Skip to content

Commit

Permalink
Integration tests (facebookresearch#1153)
Browse files Browse the repository at this point in the history
* Unit Test agent.
* Give 20 random negative candidates to choose from.
* Make parameters changable.
* Copyright.
* Add test for fairseq.
* Fix broken unit test.
* Shorten test so it runs in ~6 seconds.
* Update fairseq integration test.
* Quieter test output.
* Change how skipping is done in torchagent.
  • Loading branch information
stephenroller authored Sep 12, 2018
1 parent be30b41 commit 24b49f1
Show file tree
Hide file tree
Showing 9 changed files with 254 additions and 12 deletions.
2 changes: 1 addition & 1 deletion parlai/agents/fairseq/fairseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
try:
from fairseq import models, optim, criterions
except ImportError:
raise RuntimeError(
raise ImportError(
"Please run \"pip install -U 'git+https://github.com/pytorch/"
"[email protected]#egg=fairseq'\""
)
Expand Down
5 changes: 5 additions & 0 deletions parlai/tasks/integration_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
122 changes: 122 additions & 0 deletions parlai/tasks/integration_tests/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#!/usr/bin/env
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.

"""
These agents contain a number of "unit test" corpora, or
fake corpora that ensure models can learn simple behavior easily.
They are useful as unit tests for the basic models.
The corpora are all randomly, but deterministically generated
"""

from parlai.core.teachers import DialogTeacher
import random
import itertools

# default parameters
VOCAB_SIZE = 7
EXAMPLE_SIZE = 4
NUM_CANDIDATES = 10
NUM_TRAIN = 500
NUM_TEST = 100


class CandidateTeacher(DialogTeacher):
"""
Candidate teacher produces several candidates, one of which is a repeat
of the input. A good ranker should easily identify the correct response.
"""
def __init__(self, opt, shared=None,
vocab_size=VOCAB_SIZE,
example_size=EXAMPLE_SIZE,
num_candidates=NUM_CANDIDATES,
num_train=NUM_TRAIN,
num_test=NUM_TEST):
"""
:param int vocab_size: size of the vocabulary
:param int example_size: length of each example
:param int num_candidates: number of label_candidates generated
:param int num_train: size of the training set
:param int num_test: size of the valid/test sets
"""
self.opt = opt
opt['datafile'] = opt['datatype'].split(':')[0]
self.datafile = opt['datafile']

self.vocab_size = vocab_size
self.example_size = example_size
self.num_candidates = num_candidates
self.num_train = num_train
self.num_test = num_test

# set up the vocabulary
self.words = list(map(str, range(self.vocab_size)))

super().__init__(opt, shared)

def num_examples(self):
if self.datafile == 'train':
return self.num_train
else:
return self.num_test

def num_episodes(self):
return self.num_examples()

def setup_data(self, fold):
# N words appearing in a random order
self.rng = random.Random(42)
full_corpus = [
list(x) for x in itertools.permutations(self.words, self.example_size)
]
self.rng.shuffle(full_corpus)

it = iter(full_corpus)
self.train = list(itertools.islice(it, self.num_train))
self.val = list(itertools.islice(it, self.num_test))
self.test = list(itertools.islice(it, self.num_test))

# check we have enough data
assert (len(self.train) == self.num_train), len(self.train)
assert (len(self.val) == self.num_test), len(self.val)
assert (len(self.test) == self.num_test), len(self.test)

# check every word appear in the training set
assert len(set(itertools.chain(*self.train)) - set(self.words)) == 0

# select which set we're using
if fold == "train":
self.corpus = self.train
elif fold == "valid":
self.corpus = self.val
elif fold == "test":
self.corpus = self.test

# make sure the corpus is actually text strings
self.corpus = [' '.join(x) for x in self.corpus]

for i, text in enumerate(self.corpus):
cands = []
for j in range(NUM_CANDIDATES):
offset = (i + j) % len(self.corpus)
cands.append(self.corpus[offset])
yield (text, [text], 0, cands), True


class NocandidateTeacher(CandidateTeacher):
"""
Strips the candidates so the model can't see any options. Good for testing
simple generative models.
"""
def setup_data(self, fold):
raw = super().setup_data(fold)
for (t, a, r, c), e in raw:
yield (t, a), e


class DefaultTeacher(CandidateTeacher):
pass
9 changes: 9 additions & 0 deletions parlai/tasks/task_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,4 +675,13 @@
"Link to dataset: http://cocodataset.org/#download"
),
},
{
"id": "integration_tests",
"display_name": "Integration Tests",
"task": "integration_tests",
"tags": ["All", "Debug"],
"description": (
"Artificial tasks for ensuring models perform as expected"
),
},
]
2 changes: 1 addition & 1 deletion tests/test_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_basic_parse(self):

argparser = ParlaiParser()
DictionaryAgent.add_cmdline_args(argparser)
opt = argparser.parse_args()
opt = argparser.parse_args(print_args=False)
dictionary = DictionaryAgent(opt)
num_builtin = len(dictionary)

Expand Down
93 changes: 93 additions & 0 deletions tests/test_fairseq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.

import unittest
import io
import contextlib
import tempfile
import os
import shutil

from parlai.scripts.train_model import TrainLoop, setup_args

SKIP_TESTS = False
try:
import fairseq # noqa: F401
except ImportError:
SKIP_TESTS = True


BATCH_SIZE = 64
NUM_EPOCHS = 5
LR = 1e-2


def _mock_train(**args):
outdir = tempfile.mkdtemp()
parser = setup_args()
parser.set_defaults(
model_file=os.path.join(outdir, "model"),
**args,
)
stdout = io.StringIO()
with contextlib.redirect_stdout(stdout):
tl = TrainLoop(parser.parse_args(print_args=False))
valid, test = tl.train()

shutil.rmtree(outdir)
return stdout.getvalue(), valid, test


class TestFairseq(unittest.TestCase):
"""Checks that fairseq can learn some very basic tasks."""

@unittest.skipIf(SKIP_TESTS, "Fairseq not installed")
def test_labelcands(self):
stdout, valid, test = _mock_train(
task='integration_tests:CandidateTeacher',
model='fairseq',
arch='lstm_wiseman_iwslt_de_en',
lr=LR,
batchsize=BATCH_SIZE,
num_epochs=NUM_EPOCHS,
rank_candidates=True,
skip_generation=True,
)

self.assertTrue(
valid['hits@1'] > 0.95,
"valid hits@1 = {}\nLOG:\n{}".format(valid['hits@1'], stdout)
)
self.assertTrue(
test['hits@1'] > 0.95,
"test hits@1 = {}\nLOG:\n{}".format(test['hits@1'], stdout)
)

@unittest.skipIf(SKIP_TESTS, "Fairseq not installed")
def test_generation(self):
stdout, valid, test = _mock_train(
task='integration_tests:NocandidateTeacher',
model='fairseq',
arch='lstm_wiseman_iwslt_de_en',
lr=LR,
batchsize=BATCH_SIZE,
num_epochs=NUM_EPOCHS,
rank_candidates=False,
skip_generation=False,
)

self.assertTrue(
valid['ppl'] < 1.2,
"valid ppl = {}\nLOG:\n{}".format(valid['ppl'], stdout)
)
self.assertTrue(
test['ppl'] < 1.2,
"test ppl = {}\nLOG:\n{}".format(test['ppl'], stdout)
)


if __name__ == '__main__':
unittest.main()
10 changes: 7 additions & 3 deletions tests/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@ class TestInit(unittest.TestCase):

def test_init_everywhere(self):
from parlai.core.params import ParlaiParser
opt = ParlaiParser().parse_args()
for root, _subfolder, files in os.walk(os.path.join(opt['parlai_home'], 'parlai')):
opt = ParlaiParser().parse_args(print_args=False)
folders = os.walk(os.path.join(opt['parlai_home'], 'parlai'))
for root, _subfolder, files in folders:
if not root.endswith('__pycache__'):
if os.path.basename(root) == 'html':
# skip mturk core's html folder--not a python module
continue
assert '__init__.py' in files, 'Dir {} is missing __init__.py'.format(root)
self.assertTrue(
'__init__.py' in files,
'Dir {} is missing __init__.py'.format(root)
)


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tasklist.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class TestInit(unittest.TestCase):
def test_tasklist(self):
from parlai.tasks.task_list import task_list
from parlai.core.params import ParlaiParser
opt = ParlaiParser().parse_args()
opt = ParlaiParser().parse_args(print_args=False)

a = set((t['task'].split(':')[0] for t in task_list))

Expand Down
21 changes: 15 additions & 6 deletions tests/test_torch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@

import unittest
from parlai.core.agents import Agent
from parlai.core.torch_agent import TorchAgent, Output

SKIP_TESTS = False
try:
from parlai.core.torch_agent import TorchAgent, Output
import torch
except ImportError:
SKIP_TESTS = True


class MockDict(Agent):
Expand Down Expand Up @@ -95,6 +101,7 @@ def test_share(self):
shared = agent.share()
self.assertTrue('dict' in shared)

@unittest.skipIf(SKIP_TESTS, "Torch not installed.")
def test__vectorize_text(self):
"""Test _vectorize_text and its different options."""
agent = get_agent()
Expand Down Expand Up @@ -187,6 +194,7 @@ def test__vectorize_text(self):
self.assertEqual(len(vec), 3)
self.assertEqual(vec.tolist(), [MockDict.BEG_IDX, 1, 2])

@unittest.skipIf(SKIP_TESTS, "Torch not installed.")
def test__check_truncate(self):
"""Make sure we are truncating when needed."""
agent = get_agent()
Expand All @@ -198,6 +206,7 @@ def test__check_truncate(self):
self.assertEqual(agent._check_truncate(inp, 1).tolist(), [1])
self.assertEqual(agent._check_truncate(inp, 0).tolist(), [])

@unittest.skipIf(SKIP_TESTS, "Torch not installed.")
def test_vectorize(self):
"""Test the vectorization of observations.
Expand Down Expand Up @@ -294,6 +303,7 @@ def test_vectorize(self):
self.assertEqual([m.tolist() for m in out['memory_vecs']],
[[1], [1], [1]])

@unittest.skipIf(SKIP_TESTS, "Torch not installed.")
def test_batchify(self):
"""Make sure the batchify function sets up the right fields."""
agent = get_agent(rank_candidates=True)
Expand Down Expand Up @@ -443,6 +453,7 @@ def test_batchify(self):
for i, cs in enumerate(batch.candidate_vecs):
self.assertEqual(len(cs), len(obs_cands[i]['label_candidates']))

@unittest.skipIf(SKIP_TESTS, "Torch not installed.")
def test_match_batch(self):
"""Make sure predictions are correctly aligned when available."""
agent = get_agent()
Expand Down Expand Up @@ -671,6 +682,7 @@ def test_last_reply(self):
self.assertEqual(agent.last_reply(use_label=False),
'It\'s okay! I\'m a leaf on the wind.')

@unittest.skipIf(SKIP_TESTS, "Torch not installed.")
def test_observe(self):
"""Make sure agent stores and returns observation."""
agent = get_agent()
Expand Down Expand Up @@ -699,6 +711,7 @@ def test_observe(self):
self.assertEqual(out['text'],
'I\'ll be back.\nI\'m back.\nI\'ll be back.')

@unittest.skipIf(SKIP_TESTS, "Torch not installed.")
def test_batch_act(self):
"""Make sure batch act calls the right step."""
agent = get_agent()
Expand Down Expand Up @@ -731,8 +744,4 @@ def test_batch_act(self):


if __name__ == '__main__':
try:
import torch
unittest.main()
except ImportError as e:
print('Skipping TestTorchAgent, no pytorch.')
unittest.main()

0 comments on commit 24b49f1

Please sign in to comment.