Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
timbmg committed Mar 5, 2018
0 parents commit 8e2d8b7
Show file tree
Hide file tree
Showing 6 changed files with 681 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
https://arxiv.org/abs/1511.06349
74 changes: 74 additions & 0 deletions inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import os
import json
import torch
import argparse

from model import SentenceVAE
from utils import to_var, idx2word, interpolate


def main(args):

with open(args.data_dir+'/ptb.vocab.json', 'r') as file:
vocab = json.load(file)

w2i, i2w = vocab['w2i'], vocab['i2w']

model = SentenceVAE(
vocab_size=len(w2i),
sos_idx=w2i['<sos>'],
eos_idx=w2i['<eos>'],
pad_idx=w2i['<pad>'],
max_sequence_length=args.max_sequence_length,
embedding_size=args.embedding_size,
rnn_type=args.rnn_type,
hidden_size=args.hidden_size,
word_dropout=args.word_dropout,
latent_size=args.latent_size,
num_layers=args.num_layers,
bidirectional=args.bidirectional
)

if not os.path.exists(args.load_checkpoint):
raise FileNotFoundError(args.load_checkpoint)

model.load_state_dict(torch.load(args.load_checkpoint))
print("Model loaded from %s"%(args.load_checkpoint))

model.eval()

samples, z = model.inference(n=args.num_samples)
print('----------SAMPLES----------')
print(*idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')

z1 = torch.randn([args.latent_size]).numpy()
z2 = torch.randn([args.latent_size]).numpy()
z = to_var(torch.from_numpy(interpolate(start=z1, end=z2, steps=8)).float())
samples, _ = model.inference(z=z)
print('-------INTERPOLATION-------')
print(*idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')

if __name__ == '__main__':

parser = argparse.ArgumentParser()

parser.add_argument('-c', '--load_checkpoint', type=str)
parser.add_argument('-n', '--num_samples', type=int, default=10)

parser.add_argument('-dd', '--data_dir', type=str, default='data')
parser.add_argument('-ms', '--max_sequence_length', type=int, default=50)
parser.add_argument('-eb', '--embedding_size', type=int, default=300)
parser.add_argument('-rnn', '--rnn_type', type=str, default='gru')
parser.add_argument('-hs', '--hidden_size', type=int, default=256)
parser.add_argument('-wd', '--word_dropout', type=float, default=0.5)
parser.add_argument('-ls', '--latent_size', type=int, default=16)
parser.add_argument('-nl', '--num_layers', type=int, default=1)
parser.add_argument('-bi', '--bidirectional', action='store_true')

args = parser.parse_args()

args.rnn_type = args.rnn_type.lower()

assert args.rnn_type in ['rnn', 'lstm', 'gru']

main(args)
185 changes: 185 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils
from utils import to_var

class SentenceVAE(nn.Module):

def __init__(self, vocab_size, embedding_size, rnn_type, hidden_size, word_dropout, latent_size,
sos_idx, eos_idx, pad_idx, max_sequence_length, num_layers=1, bidirectional=False):

super().__init__()
self.tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor

self.max_sequence_length = max_sequence_length
self.sos_idx = sos_idx
self.eos_idx = eos_idx
self.pad_idx = pad_idx

self.latent_size = latent_size

self.rnn_type = rnn_type
self.bidirectional = bidirectional
self.num_layers = num_layers
self.hidden_size = hidden_size

self.embedding = nn.Embedding(vocab_size, embedding_size)
self.word_dropout = nn.Dropout(p=word_dropout)

if rnn_type == 'rnn':
rnn = nn.RNN
elif rnn_type == 'gru':
rnn = nn.GRU
# elif rnn_type == 'lstm':
# rnn = nn.LSTM
else:
raise ValueError()

self.encoder_rnn = rnn(embedding_size, hidden_size, num_layers=num_layers, bidirectional=self.bidirectional, batch_first=True)
self.decoder_rnn = rnn(embedding_size, hidden_size, num_layers=num_layers, bidirectional=self.bidirectional, batch_first=True)

self.hidden_factor = (2 if bidirectional else 1) * num_layers

self.hidden2mean = nn.Linear(hidden_size * self.hidden_factor, latent_size)
self.hidden2logv = nn.Linear(hidden_size * self.hidden_factor, latent_size)
self.latent2hidden = nn.Linear(latent_size, hidden_size * self.hidden_factor)
self.outputs2vocab = nn.Linear(hidden_size * (2 if bidirectional else 1), vocab_size)

def forward(self, input_sequence, length):

batch_size = input_sequence.size(0)
sorted_lengths, sorted_idx = torch.sort(length, descending=True)
input_sequence = input_sequence[sorted_idx]

# ENCODER
input_embedding = self.embedding(input_sequence)

packed_input = rnn_utils.pack_padded_sequence(input_embedding, sorted_lengths.data.tolist(), batch_first=True)

_, hidden = self.encoder_rnn(packed_input)

if self.bidirectional or self.num_layers > 1:
# flatten hidden state
hidden = hidden.view(batch_size, self.hidden_size*self.hidden_factor)
else:
hidden = hidden.squeeze()

# REPARAMETERIZATION
mean = self.hidden2mean(hidden)
logv = self.hidden2logv(hidden)
std = torch.exp(0.5 * logv)

z = to_var(torch.randn([batch_size, self.latent_size]))
z = z * std + mean

# DECODER
hidden = self.latent2hidden(z)

if self.bidirectional or self.num_layers > 1:
# unflatten hidden state
hidden = hidden.view(self.hidden_factor, batch_size, self.hidden_size)
else:
hidden = hidden.unsqueeze(0)

# decoder input
input_embedding = self.word_dropout(input_embedding)
packed_input = rnn_utils.pack_padded_sequence(input_embedding, sorted_lengths.data.tolist(), batch_first=True)

# decoder forward pass
outputs, _ = self.decoder_rnn(packed_input, hidden)

# process outputs
padded_outputs = rnn_utils.pad_packed_sequence(outputs, batch_first=True)[0]
padded_outputs = padded_outputs.contiguous()
_,reversed_idx = torch.sort(sorted_idx)
padded_outputs = padded_outputs[reversed_idx]
b,s,_ = padded_outputs.size()

# project outputs to vocab
logp = nn.functional.log_softmax(self.outputs2vocab(padded_outputs.view(-1, padded_outputs.size(2))), dim=-1)
logp = logp.view(b, s, self.embedding.num_embeddings)


return logp, mean, logv, z


def inference(self, n=4, z=None):

if z is None:
batch_size = n
z = to_var(torch.randn([batch_size, self.latent_size]))
else:
batch_size = z.size(0)

hidden = self.latent2hidden(z)

if self.bidirectional or self.num_layers > 1:
# unflatten hidden state
hidden = hidden.view(self.hidden_factor, batch_size, self.hidden_size)

hidden = hidden.unsqueeze(0)

# required for dynamic stopping of sentence generation
sequence_idx = torch.arange(0, batch_size, out=self.tensor()).long() # all idx of batch
sequence_running = torch.arange(0, batch_size, out=self.tensor()).long() # all idx of batch which are still generating
sequence_mask = torch.ones(batch_size, out=self.tensor()).byte()

running_seqs = torch.arange(0, batch_size, out=self.tensor()).long() # idx of still generating sequences with respect to current loop

generations = self.tensor(batch_size, self.max_sequence_length).fill_(self.pad_idx).long()

t=0
while(t<self.max_sequence_length and len(running_seqs)>0):

if t == 0:
input_sequence = to_var(torch.Tensor(batch_size).fill_(self.sos_idx).long())

input_sequence = input_sequence.unsqueeze(1)

input_embedding = self.embedding(input_sequence)

output, hidden = self.decoder_rnn(input_embedding, hidden)

logits = self.outputs2vocab(output)

input_sequence = self._sample(logits)

# save next input
generations = self._save_sample(generations, input_sequence, sequence_running, t)

# update gloabl running sequence
sequence_mask[sequence_running] = (input_sequence != self.eos_idx).data
sequence_running = sequence_idx.masked_select(sequence_mask)

# update local running sequences
running_mask = (input_sequence != self.eos_idx).data
running_seqs = running_seqs.masked_select(running_mask)

# prune input and hidden state according to local update
if len(running_seqs) > 0:
input_sequence = input_sequence[running_seqs]
hidden = hidden[:, running_seqs]

running_seqs = torch.arange(0, len(running_seqs), out=self.tensor()).long()

t += 1

return generations, z

def _sample(self, dist, mode='greedy'):

if mode == 'greedy':
_, sample = torch.topk(dist, 1, dim=-1)
sample = sample.squeeze()

return sample

def _save_sample(self, save_to, sample, running_seqs, t):
# select only still running
running_latest = save_to[running_seqs]
# update token at position t
running_latest[:,t] = sample.data
# save back
save_to[running_seqs] = running_latest

return save_to
Loading

0 comments on commit 8e2d8b7

Please sign in to comment.