forked from facebookresearch/ParlAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Integration tests (facebookresearch#1153)
* Unit Test agent. * Give 20 random negative candidates to choose from. * Make parameters changable. * Copyright. * Add test for fairseq. * Fix broken unit test. * Shorten test so it runs in ~6 seconds. * Update fairseq integration test. * Quieter test output. * Change how skipping is done in torchagent.
- Loading branch information
1 parent
be30b41
commit 24b49f1
Showing
9 changed files
with
254 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ | |
try: | ||
from fairseq import models, optim, criterions | ||
except ImportError: | ||
raise RuntimeError( | ||
raise ImportError( | ||
"Please run \"pip install -U 'git+https://github.com/pytorch/" | ||
"[email protected]#egg=fairseq'\"" | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) 2017-present, Facebook, Inc. | ||
# All rights reserved. | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. An additional grant | ||
# of patent rights can be found in the PATENTS file in the same directory. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
#!/usr/bin/env | ||
# Copyright (c) 2017-present, Facebook, Inc. | ||
# All rights reserved. | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. An additional grant | ||
# of patent rights can be found in the PATENTS file in the same directory. | ||
|
||
""" | ||
These agents contain a number of "unit test" corpora, or | ||
fake corpora that ensure models can learn simple behavior easily. | ||
They are useful as unit tests for the basic models. | ||
The corpora are all randomly, but deterministically generated | ||
""" | ||
|
||
from parlai.core.teachers import DialogTeacher | ||
import random | ||
import itertools | ||
|
||
# default parameters | ||
VOCAB_SIZE = 7 | ||
EXAMPLE_SIZE = 4 | ||
NUM_CANDIDATES = 10 | ||
NUM_TRAIN = 500 | ||
NUM_TEST = 100 | ||
|
||
|
||
class CandidateTeacher(DialogTeacher): | ||
""" | ||
Candidate teacher produces several candidates, one of which is a repeat | ||
of the input. A good ranker should easily identify the correct response. | ||
""" | ||
def __init__(self, opt, shared=None, | ||
vocab_size=VOCAB_SIZE, | ||
example_size=EXAMPLE_SIZE, | ||
num_candidates=NUM_CANDIDATES, | ||
num_train=NUM_TRAIN, | ||
num_test=NUM_TEST): | ||
""" | ||
:param int vocab_size: size of the vocabulary | ||
:param int example_size: length of each example | ||
:param int num_candidates: number of label_candidates generated | ||
:param int num_train: size of the training set | ||
:param int num_test: size of the valid/test sets | ||
""" | ||
self.opt = opt | ||
opt['datafile'] = opt['datatype'].split(':')[0] | ||
self.datafile = opt['datafile'] | ||
|
||
self.vocab_size = vocab_size | ||
self.example_size = example_size | ||
self.num_candidates = num_candidates | ||
self.num_train = num_train | ||
self.num_test = num_test | ||
|
||
# set up the vocabulary | ||
self.words = list(map(str, range(self.vocab_size))) | ||
|
||
super().__init__(opt, shared) | ||
|
||
def num_examples(self): | ||
if self.datafile == 'train': | ||
return self.num_train | ||
else: | ||
return self.num_test | ||
|
||
def num_episodes(self): | ||
return self.num_examples() | ||
|
||
def setup_data(self, fold): | ||
# N words appearing in a random order | ||
self.rng = random.Random(42) | ||
full_corpus = [ | ||
list(x) for x in itertools.permutations(self.words, self.example_size) | ||
] | ||
self.rng.shuffle(full_corpus) | ||
|
||
it = iter(full_corpus) | ||
self.train = list(itertools.islice(it, self.num_train)) | ||
self.val = list(itertools.islice(it, self.num_test)) | ||
self.test = list(itertools.islice(it, self.num_test)) | ||
|
||
# check we have enough data | ||
assert (len(self.train) == self.num_train), len(self.train) | ||
assert (len(self.val) == self.num_test), len(self.val) | ||
assert (len(self.test) == self.num_test), len(self.test) | ||
|
||
# check every word appear in the training set | ||
assert len(set(itertools.chain(*self.train)) - set(self.words)) == 0 | ||
|
||
# select which set we're using | ||
if fold == "train": | ||
self.corpus = self.train | ||
elif fold == "valid": | ||
self.corpus = self.val | ||
elif fold == "test": | ||
self.corpus = self.test | ||
|
||
# make sure the corpus is actually text strings | ||
self.corpus = [' '.join(x) for x in self.corpus] | ||
|
||
for i, text in enumerate(self.corpus): | ||
cands = [] | ||
for j in range(NUM_CANDIDATES): | ||
offset = (i + j) % len(self.corpus) | ||
cands.append(self.corpus[offset]) | ||
yield (text, [text], 0, cands), True | ||
|
||
|
||
class NocandidateTeacher(CandidateTeacher): | ||
""" | ||
Strips the candidates so the model can't see any options. Good for testing | ||
simple generative models. | ||
""" | ||
def setup_data(self, fold): | ||
raw = super().setup_data(fold) | ||
for (t, a, r, c), e in raw: | ||
yield (t, a), e | ||
|
||
|
||
class DefaultTeacher(CandidateTeacher): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Copyright (c) 2017-present, Facebook, Inc. | ||
# All rights reserved. | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. An additional grant | ||
# of patent rights can be found in the PATENTS file in the same directory. | ||
|
||
import unittest | ||
import io | ||
import contextlib | ||
import tempfile | ||
import os | ||
import shutil | ||
|
||
from parlai.scripts.train_model import TrainLoop, setup_args | ||
|
||
SKIP_TESTS = False | ||
try: | ||
import fairseq # noqa: F401 | ||
except ImportError: | ||
SKIP_TESTS = True | ||
|
||
|
||
BATCH_SIZE = 64 | ||
NUM_EPOCHS = 5 | ||
LR = 1e-2 | ||
|
||
|
||
def _mock_train(**args): | ||
outdir = tempfile.mkdtemp() | ||
parser = setup_args() | ||
parser.set_defaults( | ||
model_file=os.path.join(outdir, "model"), | ||
**args, | ||
) | ||
stdout = io.StringIO() | ||
with contextlib.redirect_stdout(stdout): | ||
tl = TrainLoop(parser.parse_args(print_args=False)) | ||
valid, test = tl.train() | ||
|
||
shutil.rmtree(outdir) | ||
return stdout.getvalue(), valid, test | ||
|
||
|
||
class TestFairseq(unittest.TestCase): | ||
"""Checks that fairseq can learn some very basic tasks.""" | ||
|
||
@unittest.skipIf(SKIP_TESTS, "Fairseq not installed") | ||
def test_labelcands(self): | ||
stdout, valid, test = _mock_train( | ||
task='integration_tests:CandidateTeacher', | ||
model='fairseq', | ||
arch='lstm_wiseman_iwslt_de_en', | ||
lr=LR, | ||
batchsize=BATCH_SIZE, | ||
num_epochs=NUM_EPOCHS, | ||
rank_candidates=True, | ||
skip_generation=True, | ||
) | ||
|
||
self.assertTrue( | ||
valid['hits@1'] > 0.95, | ||
"valid hits@1 = {}\nLOG:\n{}".format(valid['hits@1'], stdout) | ||
) | ||
self.assertTrue( | ||
test['hits@1'] > 0.95, | ||
"test hits@1 = {}\nLOG:\n{}".format(test['hits@1'], stdout) | ||
) | ||
|
||
@unittest.skipIf(SKIP_TESTS, "Fairseq not installed") | ||
def test_generation(self): | ||
stdout, valid, test = _mock_train( | ||
task='integration_tests:NocandidateTeacher', | ||
model='fairseq', | ||
arch='lstm_wiseman_iwslt_de_en', | ||
lr=LR, | ||
batchsize=BATCH_SIZE, | ||
num_epochs=NUM_EPOCHS, | ||
rank_candidates=False, | ||
skip_generation=False, | ||
) | ||
|
||
self.assertTrue( | ||
valid['ppl'] < 1.2, | ||
"valid ppl = {}\nLOG:\n{}".format(valid['ppl'], stdout) | ||
) | ||
self.assertTrue( | ||
test['ppl'] < 1.2, | ||
"test ppl = {}\nLOG:\n{}".format(test['ppl'], stdout) | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters