forked from schangpi/pytorch-seq2seq
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathintegration_test.py
114 lines (101 loc) · 4.24 KB
/
integration_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import argparse
import logging
import torch
import torchtext
import seq2seq
from seq2seq.trainer import SupervisedTrainer
from seq2seq.models import EncoderRNN, DecoderRNN, Seq2seq
from seq2seq.loss import Perplexity
from seq2seq.dataset import SourceField, TargetField
from seq2seq.evaluator import Predictor, Evaluator
from seq2seq.util.checkpoint import Checkpoint
parser = argparse.ArgumentParser()
parser.add_argument('--train_path', action='store', dest='train_path',
help='Path to train data')
parser.add_argument('--dev_path', action='store', dest='dev_path',
help='Path to dev data')
parser.add_argument('--expt_dir', action='store', dest='expt_dir', default='./experiment',
help='Path to experiment directory. If load_checkpoint is True, then path to checkpoint directory has to be provided')
parser.add_argument('--load_checkpoint', action='store', dest='load_checkpoint',
help='The name of the checkpoint to load, usually an encoded time string')
parser.add_argument('--resume', action='store_true', dest='resume',
default=False,
help='Indicates if training has to be resumed from the latest checkpoint')
parser.add_argument('--log-level', dest='log_level',
default='info',
help='Logging level.')
opt = parser.parse_args()
LOG_FORMAT = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'
logging.basicConfig(format=LOG_FORMAT, level=getattr(logging, opt.log_level.upper()))
logging.info(opt)
# Prepare dataset
src = SourceField()
tgt = TargetField()
max_len = 50
def len_filter(example):
return len(example.src) <= max_len and len(example.tgt) <= max_len
train = torchtext.data.TabularDataset(
path=opt.train_path, format='tsv',
fields=[('src', src), ('tgt', tgt)],
filter_pred=len_filter
)
dev = torchtext.data.TabularDataset(
path=opt.dev_path, format='tsv',
fields=[('src', src), ('tgt', tgt)],
filter_pred=len_filter
)
src.build_vocab(train, max_size=50000)
tgt.build_vocab(train, max_size=50000)
input_vocab = src.vocab
output_vocab = tgt.vocab
# Prepare loss
weight = torch.ones(len(tgt.vocab))
pad = tgt.vocab.stoi[tgt.pad_token]
loss = Perplexity(weight, pad)
if torch.cuda.is_available():
loss.cuda()
if opt.load_checkpoint is not None:
logging.info("loading checkpoint from {}".format(os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint)))
checkpoint_path = os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint)
checkpoint = Checkpoint.load(checkpoint_path)
seq2seq = checkpoint.model
input_vocab = checkpoint.input_vocab
output_vocab = checkpoint.output_vocab
else:
seq2seq = None
optimizer = None
if not opt.resume:
# Initialize model
hidden_size=128
bidirectional = True
encoder = EncoderRNN(len(src.vocab), max_len, hidden_size,
bidirectional=bidirectional,
rnn_cell='lstm',
variable_lengths=True)
decoder = DecoderRNN(len(tgt.vocab), max_len, hidden_size * 2,
dropout_p=0.2, use_attention=True,
bidirectional=bidirectional,
rnn_cell='lstm',
eos_id=tgt.eos_id, sos_id=tgt.sos_id)
seq2seq = Seq2seq(encoder, decoder)
if torch.cuda.is_available():
seq2seq.cuda()
for param in seq2seq.parameters():
param.data.uniform_(-0.08, 0.08)
# train
t = SupervisedTrainer(loss=loss, batch_size=32,
checkpoint_every=50,
print_every=10, expt_dir=opt.expt_dir)
seq2seq = t.train(seq2seq, train,
num_epochs=6, dev_data=dev,
optimizer=optimizer,
teacher_forcing_ratio=0.5,
resume=opt.resume)
evaluator = Evaluator(loss=loss, batch_size=32)
dev_loss, accuracy = evaluator.evaluate(seq2seq, dev)
assert dev_loss < 1.5
predictor = Predictor(seq2seq, input_vocab, output_vocab)
inp_seq = "1 3 5 7 9"
seq = predictor.predict(inp_seq.split())
assert " ".join(seq[:-1]) == inp_seq[::-1]