diff --git a/modules.py b/modules.py index 4ba8fe9..320423a 100644 --- a/modules.py +++ b/modules.py @@ -5,7 +5,7 @@ def init_hidden(x, hidden_size: int): - return Variable(x.data.new(1, x.size(0), hidden_size).zero_()) + return Variable(torch.zeros(1, x.size(0), hidden_size)) class Encoder(nn.Module): @@ -25,25 +25,24 @@ def __init__(self, input_size: int, hidden_size: int, T: int): self.attn_linear = nn.Linear(in_features=2 * hidden_size + T - 1, out_features=1) def forward(self, input_data): - # input_data: batch_size * T - 1 * input_size - input_weighted = Variable(input_data.data.new(input_data.size(0), self.T - 1, self.input_size).zero_()) - input_encoded = Variable(input_data.data.new(input_data.size(0), self.T - 1, self.hidden_size).zero_()) + # input_data: (batch_size, T - 1, input_size) + input_weighted = Variable(torch.zeros(input_data.size(0), self.T - 1, self.input_size)) + input_encoded = Variable(torch.zeros(input_data.size(0), self.T - 1, self.hidden_size)) # hidden, cell: initial states with dimension hidden_size hidden = init_hidden(input_data, self.hidden_size) # 1 * batch_size * hidden_size cell = init_hidden(input_data, self.hidden_size) - # hidden.requires_grad = False - # cell.requires_grad = False + for t in range(self.T - 1): # Eqn. 8: concatenate the hidden states with each predictor x = torch.cat((hidden.repeat(self.input_size, 1, 1).permute(1, 0, 2), cell.repeat(self.input_size, 1, 1).permute(1, 0, 2), input_data.permute(0, 2, 1)), dim=2) # batch_size * input_size * (2*hidden_size + T - 1) - # Eqn. 9: Get attention weights + # Eqn. 8: Get attention weights x = self.attn_linear(x.view(-1, self.hidden_size * 2 + self.T - 1)) # (batch_size * input_size) * 1 - attn_weights = tf.softmax( - x.view(-1, self.input_size)) # batch_size * input_size, attn weights with values sum up to 1. + # Eqn. 9: Softmax the attention weights + attn_weights = tf.softmax(x.view(-1, self.input_size), dim=1) # (batch_size, input_size) # Eqn. 10: LSTM - weighted_input = torch.mul(attn_weights, input_data[:, t, :]) # batch_size * input_size + weighted_input = torch.mul(attn_weights, input_data[:, t, :]) # (batch_size, input_size) # Fix the warning about non-contiguous memory # see https://discuss.pytorch.org/t/dataparallel-issue-with-flatten-parameter/8282 self.lstm_layer.flatten_parameters() @@ -75,34 +74,34 @@ def __init__(self, encoder_hidden_size: int, decoder_hidden_size: int, T: int): self.fc.weight.data.normal_() def forward(self, input_encoded, y_history): - # input_encoded: batch_size * T - 1 * encoder_hidden_size - # y_history: batch_size * (T-1) + # input_encoded: (batch_size, T - 1, encoder_hidden_size) + # y_history: (batch_size, (T-1)) # Initialize hidden and cell, 1 * batch_size * decoder_hidden_size hidden = init_hidden(input_encoded, self.decoder_hidden_size) cell = init_hidden(input_encoded, self.decoder_hidden_size) - # hidden.requires_grad = False - # cell.requires_grad = False + context = Variable(torch.zeros(input_encoded.size(0), self.encoder_hidden_size)) + for t in range(self.T - 1): # Eqn. 12-13: compute attention weights - # batch_size * T * (2*decoder_hidden_size + encoder_hidden_size) + # (batch_size, T, (2*decoder_hidden_size + encoder_hidden_size)) x = torch.cat((hidden.repeat(self.T - 1, 1, 1).permute(1, 0, 2), cell.repeat(self.T - 1, 1, 1).permute(1, 0, 2), input_encoded), dim=2) x = tf.softmax( self.attn_layer( x.view(-1, 2 * self.decoder_hidden_size + self.encoder_hidden_size) - ).view(-1, self.T - 1)) # batch_size * T - 1, row sum up to 1 + ).view(-1, self.T - 1), + dim=1 + ) # batch_size * T - 1, row sum up to 1 # Eqn. 14: compute context vector - context = torch.bmm(x.unsqueeze(1), input_encoded)[:, 0, :] # batch_size * encoder_hidden_size + context = torch.bmm(x.unsqueeze(1), input_encoded)[:, 0, :] # (batch_size, encoder_hidden_size) - if t < self.T - 1: - # Eqn. 15 - y_tilde = self.fc(torch.cat((context, y_history[:, t].unsqueeze(1)), dim=1)) # batch_size * 1 - # Eqn. 16: LSTM - self.lstm_layer.flatten_parameters() - _, lstm_output = self.lstm_layer(y_tilde.unsqueeze(0), (hidden, cell)) - hidden = lstm_output[0] # 1 * batch_size * decoder_hidden_size - cell = lstm_output[1] # 1 * batch_size * decoder_hidden_size + # Eqn. 15 + y_tilde = self.fc(torch.cat((context, y_history[:, t].unsqueeze(1)), dim=1)) # batch_size * 1 + # Eqn. 16: LSTM + self.lstm_layer.flatten_parameters() + _, lstm_output = self.lstm_layer(y_tilde.unsqueeze(0), (hidden, cell)) + hidden = lstm_output[0] # 1 * batch_size * decoder_hidden_size + cell = lstm_output[1] # 1 * batch_size * decoder_hidden_size # Eqn. 22: final output - y_pred = self.fc_final(torch.cat((hidden[0], context), dim=1)) - return y_pred + return self.fc_final(torch.cat((hidden[0], context), dim=1))