Skip to content

Commit

Permalink
Added unit test for train_model.py
Browse files Browse the repository at this point in the history
Added repeat.py to create RepeatTeacher for basic debugging purposes
  • Loading branch information
Emily Dinan committed Dec 20, 2017
1 parent 0b25e63 commit 98a59e4
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 6 deletions.
11 changes: 5 additions & 6 deletions examples/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,11 @@
from parlai.core.worlds import create_task
from parlai.core.params import ParlaiParser
from parlai.core.utils import Timer
from parlai.core.metrics import compute_time_metrics
import build_dict
from examples.build_dict import build_dict
import math

def setup_args():
parser = ParlaiParser(True, True)
def setup_args(model_args=None):
parser = ParlaiParser(True, True, model_argv=model_args)
train = parser.add_argument_group('Training Loop Arguments')
train.add_argument('-et', '--evaltask',
help=('task to use for valid/test (defaults to the '
Expand Down Expand Up @@ -120,7 +119,7 @@ def __init__(self, parser):
if opt['dict_file'] is None and opt.get('model_file'):
opt['dict_file'] = opt['model_file'] + '.dict'
print("[ building dictionary first... ]")
build_dict.build_dict(opt)
build_dict(opt)
# Create model and assign it to the specified task
self.agent = create_agent(opt)
self.world = create_task(opt, self.agent)
Expand Down Expand Up @@ -201,7 +200,7 @@ def train(self):
while True:
world.parley()
self.parleys += 1

if world.get_total_epochs() >= self.max_num_epochs:
self.log()
print('[ num_epochs completed:{} time elapsed:{}s ]'.format(
Expand Down
23 changes: 23 additions & 0 deletions tests/tasks/repeat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
"""Class which creates a dummy dataset for testing purposes.
Used in test_train_model.py
"""
from parlai.core.teachers import DialogTeacher

import copy

class RepeatTeacher(DialogTeacher):
def __init__(self, opt, shared=None):
opt = copy.deepcopy(opt)
opt['datafile'] = 'unused_path'
task = opt.get('task', 'tests.tasks.repeat:RepeatTeacher:50')
self.data_length = int(task.split(':')[2])
super().__init__(opt)

def setup_data(self, unused_path):
for i in range(self.data_length):
yield ((str(i), [str(i)]), True)
108 changes: 108 additions & 0 deletions tests/test_train_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.

from examples.train_model import TrainLoop, setup_args
from parlai.core.agents import create_agent
from parlai.core.utils import Timer
from parlai.core.worlds import create_task

import ast
import importlib
import unittest
import sys


class TestTrainModel(unittest.TestCase):
"""Basic tests on the train_model.py example."""

def setup_test_args(self):
parser = setup_args(model_args=['--model', 'memnn'])
# using memnn, so we want to check if torch is downloaded
torch_downloaded = importlib.find_loader('torch')
self.assertTrue(torch_downloaded is not None, "Torch not downloaded")
return parser

def test_output(self):
class TestTrainLoop(TrainLoop):
args = [
'--model', 'memnn',
'--task', 'tests.tasks.repeat:RepeatTeacher:10',
'--dict-file', '/tmp/repeat',
'-bs', '1',
'-vtim', '5',
'-vp', '2',
'--embedding-size', '8',
'--no-cuda'
]

def __init__(self, parser):
opt = parser.parse_args(self.args, print_args=False)
self.agent = create_agent(opt)
self.world = create_task(opt, self.agent)
self.train_time = Timer()
self.validate_time = Timer()
self.log_time = Timer()
self.save_time = Timer()
print('[ training... ]')
self.parleys = 0
self.max_num_epochs = opt['num_epochs'] \
if opt['num_epochs'] > 0 else float('inf')
self.max_train_time = opt['max_train_time'] \
if opt['max_train_time'] > 0 else float('inf')
self.log_every_n_secs = opt['log_every_n_secs'] \
if opt['log_every_n_secs'] > 0 else float('inf')
self.val_every_n_secs = opt['validation_every_n_secs'] \
if opt['validation_every_n_secs'] > 0 else float('inf')
self.save_every_n_secs = opt['save_every_n_secs'] \
if opt['save_every_n_secs'] > 0 else float('inf')
self.best_valid = 0
self.impatience = 0
self.saved = False
self.valid_world = None
self.opt = opt

class display_output(object):
def __init__(self):
self.data = []

def write(self, s):
self.data.append(s)

def __str__(self):
return "".join(self.data)

old_out = sys.stdout
output = display_output()
try:
sys.stdout = output
TestTrainLoop(self.setup_test_args()).train()
finally:
# restore sys.stdout
sys.stdout = old_out

str_output = str(output)

self.assertTrue(len(str_output) > 0, "Output is empty")
self.assertTrue("[ training... ]" in str_output,
"Did not reach training step")
self.assertTrue("[ running eval: valid ]" in str_output,
"Did not reach validation step")
self.assertTrue("valid:{'total': 10," in str_output,
"Did not complete validation")
self.assertTrue("[ running eval: test ]" in str_output,
"Did not reach evaluation step")
self.assertTrue("test:{'total': 10," in str_output,
"Did not complete evaluation")

list_output = str_output.split("\n")
for line in list_output:
if "test:{" in line:
score = ast.literal_eval(line.split("test:", 1)[1])
self.assertTrue(score['accuracy'] == 1,
"Accuracy did not reach 1")

if __name__ == '__main__':
unittest.main()

0 comments on commit 98a59e4

Please sign in to comment.