From a7a5cdc598afd6afa3d0e9d355360bebc4947c29 Mon Sep 17 00:00:00 2001 From: Bryan Marcus McCann Date: Wed, 25 Jan 2017 02:11:47 +0000 Subject: [PATCH] adding projection layer --- snli/model.py | 45 ++++++++++++++++++++++------------- snli/train.py | 65 ++++++++++++++++++++++++++++++--------------------- snli/util.py | 11 +++++---- 3 files changed, 74 insertions(+), 47 deletions(-) diff --git a/snli/model.py b/snli/model.py index 33087f69d4..4de4f100e2 100644 --- a/snli/model.py +++ b/snli/model.py @@ -22,16 +22,17 @@ class Encoder(nn.Module): def __init__(self, config): super(Encoder, self).__init__() self.config = config - self.rnn = nn.LSTM(input_size=config.d_embed, hidden_size=config.d_hidden, + input_size = config.d_proj if config.projection else config.d_embed + self.rnn = nn.LSTM(input_size=input_size, hidden_size=config.d_hidden, num_layers=config.n_layers, dropout=config.dp_ratio, - bidirectional=config.bidirectional) + bidirectional=config.birnn) def forward(self, inputs): batch_size = inputs.size()[1] - h0 = Variable(torch.zeros(self.config.n_cells, batch_size, self.config.d_hidden)).cuda() - c0 = Variable(torch.zeros(self.config.n_cells, batch_size, self.config.d_hidden)).cuda() - _, (hn, _) = self.rnn(inputs, (h0, c0)) - return hn[-1] if not self.config.bidirectional else hn[-2:].view(batch_size, -1) + state_shape = self.config.n_cells, batch_size, self.config.d_hidden + h0 = c0 = Variable(inputs.data.new(*state_shape).zero_()) + outputs, (ht, ct) = self.rnn(inputs, (h0, c0)) + return ht[-1] if not self.config.birnn else ht[-2:].view(batch_size, -1) class SNLIClassifier(nn.Module): @@ -40,24 +41,36 @@ def __init__(self, config): super(SNLIClassifier, self).__init__() self.config = config self.embed = nn.Embedding(config.n_embed, config.d_embed) + self.projection = Linear(config.d_embed, config.d_proj) self.encoder = Encoder(config) + self.dropout = nn.Dropout(p=config.dp_ratio) + self.relu = nn.ReLU() seq_in_size = 2*config.d_hidden - if self.config.bidirectional: + if self.config.birnn: seq_in_size *= 2 lin_config = [seq_in_size]*2 self.out = nn.Sequential( Linear(*lin_config), - nn.ReLU(), - nn.Dropout(p=config.dp_ratio), + self.relu, + self.dropout, Linear(*lin_config), - nn.ReLU(), + self.relu, + self.dropout, Linear(*lin_config), - nn.ReLU(), - nn.Dropout(p=config.dp_ratio), + self.relu, + self.dropout, Linear(seq_in_size, config.d_out)) def forward(self, batch): - premise = self.encoder(self.embed(batch.premise)) - hypothesis = self.encoder(self.embed(batch.hypothesis)) - answer = self.out(torch.cat([premise, hypothesis], 1)) - return answer + prem_embed = self.embed(batch.premise) + hypo_embed = self.embed(batch.hypothesis) + if self.config.fix_emb: + prem_embed = Variable(prem_embed.data) + hypo_embed = Variable(hypo_embed.data) + if self.config.projection: + prem_embed = self.relu(self.projection(prem_embed)) + hypo_embed = self.relu(self.projection(hypo_embed)) + premise = self.encoder(prem_embed) + hypothesis = self.encoder(hypo_embed) + scores = self.out(torch.cat([premise, hypothesis], 1)) + return scores diff --git a/snli/train.py b/snli/train.py index 3e4179866b..eefaa56a3f 100644 --- a/snli/train.py +++ b/snli/train.py @@ -1,5 +1,6 @@ import os import time +import glob import torch import torch.optim as O @@ -17,27 +18,28 @@ inputs = data.Field() answers = data.Field(sequential=False) -train, val, test = datasets.SNLI.splits(inputs, answers) +train, dev, test = datasets.SNLI.splits(inputs, answers) -if os.path.isfile(args.vocab_cache): - inputs.build_vocab(train, lower=args.lower) - inputs.vocab.vectors = torch.load(args.vocab_cache) +if args.word_vectors and os.path.isfile(args.vector_cache): + inputs.build_vocab(train, dev, test, lower=args.lower) + inputs.vocab.vectors = torch.load(args.vector_cache) else: - inputs.build_vocab(train, vectors=(args.data_cache, args.word_vectors, args.d_embed), lower=args.lower) - os.makedirs(os.path.dirname(args.vocab_cache), exist_okay=True) - torch.save(inputs.vocab.vectors, args.vocab_cache) + if args.word_vectors: + inputs.build_vocab(train, dev, test, vectors=(args.data_cache, args.word_vectors, args.d_embed), lower=args.lower) + os.makedirs(os.path.dirname(args.vector_cache), exist_ok=True) + torch.save(inputs.vocab.vectors, args.vector_cache) + else: + inputs.build_vocab(train, dev, test, lower=args.lower) answers.build_vocab(train) -train_iter, val_iter, test_iter = data.BucketIterator.splits( - (train, val, test), batch_size=args.batch_size, device=args.gpu) -print(train_iter.batch_size) -print(len(train_iter)) +train_iter, dev_iter, test_iter = data.BucketIterator.splits( + (train, dev, test), batch_size=args.batch_size, device=args.gpu) config = args config.n_embed = len(inputs.vocab) config.d_out = len(answers.vocab) config.n_cells = config.n_layers -if config.bidirectional: +if config.birnn: config.n_cells *= 2 model = SNLIClassifier(config) @@ -51,9 +53,10 @@ start = time.time() best_dev_acc = -1 train_iter.repeat = False -header = ' Time Epoch Iteration Progress (%Epoch) Loss Val/Loss Accuracy Val/Accuracy' -val_log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:12.4f},{:12.4f}'.split(',')) +header = ' Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss Accuracy Dev/Accuracy' +dev_log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:12.4f},{:12.4f}'.split(',')) log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{},{:12.4f},{}'.split(',')) +os.makedirs(args.save_path, exist_ok=True) print(header) for epoch in range(args.epochs): @@ -69,26 +72,34 @@ loss = criterion(answer, batch.label) loss.backward(); opt.step() if iterations % args.save_every == 0: - torch.save(model, os.path.join(args.save_path, - 'snapshot_acc_{:.4f}_loss_{:.6f}_iter_{}_model.pt'.format(train_acc, loss.data[0], iterations))) - if iterations % args.val_every == 0: - model.eval(); val_iter.init_epoch() + snapshot_prefix = os.path.join(args.save_path, 'snapshot') + snapshot_path = snapshot_prefix + '_acc_{:.4f}_loss_{:.6f}_iter_{}_model.pt'.format(train_acc, loss.data[0], iterations) + torch.save(model, snapshot_path) + for f in glob.glob(snapshot_prefix + '*'): + if f != snapshot_path: + os.remove(f) + if iterations % args.dev_every == 0: + model.eval(); dev_iter.init_epoch() n_dev_correct, dev_loss = 0, 0 - for dev_batch_idx, dev_batch in enumerate(val_iter): + for dev_batch_idx, dev_batch in enumerate(dev_iter): answer = model(dev_batch) n_dev_correct += (torch.max(answer, 1)[1].view(dev_batch.label.size()).data == dev_batch.label.data).sum() dev_loss = criterion(answer, dev_batch.label) - dev_acc = 100. * n_dev_correct / len(val) - print(val_log_template.format(time.time()-start, - epoch, iterations, batch_idx, len(train_iter), - 100. * batch_idx / len(train_iter), loss.data[0], dev_loss.data[0], train_acc, dev_acc)) + dev_acc = 100. * n_dev_correct / len(dev) + print(dev_log_template.format(time.time()-start, + epoch, iterations, 1+batch_idx, len(train_iter), + 100. * (1+batch_idx) / len(train_iter), loss.data[0], dev_loss.data[0], train_acc, dev_acc)) if dev_acc > best_dev_acc: best_dev_acc = dev_acc - torch.save(model, os.path.join(args.save_path, - 'best_devacc_{}_devloss_{}__iter_{}_model.pt'.format(dev_acc, dev_loss.data[0], iterations))) + snapshot_prefix = os.path.join(args.save_path, 'best_snapshot') + snapshot_path = snapshot_prefix + '_devacc_{}_devloss_{}__iter_{}_model.pt'.format(dev_acc, dev_loss.data[0], iterations) + torch.save(model, snapshot_path) + for f in glob.glob(snapshot_prefix + '*'): + if f != snapshot_path: + os.remove(f) elif iterations % args.log_every == 0: print(log_template.format(time.time()-start, - epoch, iterations, batch_idx, len(train_iter), - 100. * batch_idx / len(train_iter), loss.data[0], ' '*8, n_correct/n_total*100, ' '*12)) + epoch, iterations, 1+batch_idx, len(train_iter), + 100. * (1+batch_idx) / len(train_iter), loss.data[0], ' '*8, n_correct/n_total*100, ' '*12)) diff --git a/snli/util.py b/snli/util.py index e1a7be328c..dcdf5f23db 100644 --- a/snli/util.py +++ b/snli/util.py @@ -6,19 +6,22 @@ def get_args(): parser.add_argument('--epochs', type=int, default=50) parser.add_argument('--batch_size', type=int, default=128) parser.add_argument('--d_embed', type=int, default=300) + parser.add_argument('--d_proj', type=int, default=300) parser.add_argument('--d_hidden', type=int, default=300) parser.add_argument('--n_layers', type=int, default=2) parser.add_argument('--log_every', type=int, default=50) parser.add_argument('--lr', type=float, default=.001) - parser.add_argument('--val_every', type=int, default=1000) + parser.add_argument('--dev_every', type=int, default=1000) parser.add_argument('--save_every', type=int, default=1000) parser.add_argument('--dp_ratio', type=int, default=0.0) - parser.add_argument('--bidirectional', action='store_true') + parser.add_argument('--no-bidirectional', action='store_false', dest='birnn') parser.add_argument('--preserve-case', action='store_false', dest='lower') + parser.add_argument('--no-projection', action='store_false', dest='projection') + parser.add_argument('--train_embed', action='store_false', dest='fix_emb') parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--save_path', type=str, default='') + parser.add_argument('--save_path', type=str, default='results') parser.add_argument('--data_cache', type=str, default=os.path.join(os.getcwd(), '.data_cache')) - parser.add_argument('--vocab_cache', type=str, default=os.path.join(os.getcwd(), '.vocab_cache/input_vocab.pt')) + parser.add_argument('--vector_cache', type=str, default=os.path.join(os.getcwd(), '.vector_cache/input_vectors.pt')) parser.add_argument('--word_vectors', type=str, default='glove.42B') args = parser.parse_args() return args