Skip to content

Commit

Permalink
Added unit tests for examples display_data.py and eval_model.py. Chan…
Browse files Browse the repository at this point in the history
…ged repeat_label agent to check for eval_labels
  • Loading branch information
Emily Dinan committed Dec 19, 2017
1 parent cc3657c commit 41175f0
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 22 deletions.
20 changes: 11 additions & 9 deletions examples/display_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,7 @@

import random


def main():
random.seed(42)

# Get command line arguments
parser = ParlaiParser()
parser.add_argument('-n', '--num-examples', default=10, type=int)
opt = parser.parse_args()

def display_data(opt):
# create repeat label agent and assign it to the specified task
agent = RepeatLabelAgent(opt)
world = create_task(opt, agent)
Expand All @@ -40,5 +32,15 @@ def main():
break


def main():
random.seed(42)

# Get command line arguments
parser = ParlaiParser()
parser.add_argument('-n', '--num-examples', default=10, type=int)
opt = parser.parse_args()

display_data(opt)

if __name__ == '__main__':
main()
26 changes: 16 additions & 10 deletions examples/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,18 @@

import random

def main():
random.seed(42)

# Get command line arguments
parser = ParlaiParser(True, True)
parser.add_argument('-n', '--num-examples', default=100000000)
parser.add_argument('-d', '--display-examples', type='bool', default=False)
parser.set_defaults(datatype='valid')
opt = parser.parse_args(print_args=False)
def eval_model(opt, parser, printargs=True):
# Create model and assign it to the specified task
agent = create_agent(opt)
world = create_task(opt, agent)
# Show arguments after loading model
parser.opt = agent.opt
parser.print_args()
if (printargs):
parser.print_args()

# Show some example dialogs:
for k in range(int(opt['num_examples'])):
for _ in range(int(opt['num_examples'])):
world.parley()
print("---")
if opt['display_examples']:
Expand All @@ -45,5 +39,17 @@ def main():
break
world.shutdown()

def main():
random.seed(42)

# Get command line arguments
parser = ParlaiParser(True, True)
parser.add_argument('-n', '--num-examples', default=100000000)
parser.add_argument('-d', '--display-examples', type='bool', default=False)
parser.set_defaults(datatype='valid')
opt = parser.parse_args(print_args=False)

eval_model(opt, parser)

if __name__ == '__main__':
main()
5 changes: 2 additions & 3 deletions parlai/agents/repeat_label/repeat_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,8 @@ def act(self):
return {'text': 'Nothing to repeat yet.'}
reply = {}
reply['id'] = self.getID()
if ('labels' in obs and obs['labels'] is not None
and len(obs['labels']) > 0):
labels = obs['labels']
labels = obs.get('labels', obs.get('eval_labels', None))
if labels:
if random.random() >= self.cantAnswerPercent:
if self.returnOneRandomAnswer:
reply['text'] = labels[random.randrange(len(labels))]
Expand Down
52 changes: 52 additions & 0 deletions tests/test_display_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
from examples.display_data import display_data
from parlai.core.params import ParlaiParser

import sys
import unittest


class TestDisplayData(unittest.TestCase):
"""Basic tests on the display_data.py example."""

args = [
'--task', 'babi:task1k:1',
]
parser = ParlaiParser()
opt = parser.parse_args(args, print_args=False)
opt['num_examples'] = 1

def test_output(self):
"""Does display_data reach the end of the loop?"""

class display_output(object):
def __init__(self):
self.data = []

def write(self, s):
self.data.append(s)

def __str__(self):
return "".join(self.data)

old_out = sys.stdout
output = display_output()
try:
sys.stdout = output
display_data(self.opt)
finally:
# restore sys.stdout
sys.stdout = old_out

str_output = str(output)
self.assertTrue(len(str_output) > 0, "Output is empty")
self.assertTrue("[babi:task1k:1]:" in str_output,
"Babi task did not print")
self.assertTrue("~~" in str_output, "Example output did not complete")

if __name__ == '__main__':
unittest.main()
65 changes: 65 additions & 0 deletions tests/test_eval_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
from examples.eval_model import eval_model
from parlai.core.params import ParlaiParser

import ast
import unittest
import sys


class TestEvalModel(unittest.TestCase):
"""Basic tests on the eval_model.py example."""

args = [
'--task', '#moviedd-reddit',
'--datatype', 'valid',
]

parser = ParlaiParser()
parser.set_defaults(datatype='valid')
opt = parser.parse_args(args, print_args=False)
opt['model'] = 'repeat_label'
opt['num_examples'] = 5
opt['display_examples'] = False

def test_output(self):
"""Test output of running eval_model"""
class display_output(object):
def __init__(self):
self.data = []

def write(self, s):
self.data.append(s)

def __str__(self):
return "".join(self.data)

old_out = sys.stdout
output = display_output()
try:
sys.stdout = output
eval_model(self.opt, self.parser, printargs=False)
finally:
# restore sys.stdout
sys.stdout = old_out

str_output = str(output)
self.assertTrue(len(str_output) > 0, "Output is empty")

# decode the output
scores = str_output.split("\n---\n")
for i in range(1, len(scores)):
score = ast.literal_eval(scores[i])
# check totals
self.assertTrue(score['total'] == i,
"Total is incorrect")
# accuracy should be one
self.assertTrue(score['accuracy'] == 1,
"accuracy != 1")

if __name__ == '__main__':
unittest.main()

0 comments on commit 41175f0

Please sign in to comment.