Skip to content

Commit

Permalink
adding projection layer
Browse files Browse the repository at this point in the history
  • Loading branch information
bmccann committed Jan 25, 2017
1 parent 88508c0 commit a7a5cdc
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 47 deletions.
45 changes: 29 additions & 16 deletions snli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,17 @@ class Encoder(nn.Module):
def __init__(self, config):
super(Encoder, self).__init__()
self.config = config
self.rnn = nn.LSTM(input_size=config.d_embed, hidden_size=config.d_hidden,
input_size = config.d_proj if config.projection else config.d_embed
self.rnn = nn.LSTM(input_size=input_size, hidden_size=config.d_hidden,
num_layers=config.n_layers, dropout=config.dp_ratio,
bidirectional=config.bidirectional)
bidirectional=config.birnn)

def forward(self, inputs):
batch_size = inputs.size()[1]
h0 = Variable(torch.zeros(self.config.n_cells, batch_size, self.config.d_hidden)).cuda()
c0 = Variable(torch.zeros(self.config.n_cells, batch_size, self.config.d_hidden)).cuda()
_, (hn, _) = self.rnn(inputs, (h0, c0))
return hn[-1] if not self.config.bidirectional else hn[-2:].view(batch_size, -1)
state_shape = self.config.n_cells, batch_size, self.config.d_hidden
h0 = c0 = Variable(inputs.data.new(*state_shape).zero_())
outputs, (ht, ct) = self.rnn(inputs, (h0, c0))
return ht[-1] if not self.config.birnn else ht[-2:].view(batch_size, -1)


class SNLIClassifier(nn.Module):
Expand All @@ -40,24 +41,36 @@ def __init__(self, config):
super(SNLIClassifier, self).__init__()
self.config = config
self.embed = nn.Embedding(config.n_embed, config.d_embed)
self.projection = Linear(config.d_embed, config.d_proj)
self.encoder = Encoder(config)
self.dropout = nn.Dropout(p=config.dp_ratio)
self.relu = nn.ReLU()
seq_in_size = 2*config.d_hidden
if self.config.bidirectional:
if self.config.birnn:
seq_in_size *= 2
lin_config = [seq_in_size]*2
self.out = nn.Sequential(
Linear(*lin_config),
nn.ReLU(),
nn.Dropout(p=config.dp_ratio),
self.relu,
self.dropout,
Linear(*lin_config),
nn.ReLU(),
self.relu,
self.dropout,
Linear(*lin_config),
nn.ReLU(),
nn.Dropout(p=config.dp_ratio),
self.relu,
self.dropout,
Linear(seq_in_size, config.d_out))

def forward(self, batch):
premise = self.encoder(self.embed(batch.premise))
hypothesis = self.encoder(self.embed(batch.hypothesis))
answer = self.out(torch.cat([premise, hypothesis], 1))
return answer
prem_embed = self.embed(batch.premise)
hypo_embed = self.embed(batch.hypothesis)
if self.config.fix_emb:
prem_embed = Variable(prem_embed.data)
hypo_embed = Variable(hypo_embed.data)
if self.config.projection:
prem_embed = self.relu(self.projection(prem_embed))
hypo_embed = self.relu(self.projection(hypo_embed))
premise = self.encoder(prem_embed)
hypothesis = self.encoder(hypo_embed)
scores = self.out(torch.cat([premise, hypothesis], 1))
return scores
65 changes: 38 additions & 27 deletions snli/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import time
import glob

import torch
import torch.optim as O
Expand All @@ -17,27 +18,28 @@
inputs = data.Field()
answers = data.Field(sequential=False)

train, val, test = datasets.SNLI.splits(inputs, answers)
train, dev, test = datasets.SNLI.splits(inputs, answers)

if os.path.isfile(args.vocab_cache):
inputs.build_vocab(train, lower=args.lower)
inputs.vocab.vectors = torch.load(args.vocab_cache)
if args.word_vectors and os.path.isfile(args.vector_cache):
inputs.build_vocab(train, dev, test, lower=args.lower)
inputs.vocab.vectors = torch.load(args.vector_cache)
else:
inputs.build_vocab(train, vectors=(args.data_cache, args.word_vectors, args.d_embed), lower=args.lower)
os.makedirs(os.path.dirname(args.vocab_cache), exist_okay=True)
torch.save(inputs.vocab.vectors, args.vocab_cache)
if args.word_vectors:
inputs.build_vocab(train, dev, test, vectors=(args.data_cache, args.word_vectors, args.d_embed), lower=args.lower)
os.makedirs(os.path.dirname(args.vector_cache), exist_ok=True)
torch.save(inputs.vocab.vectors, args.vector_cache)
else:
inputs.build_vocab(train, dev, test, lower=args.lower)
answers.build_vocab(train)

train_iter, val_iter, test_iter = data.BucketIterator.splits(
(train, val, test), batch_size=args.batch_size, device=args.gpu)
print(train_iter.batch_size)
print(len(train_iter))
train_iter, dev_iter, test_iter = data.BucketIterator.splits(
(train, dev, test), batch_size=args.batch_size, device=args.gpu)

config = args
config.n_embed = len(inputs.vocab)
config.d_out = len(answers.vocab)
config.n_cells = config.n_layers
if config.bidirectional:
if config.birnn:
config.n_cells *= 2

model = SNLIClassifier(config)
Expand All @@ -51,9 +53,10 @@
start = time.time()
best_dev_acc = -1
train_iter.repeat = False
header = ' Time Epoch Iteration Progress (%Epoch) Loss Val/Loss Accuracy Val/Accuracy'
val_log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:12.4f},{:12.4f}'.split(','))
header = ' Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss Accuracy Dev/Accuracy'
dev_log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:12.4f},{:12.4f}'.split(','))
log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{},{:12.4f},{}'.split(','))
os.makedirs(args.save_path, exist_ok=True)
print(header)

for epoch in range(args.epochs):
Expand All @@ -69,26 +72,34 @@
loss = criterion(answer, batch.label)
loss.backward(); opt.step()
if iterations % args.save_every == 0:
torch.save(model, os.path.join(args.save_path,
'snapshot_acc_{:.4f}_loss_{:.6f}_iter_{}_model.pt'.format(train_acc, loss.data[0], iterations)))
if iterations % args.val_every == 0:
model.eval(); val_iter.init_epoch()
snapshot_prefix = os.path.join(args.save_path, 'snapshot')
snapshot_path = snapshot_prefix + '_acc_{:.4f}_loss_{:.6f}_iter_{}_model.pt'.format(train_acc, loss.data[0], iterations)
torch.save(model, snapshot_path)
for f in glob.glob(snapshot_prefix + '*'):
if f != snapshot_path:
os.remove(f)
if iterations % args.dev_every == 0:
model.eval(); dev_iter.init_epoch()
n_dev_correct, dev_loss = 0, 0
for dev_batch_idx, dev_batch in enumerate(val_iter):
for dev_batch_idx, dev_batch in enumerate(dev_iter):
answer = model(dev_batch)
n_dev_correct += (torch.max(answer, 1)[1].view(dev_batch.label.size()).data == dev_batch.label.data).sum()
dev_loss = criterion(answer, dev_batch.label)
dev_acc = 100. * n_dev_correct / len(val)
print(val_log_template.format(time.time()-start,
epoch, iterations, batch_idx, len(train_iter),
100. * batch_idx / len(train_iter), loss.data[0], dev_loss.data[0], train_acc, dev_acc))
dev_acc = 100. * n_dev_correct / len(dev)
print(dev_log_template.format(time.time()-start,
epoch, iterations, 1+batch_idx, len(train_iter),
100. * (1+batch_idx) / len(train_iter), loss.data[0], dev_loss.data[0], train_acc, dev_acc))
if dev_acc > best_dev_acc:
best_dev_acc = dev_acc
torch.save(model, os.path.join(args.save_path,
'best_devacc_{}_devloss_{}__iter_{}_model.pt'.format(dev_acc, dev_loss.data[0], iterations)))
snapshot_prefix = os.path.join(args.save_path, 'best_snapshot')
snapshot_path = snapshot_prefix + '_devacc_{}_devloss_{}__iter_{}_model.pt'.format(dev_acc, dev_loss.data[0], iterations)
torch.save(model, snapshot_path)
for f in glob.glob(snapshot_prefix + '*'):
if f != snapshot_path:
os.remove(f)
elif iterations % args.log_every == 0:
print(log_template.format(time.time()-start,
epoch, iterations, batch_idx, len(train_iter),
100. * batch_idx / len(train_iter), loss.data[0], ' '*8, n_correct/n_total*100, ' '*12))
epoch, iterations, 1+batch_idx, len(train_iter),
100. * (1+batch_idx) / len(train_iter), loss.data[0], ' '*8, n_correct/n_total*100, ' '*12))


11 changes: 7 additions & 4 deletions snli/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,22 @@ def get_args():
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--d_embed', type=int, default=300)
parser.add_argument('--d_proj', type=int, default=300)
parser.add_argument('--d_hidden', type=int, default=300)
parser.add_argument('--n_layers', type=int, default=2)
parser.add_argument('--log_every', type=int, default=50)
parser.add_argument('--lr', type=float, default=.001)
parser.add_argument('--val_every', type=int, default=1000)
parser.add_argument('--dev_every', type=int, default=1000)
parser.add_argument('--save_every', type=int, default=1000)
parser.add_argument('--dp_ratio', type=int, default=0.0)
parser.add_argument('--bidirectional', action='store_true')
parser.add_argument('--no-bidirectional', action='store_false', dest='birnn')
parser.add_argument('--preserve-case', action='store_false', dest='lower')
parser.add_argument('--no-projection', action='store_false', dest='projection')
parser.add_argument('--train_embed', action='store_false', dest='fix_emb')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--save_path', type=str, default='')
parser.add_argument('--save_path', type=str, default='results')
parser.add_argument('--data_cache', type=str, default=os.path.join(os.getcwd(), '.data_cache'))
parser.add_argument('--vocab_cache', type=str, default=os.path.join(os.getcwd(), '.vocab_cache/input_vocab.pt'))
parser.add_argument('--vector_cache', type=str, default=os.path.join(os.getcwd(), '.vector_cache/input_vectors.pt'))
parser.add_argument('--word_vectors', type=str, default='glove.42B')
args = parser.parse_args()
return args

0 comments on commit a7a5cdc

Please sign in to comment.