Skip to content

Commit

Permalink
a slightly better discriminator
Browse files Browse the repository at this point in the history
  • Loading branch information
suragnair committed Jul 18, 2017
1 parent d64d1a6 commit cd40aed
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, gpu=False
self.gpu = gpu

self.embeddings = nn.Embedding(vocab_size, embedding_dim)
self.gru = nn.GRU(embedding_dim, hidden_dim, dropout=dropout)
self.gru2hidden = nn.Linear(hidden_dim, hidden_dim)
self.gru = nn.GRU(embedding_dim, hidden_dim, num_layers=2, bidirectional=True, dropout=dropout)
self.gru2hidden = nn.Linear(2*2*hidden_dim, hidden_dim)
self.dropout_linear = nn.Dropout(p=dropout)
self.hidden2out = nn.Linear(hidden_dim, 1)

def init_hidden(self, batch_size):
h = autograd.Variable(torch.zeros(1, batch_size, self.hidden_dim))
h = autograd.Variable(torch.zeros(2*2*1, batch_size, self.hidden_dim))

if self.gpu:
return h.cuda()
Expand All @@ -31,8 +31,9 @@ def forward(self, input, hidden):
# input dim # batch_size x seq_len
emb = self.embeddings(input) # batch_size x seq_len x embedding_dim
emb = emb.permute(1, 0, 2) # seq_len x batch_size x embedding_dim
_, hidden = self.gru(emb, hidden) # 1 x batch_size x hidden_dim (out)
out = self.gru2hidden(hidden.view(-1, self.hidden_dim)) # batch_size x hidden_dim
_, hidden = self.gru(emb, hidden) # 4 x batch_size x hidden_dim
hidden = hidden.permute(1, 0, 2) # batch_size x 4 x hidden_dim
out = self.gru2hidden(hidden.view(-1, 4*self.hidden_dim)) # batch_size x hidden_dim*4
out = F.relu(out)
out = self.dropout_linear(out)
out = self.hidden2out(out) # batch_size x 1
Expand Down

0 comments on commit cd40aed

Please sign in to comment.