forked from timbmg/Sentence-VAE
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 8e2d8b7
Showing
6 changed files
with
681 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
https://arxiv.org/abs/1511.06349 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import os | ||
import json | ||
import torch | ||
import argparse | ||
|
||
from model import SentenceVAE | ||
from utils import to_var, idx2word, interpolate | ||
|
||
|
||
def main(args): | ||
|
||
with open(args.data_dir+'/ptb.vocab.json', 'r') as file: | ||
vocab = json.load(file) | ||
|
||
w2i, i2w = vocab['w2i'], vocab['i2w'] | ||
|
||
model = SentenceVAE( | ||
vocab_size=len(w2i), | ||
sos_idx=w2i['<sos>'], | ||
eos_idx=w2i['<eos>'], | ||
pad_idx=w2i['<pad>'], | ||
max_sequence_length=args.max_sequence_length, | ||
embedding_size=args.embedding_size, | ||
rnn_type=args.rnn_type, | ||
hidden_size=args.hidden_size, | ||
word_dropout=args.word_dropout, | ||
latent_size=args.latent_size, | ||
num_layers=args.num_layers, | ||
bidirectional=args.bidirectional | ||
) | ||
|
||
if not os.path.exists(args.load_checkpoint): | ||
raise FileNotFoundError(args.load_checkpoint) | ||
|
||
model.load_state_dict(torch.load(args.load_checkpoint)) | ||
print("Model loaded from %s"%(args.load_checkpoint)) | ||
|
||
model.eval() | ||
|
||
samples, z = model.inference(n=args.num_samples) | ||
print('----------SAMPLES----------') | ||
print(*idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n') | ||
|
||
z1 = torch.randn([args.latent_size]).numpy() | ||
z2 = torch.randn([args.latent_size]).numpy() | ||
z = to_var(torch.from_numpy(interpolate(start=z1, end=z2, steps=8)).float()) | ||
samples, _ = model.inference(z=z) | ||
print('-------INTERPOLATION-------') | ||
print(*idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n') | ||
|
||
if __name__ == '__main__': | ||
|
||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument('-c', '--load_checkpoint', type=str) | ||
parser.add_argument('-n', '--num_samples', type=int, default=10) | ||
|
||
parser.add_argument('-dd', '--data_dir', type=str, default='data') | ||
parser.add_argument('-ms', '--max_sequence_length', type=int, default=50) | ||
parser.add_argument('-eb', '--embedding_size', type=int, default=300) | ||
parser.add_argument('-rnn', '--rnn_type', type=str, default='gru') | ||
parser.add_argument('-hs', '--hidden_size', type=int, default=256) | ||
parser.add_argument('-wd', '--word_dropout', type=float, default=0.5) | ||
parser.add_argument('-ls', '--latent_size', type=int, default=16) | ||
parser.add_argument('-nl', '--num_layers', type=int, default=1) | ||
parser.add_argument('-bi', '--bidirectional', action='store_true') | ||
|
||
args = parser.parse_args() | ||
|
||
args.rnn_type = args.rnn_type.lower() | ||
|
||
assert args.rnn_type in ['rnn', 'lstm', 'gru'] | ||
|
||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.utils.rnn as rnn_utils | ||
from utils import to_var | ||
|
||
class SentenceVAE(nn.Module): | ||
|
||
def __init__(self, vocab_size, embedding_size, rnn_type, hidden_size, word_dropout, latent_size, | ||
sos_idx, eos_idx, pad_idx, max_sequence_length, num_layers=1, bidirectional=False): | ||
|
||
super().__init__() | ||
self.tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor | ||
|
||
self.max_sequence_length = max_sequence_length | ||
self.sos_idx = sos_idx | ||
self.eos_idx = eos_idx | ||
self.pad_idx = pad_idx | ||
|
||
self.latent_size = latent_size | ||
|
||
self.rnn_type = rnn_type | ||
self.bidirectional = bidirectional | ||
self.num_layers = num_layers | ||
self.hidden_size = hidden_size | ||
|
||
self.embedding = nn.Embedding(vocab_size, embedding_size) | ||
self.word_dropout = nn.Dropout(p=word_dropout) | ||
|
||
if rnn_type == 'rnn': | ||
rnn = nn.RNN | ||
elif rnn_type == 'gru': | ||
rnn = nn.GRU | ||
# elif rnn_type == 'lstm': | ||
# rnn = nn.LSTM | ||
else: | ||
raise ValueError() | ||
|
||
self.encoder_rnn = rnn(embedding_size, hidden_size, num_layers=num_layers, bidirectional=self.bidirectional, batch_first=True) | ||
self.decoder_rnn = rnn(embedding_size, hidden_size, num_layers=num_layers, bidirectional=self.bidirectional, batch_first=True) | ||
|
||
self.hidden_factor = (2 if bidirectional else 1) * num_layers | ||
|
||
self.hidden2mean = nn.Linear(hidden_size * self.hidden_factor, latent_size) | ||
self.hidden2logv = nn.Linear(hidden_size * self.hidden_factor, latent_size) | ||
self.latent2hidden = nn.Linear(latent_size, hidden_size * self.hidden_factor) | ||
self.outputs2vocab = nn.Linear(hidden_size * (2 if bidirectional else 1), vocab_size) | ||
|
||
def forward(self, input_sequence, length): | ||
|
||
batch_size = input_sequence.size(0) | ||
sorted_lengths, sorted_idx = torch.sort(length, descending=True) | ||
input_sequence = input_sequence[sorted_idx] | ||
|
||
# ENCODER | ||
input_embedding = self.embedding(input_sequence) | ||
|
||
packed_input = rnn_utils.pack_padded_sequence(input_embedding, sorted_lengths.data.tolist(), batch_first=True) | ||
|
||
_, hidden = self.encoder_rnn(packed_input) | ||
|
||
if self.bidirectional or self.num_layers > 1: | ||
# flatten hidden state | ||
hidden = hidden.view(batch_size, self.hidden_size*self.hidden_factor) | ||
else: | ||
hidden = hidden.squeeze() | ||
|
||
# REPARAMETERIZATION | ||
mean = self.hidden2mean(hidden) | ||
logv = self.hidden2logv(hidden) | ||
std = torch.exp(0.5 * logv) | ||
|
||
z = to_var(torch.randn([batch_size, self.latent_size])) | ||
z = z * std + mean | ||
|
||
# DECODER | ||
hidden = self.latent2hidden(z) | ||
|
||
if self.bidirectional or self.num_layers > 1: | ||
# unflatten hidden state | ||
hidden = hidden.view(self.hidden_factor, batch_size, self.hidden_size) | ||
else: | ||
hidden = hidden.unsqueeze(0) | ||
|
||
# decoder input | ||
input_embedding = self.word_dropout(input_embedding) | ||
packed_input = rnn_utils.pack_padded_sequence(input_embedding, sorted_lengths.data.tolist(), batch_first=True) | ||
|
||
# decoder forward pass | ||
outputs, _ = self.decoder_rnn(packed_input, hidden) | ||
|
||
# process outputs | ||
padded_outputs = rnn_utils.pad_packed_sequence(outputs, batch_first=True)[0] | ||
padded_outputs = padded_outputs.contiguous() | ||
_,reversed_idx = torch.sort(sorted_idx) | ||
padded_outputs = padded_outputs[reversed_idx] | ||
b,s,_ = padded_outputs.size() | ||
|
||
# project outputs to vocab | ||
logp = nn.functional.log_softmax(self.outputs2vocab(padded_outputs.view(-1, padded_outputs.size(2))), dim=-1) | ||
logp = logp.view(b, s, self.embedding.num_embeddings) | ||
|
||
|
||
return logp, mean, logv, z | ||
|
||
|
||
def inference(self, n=4, z=None): | ||
|
||
if z is None: | ||
batch_size = n | ||
z = to_var(torch.randn([batch_size, self.latent_size])) | ||
else: | ||
batch_size = z.size(0) | ||
|
||
hidden = self.latent2hidden(z) | ||
|
||
if self.bidirectional or self.num_layers > 1: | ||
# unflatten hidden state | ||
hidden = hidden.view(self.hidden_factor, batch_size, self.hidden_size) | ||
|
||
hidden = hidden.unsqueeze(0) | ||
|
||
# required for dynamic stopping of sentence generation | ||
sequence_idx = torch.arange(0, batch_size, out=self.tensor()).long() # all idx of batch | ||
sequence_running = torch.arange(0, batch_size, out=self.tensor()).long() # all idx of batch which are still generating | ||
sequence_mask = torch.ones(batch_size, out=self.tensor()).byte() | ||
|
||
running_seqs = torch.arange(0, batch_size, out=self.tensor()).long() # idx of still generating sequences with respect to current loop | ||
|
||
generations = self.tensor(batch_size, self.max_sequence_length).fill_(self.pad_idx).long() | ||
|
||
t=0 | ||
while(t<self.max_sequence_length and len(running_seqs)>0): | ||
|
||
if t == 0: | ||
input_sequence = to_var(torch.Tensor(batch_size).fill_(self.sos_idx).long()) | ||
|
||
input_sequence = input_sequence.unsqueeze(1) | ||
|
||
input_embedding = self.embedding(input_sequence) | ||
|
||
output, hidden = self.decoder_rnn(input_embedding, hidden) | ||
|
||
logits = self.outputs2vocab(output) | ||
|
||
input_sequence = self._sample(logits) | ||
|
||
# save next input | ||
generations = self._save_sample(generations, input_sequence, sequence_running, t) | ||
|
||
# update gloabl running sequence | ||
sequence_mask[sequence_running] = (input_sequence != self.eos_idx).data | ||
sequence_running = sequence_idx.masked_select(sequence_mask) | ||
|
||
# update local running sequences | ||
running_mask = (input_sequence != self.eos_idx).data | ||
running_seqs = running_seqs.masked_select(running_mask) | ||
|
||
# prune input and hidden state according to local update | ||
if len(running_seqs) > 0: | ||
input_sequence = input_sequence[running_seqs] | ||
hidden = hidden[:, running_seqs] | ||
|
||
running_seqs = torch.arange(0, len(running_seqs), out=self.tensor()).long() | ||
|
||
t += 1 | ||
|
||
return generations, z | ||
|
||
def _sample(self, dist, mode='greedy'): | ||
|
||
if mode == 'greedy': | ||
_, sample = torch.topk(dist, 1, dim=-1) | ||
sample = sample.squeeze() | ||
|
||
return sample | ||
|
||
def _save_sample(self, save_to, sample, running_seqs, t): | ||
# select only still running | ||
running_latest = save_to[running_seqs] | ||
# update token at position t | ||
running_latest[:,t] = sample.data | ||
# save back | ||
save_to[running_seqs] = running_latest | ||
|
||
return save_to |
Oops, something went wrong.