forked from facebookresearch/ParlAI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_train_model.py
117 lines (103 loc) · 4.09 KB
/
test_train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Basic tests that ensure train_model.py behaves in predictable ways.
"""
import unittest
import parlai.utils.testing as testing_utils
from parlai.core.worlds import create_task
from parlai.core.params import ParlaiParser
class TestTrainModel(unittest.TestCase):
def test_fast_final_eval(self):
valid, test = testing_utils.train_model(
{
'task': 'integration_tests',
'validation_max_exs': 10,
'model': 'repeat_label',
'short_final_eval': True,
'num_epochs': 1.0,
}
)
self.assertEqual(valid['exs'], 10, 'Validation exs is wrong')
self.assertEqual(test['exs'], 10, 'Test exs is wrong')
def test_multitasking_metrics_micro(self):
valid, test = testing_utils.train_model(
{
'task': 'integration_tests:candidate,'
'integration_tests:multiturnCandidate',
'model': 'random_candidate',
'num_epochs': 0.5,
'aggregate_micro': True,
}
)
task1_acc = valid['integration_tests:candidate/accuracy']
task2_acc = valid['integration_tests:multiturnCandidate/accuracy']
total_acc = valid['accuracy']
self.assertEqual(
total_acc, task1_acc + task2_acc, 'Task accuracy is averaged incorrectly'
)
valid, test = testing_utils.train_model(
{
'task': 'integration_tests:candidate,'
'integration_tests:multiturnCandidate',
'model': 'random_candidate',
'num_epochs': 0.5,
'aggregate_micro': True,
}
)
task1_acc = valid['integration_tests:candidate/accuracy']
task2_acc = valid['integration_tests:multiturnCandidate/accuracy']
total_acc = valid['accuracy']
# metrics should be averaged equally across tasks
self.assertEqual(
total_acc, task1_acc + task2_acc, 'Task accuracy is averaged incorrectly'
)
def test_multitasking_metrics_macro(self):
valid, test = testing_utils.train_model(
{
'task': 'integration_tests:candidate,'
'integration_tests:multiturnCandidate',
'model': 'random_candidate',
'num_epochs': 0.5,
'aggregate_micro': False,
}
)
task1_acc = valid['integration_tests:candidate/accuracy']
task2_acc = valid['integration_tests:multiturnCandidate/accuracy']
total_acc = valid['accuracy']
self.assertEqual(
total_acc,
0.5 * (task1_acc.value() + task2_acc.value()),
'Task accuracy is averaged incorrectly',
)
valid, test = testing_utils.train_model(
{
'task': 'integration_tests:candidate,'
'integration_tests:multiturnCandidate',
'model': 'random_candidate',
'num_epochs': 0.5,
'aggregate_micro': False,
}
)
task1_acc = valid['integration_tests:candidate/accuracy']
task2_acc = valid['integration_tests:multiturnCandidate/accuracy']
total_acc = valid['accuracy']
# metrics should be averaged equally across tasks
self.assertEqual(
total_acc,
0.5 * (task1_acc.value() + task2_acc.value()),
'Task accuracy is averaged incorrectly',
)
def test_multitasking_id_overlap(self):
with self.assertRaises(AssertionError) as context:
pp = ParlaiParser()
opt = pp.parse_args(['--task', 'integration_tests,integration_tests'])
self.world = create_task(opt, None)
self.assertTrue(
'teachers have overlap in id integration_tests.'
in str(context.exception)
)
if __name__ == '__main__':
unittest.main()