Skip to content

Commit

Permalink
fixed warnings, changed shape annotation, improved initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
Seanny123 committed Oct 18, 2018
1 parent 7fd08b6 commit 8335f96
Showing 1 changed file with 26 additions and 27 deletions.
53 changes: 26 additions & 27 deletions modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def init_hidden(x, hidden_size: int):
return Variable(x.data.new(1, x.size(0), hidden_size).zero_())
return Variable(torch.zeros(1, x.size(0), hidden_size))


class Encoder(nn.Module):
Expand All @@ -25,25 +25,24 @@ def __init__(self, input_size: int, hidden_size: int, T: int):
self.attn_linear = nn.Linear(in_features=2 * hidden_size + T - 1, out_features=1)

def forward(self, input_data):
# input_data: batch_size * T - 1 * input_size
input_weighted = Variable(input_data.data.new(input_data.size(0), self.T - 1, self.input_size).zero_())
input_encoded = Variable(input_data.data.new(input_data.size(0), self.T - 1, self.hidden_size).zero_())
# input_data: (batch_size, T - 1, input_size)
input_weighted = Variable(torch.zeros(input_data.size(0), self.T - 1, self.input_size))
input_encoded = Variable(torch.zeros(input_data.size(0), self.T - 1, self.hidden_size))
# hidden, cell: initial states with dimension hidden_size
hidden = init_hidden(input_data, self.hidden_size) # 1 * batch_size * hidden_size
cell = init_hidden(input_data, self.hidden_size)
# hidden.requires_grad = False
# cell.requires_grad = False

for t in range(self.T - 1):
# Eqn. 8: concatenate the hidden states with each predictor
x = torch.cat((hidden.repeat(self.input_size, 1, 1).permute(1, 0, 2),
cell.repeat(self.input_size, 1, 1).permute(1, 0, 2),
input_data.permute(0, 2, 1)), dim=2) # batch_size * input_size * (2*hidden_size + T - 1)
# Eqn. 9: Get attention weights
# Eqn. 8: Get attention weights
x = self.attn_linear(x.view(-1, self.hidden_size * 2 + self.T - 1)) # (batch_size * input_size) * 1
attn_weights = tf.softmax(
x.view(-1, self.input_size)) # batch_size * input_size, attn weights with values sum up to 1.
# Eqn. 9: Softmax the attention weights
attn_weights = tf.softmax(x.view(-1, self.input_size), dim=1) # (batch_size, input_size)
# Eqn. 10: LSTM
weighted_input = torch.mul(attn_weights, input_data[:, t, :]) # batch_size * input_size
weighted_input = torch.mul(attn_weights, input_data[:, t, :]) # (batch_size, input_size)
# Fix the warning about non-contiguous memory
# see https://discuss.pytorch.org/t/dataparallel-issue-with-flatten-parameter/8282
self.lstm_layer.flatten_parameters()
Expand Down Expand Up @@ -75,34 +74,34 @@ def __init__(self, encoder_hidden_size: int, decoder_hidden_size: int, T: int):
self.fc.weight.data.normal_()

def forward(self, input_encoded, y_history):
# input_encoded: batch_size * T - 1 * encoder_hidden_size
# y_history: batch_size * (T-1)
# input_encoded: (batch_size, T - 1, encoder_hidden_size)
# y_history: (batch_size, (T-1))
# Initialize hidden and cell, 1 * batch_size * decoder_hidden_size
hidden = init_hidden(input_encoded, self.decoder_hidden_size)
cell = init_hidden(input_encoded, self.decoder_hidden_size)
# hidden.requires_grad = False
# cell.requires_grad = False
context = Variable(torch.zeros(input_encoded.size(0), self.encoder_hidden_size))

for t in range(self.T - 1):
# Eqn. 12-13: compute attention weights
# batch_size * T * (2*decoder_hidden_size + encoder_hidden_size)
# (batch_size, T, (2*decoder_hidden_size + encoder_hidden_size))
x = torch.cat((hidden.repeat(self.T - 1, 1, 1).permute(1, 0, 2),
cell.repeat(self.T - 1, 1, 1).permute(1, 0, 2), input_encoded), dim=2)
x = tf.softmax(
self.attn_layer(
x.view(-1, 2 * self.decoder_hidden_size + self.encoder_hidden_size)
).view(-1, self.T - 1)) # batch_size * T - 1, row sum up to 1
).view(-1, self.T - 1),
dim=1
) # batch_size * T - 1, row sum up to 1
# Eqn. 14: compute context vector
context = torch.bmm(x.unsqueeze(1), input_encoded)[:, 0, :] # batch_size * encoder_hidden_size
context = torch.bmm(x.unsqueeze(1), input_encoded)[:, 0, :] # (batch_size, encoder_hidden_size)

if t < self.T - 1:
# Eqn. 15
y_tilde = self.fc(torch.cat((context, y_history[:, t].unsqueeze(1)), dim=1)) # batch_size * 1
# Eqn. 16: LSTM
self.lstm_layer.flatten_parameters()
_, lstm_output = self.lstm_layer(y_tilde.unsqueeze(0), (hidden, cell))
hidden = lstm_output[0] # 1 * batch_size * decoder_hidden_size
cell = lstm_output[1] # 1 * batch_size * decoder_hidden_size
# Eqn. 15
y_tilde = self.fc(torch.cat((context, y_history[:, t].unsqueeze(1)), dim=1)) # batch_size * 1
# Eqn. 16: LSTM
self.lstm_layer.flatten_parameters()
_, lstm_output = self.lstm_layer(y_tilde.unsqueeze(0), (hidden, cell))
hidden = lstm_output[0] # 1 * batch_size * decoder_hidden_size
cell = lstm_output[1] # 1 * batch_size * decoder_hidden_size

# Eqn. 22: final output
y_pred = self.fc_final(torch.cat((hidden[0], context), dim=1))
return y_pred
return self.fc_final(torch.cat((hidden[0], context), dim=1))

0 comments on commit 8335f96

Please sign in to comment.