Skip to content

Commit

Permalink
Merge pull request facebookresearch#462 from facebookresearch/train_m…
Browse files Browse the repository at this point in the history
…odel_test

Added unit test for train_model.py
  • Loading branch information
emilydinan authored Dec 20, 2017
2 parents 68a7ff1 + 50001aa commit 76a070c
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 6 deletions.
11 changes: 5 additions & 6 deletions examples/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,11 @@
from parlai.core.worlds import create_task
from parlai.core.params import ParlaiParser
from parlai.core.utils import Timer
from parlai.core.metrics import compute_time_metrics
import build_dict
from examples.build_dict import build_dict
import math

def setup_args():
parser = ParlaiParser(True, True)
def setup_args(model_args=None):
parser = ParlaiParser(True, True, model_argv=model_args)
train = parser.add_argument_group('Training Loop Arguments')
train.add_argument('-et', '--evaltask',
help=('task to use for valid/test (defaults to the '
Expand Down Expand Up @@ -120,7 +119,7 @@ def __init__(self, parser):
if opt['dict_file'] is None and opt.get('model_file'):
opt['dict_file'] = opt['model_file'] + '.dict'
print("[ building dictionary first... ]")
build_dict.build_dict(opt)
build_dict(opt)
# Create model and assign it to the specified task
self.agent = create_agent(opt)
self.world = create_task(opt, self.agent)
Expand Down Expand Up @@ -201,7 +200,7 @@ def train(self):
while True:
world.parley()
self.parleys += 1

if world.get_total_epochs() >= self.max_num_epochs:
self.log()
print('[ num_epochs completed:{} time elapsed:{}s ]'.format(
Expand Down
23 changes: 23 additions & 0 deletions tests/tasks/repeat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
"""Class which creates a dummy dataset for testing purposes.
Used in test_train_model.py
"""
from parlai.core.teachers import DialogTeacher

import copy

class RepeatTeacher(DialogTeacher):
def __init__(self, opt, shared=None):
opt = copy.deepcopy(opt)
opt['datafile'] = 'unused_path'
task = opt.get('task', 'tests.tasks.repeat:RepeatTeacher:50')
self.data_length = int(task.split(':')[2])
super().__init__(opt)

def setup_data(self, unused_path):
for i in range(self.data_length):
yield ((str(i), [str(i)]), True)
76 changes: 76 additions & 0 deletions tests/test_train_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.

from examples.train_model import TrainLoop, setup_args

import ast
import unittest
import sys


class TestTrainModel(unittest.TestCase):
"""Basic tests on the train_model.py example."""

def test_output(self):
class display_output(object):
def __init__(self):
self.data = []

def write(self, s):
self.data.append(s)

def __str__(self):
return "".join(self.data)

old_out = sys.stdout
output = display_output()

try:
sys.stdout = output
try:
import torch
except ImportError:
raise ImportError('Cannot import torch')
return
parser = setup_args(model_args=['--model', 'memnn'])
parser.set_defaults(
model='memnn',
task='tests.tasks.repeat:RepeatTeacher:10',
dict_file='/tmp/repeat',
batchsize='1',
validation_every_n_secs='5',
validation_patience='2',
embedding_size='8',
no_cuda=True
)
TrainLoop(parser).train()
finally:
# restore sys.stdout
sys.stdout = old_out

str_output = str(output)

self.assertTrue(len(str_output) > 0, "Output is empty")
self.assertTrue("[ training... ]" in str_output,
"Did not reach training step")
self.assertTrue("[ running eval: valid ]" in str_output,
"Did not reach validation step")
self.assertTrue("valid:{'total': 10," in str_output,
"Did not complete validation")
self.assertTrue("[ running eval: test ]" in str_output,
"Did not reach evaluation step")
self.assertTrue("test:{'total': 10," in str_output,
"Did not complete evaluation")

list_output = str_output.split("\n")
for line in list_output:
if "test:{" in line:
score = ast.literal_eval(line.split("test:", 1)[1])
self.assertTrue(score['accuracy'] == 1,
"Accuracy did not reach 1")

if __name__ == '__main__':
unittest.main()

0 comments on commit 76a070c

Please sign in to comment.