-
Notifications
You must be signed in to change notification settings - Fork 100
/
Copy pathdefine_lstm_model.py
39 lines (31 loc) · 1.35 KB
/
define_lstm_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class LSTMModel(nn.Module):
def __init__(self, input_size=1, hidden_layer_size=32, num_layers=2, output_size=1, dropout=0.2):
super().__init__()
self.hidden_layer_size = hidden_layer_size
self.linear_1 = nn.Linear(input_size, hidden_layer_size)
self.relu = nn.ReLU()
self.lstm = nn.LSTM(hidden_layer_size, hidden_size=self.hidden_layer_size, num_layers=num_layers, batch_first=True)
self.dropout = nn.Dropout(dropout)
self.linear_2 = nn.Linear(num_layers*hidden_layer_size, output_size)
self.init_weights()
def init_weights(self):
for name, param in self.lstm.named_parameters():
if 'bias' in name:
nn.init.constant_(param, 0.0)
elif 'weight_ih' in name:
nn.init.kaiming_normal_(param)
elif 'weight_hh' in name:
nn.init.orthogonal_(param)
def forward(self, x):
batchsize = x.shape[0]
# layer 1
x = self.linear_1(x)
x = self.relu(x)
# LSTM layer
lstm_out, (h_n, c_n) = self.lstm(x)
# reshape output from hidden cell into [batch, features] for `linear_2`
x = h_n.permute(1, 0, 2).reshape(batchsize, -1)
# layer 2
x = self.dropout(x)
predictions = self.linear_2(x)
return predictions[:,-1]