forked from facebookresearch/ParlAI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_tod_world_metrics_in_script.py
279 lines (249 loc) · 9.61 KB
/
test_tod_world_metrics_in_script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Tests tod world metrics in the full script, *including* making the script properly set
up the agents on its own.
Use a few of the API Call + goal hit metrics as the metric handlers to test proper
functionality.
"""
import copy
import unittest
from parlai.core.metrics import dict_report
from parlai.core.opt import Opt
from parlai.core.tod.tod_core import SerializationHelpers
import parlai.core.tod.tod_test_utils.test_agents as test_agents
from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY
import parlai.scripts.tod_world_script as tod_world_script
# Ignore lint on following line; want to have registered classes show up for tests
import projects.tod_simulator.world_metrics.extended_world_metrics # noqa: F401
NUM_EPISODES = 35
TEST_SETUP = {
"api_schema_grounding_model": "parlai.core.tod.tod_test_utils.test_agents:ApiSchemaAgent",
"goal_grounding_model": "parlai.core.tod.tod_test_utils.test_agents:GoalAgent",
"user_model": "parlai.core.tod.tod_test_utils.test_agents:UserUttAgent",
"system_model": "parlai.core.tod.tod_test_utils.test_agents:ApiCallAndSysUttAgent",
"api_resp_model": "fixed_response",
test_agents.TEST_NUM_EPISODES_OPT_KEY: NUM_EPISODES,
}
TEST_SETUP_BROKEN_USER_SYSTEM = {
"api_schema_grounding_model": "parlai.core.tod.tod_test_utils.test_agents:ApiSchemaAgent",
"goal_grounding_model": "parlai.core.tod.tod_test_utils.test_agents:GoalAgent",
"user_model": "fixed_response",
"system_model": "fixed_response",
"api_resp_model": "fixed_response",
test_agents.TEST_NUM_EPISODES_OPT_KEY: NUM_EPISODES,
}
TEST_SETUP_EMPTY_APISCHEMA = copy.deepcopy(TEST_SETUP)
TEST_SETUP_EMPTY_APISCHEMA[
"api_schema_grounding_model"
] = "parlai.core.tod.tod_agents:EmptyApiSchemaAgent"
TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA = copy.deepcopy(
TEST_SETUP_BROKEN_USER_SYSTEM
)
TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA[
"api_schema_grounding_model"
] = "parlai.core.tod.tod_agents:EmptyApiSchemaAgent"
DATATYPE = "valid"
class TestTodWorldScript(tod_world_script.TodWorldScript):
"""
Wrap around it to check its logic; also makes it easier to do things w/ underlying
World.
"""
def __init__(self, opt: Opt):
opt["datatype"] = DATATYPE
# none of the below matter, but need to set to keep other code happy.
opt["log_keep_fields"] = "all"
opt["display_examples"] = False
super().__init__(opt)
def _setup_world(self):
world = super()._setup_world()
for i in range(len(world.batch_tod_world_metrics)):
world.batch_tod_world_metrics[i].handlers = [
x() for x in METRICS_HANDLER_CLASSES_TEST_REGISTRY
]
return world
def _save_outputs(self, opt, world, logger, episode_metrics):
self.world = world
self.logger = logger
self.episode_metrics = episode_metrics
class TodMetricsInScriptTests(unittest.TestCase):
def test_all_goals_hit_all_success(self):
"""
For a setup where all the goals should be successfully hit, is it?
"""
self._check_all_goals_hit_by_opt_and_batchsize(
TEST_SETUP, batchsize=1, num_episodes=1, target_all_goals_hit=1
)
self._check_all_goals_hit_by_opt_and_batchsize(
TEST_SETUP, batchsize=1, num_episodes=32, target_all_goals_hit=1
)
self._check_all_goals_hit_by_opt_and_batchsize(
TEST_SETUP, batchsize=32, num_episodes=8, target_all_goals_hit=1
)
self._check_all_goals_hit_by_opt_and_batchsize(
TEST_SETUP, batchsize=32, num_episodes=33, target_all_goals_hit=1
)
self._check_all_goals_hit_by_opt_and_batchsize(
TEST_SETUP,
batchsize=32,
num_episodes=-1,
target_all_goals_hit=1,
target_metrics_length=NUM_EPISODES,
)
def test_all_goals_hit_all_fail(self):
"""
For a setup where all the goals should *not* be successfully hit, do they fail?
"""
self._check_all_goals_hit_by_opt_and_batchsize(
TEST_SETUP_BROKEN_USER_SYSTEM,
batchsize=1,
num_episodes=1,
target_all_goals_hit=0,
)
self._check_all_goals_hit_by_opt_and_batchsize(
TEST_SETUP_BROKEN_USER_SYSTEM,
batchsize=1,
num_episodes=32,
target_all_goals_hit=0,
)
self._check_all_goals_hit_by_opt_and_batchsize(
TEST_SETUP_BROKEN_USER_SYSTEM,
batchsize=32,
num_episodes=32,
target_all_goals_hit=0,
)
self._check_all_goals_hit_by_opt_and_batchsize(
TEST_SETUP_BROKEN_USER_SYSTEM,
batchsize=32,
num_episodes=33,
target_all_goals_hit=0,
)
self._check_all_goals_hit_by_opt_and_batchsize(
TEST_SETUP_BROKEN_USER_SYSTEM,
batchsize=32,
num_episodes=-1,
target_all_goals_hit=0,
target_metrics_length=NUM_EPISODES,
)
def test_all_goals_hit_all_success_emptySchema(self):
"""
Check to make sure empty API schema doesn't have any impact on goal (Necessary
cause original, more exhaustive implementation of goal success would separate
between required + optional opts using the schema; make sure it doesn't impact
anything broader)
"""
self._check_all_goals_hit_by_opt_and_batchsize(
TEST_SETUP_EMPTY_APISCHEMA,
batchsize=1,
num_episodes=1,
target_all_goals_hit=1,
)
self._check_all_goals_hit_by_opt_and_batchsize(
TEST_SETUP_EMPTY_APISCHEMA,
batchsize=1,
num_episodes=32,
target_all_goals_hit=1,
)
self._check_all_goals_hit_by_opt_and_batchsize(
TEST_SETUP_EMPTY_APISCHEMA,
batchsize=32,
num_episodes=32,
target_all_goals_hit=1,
)
self._check_all_goals_hit_by_opt_and_batchsize(
TEST_SETUP_EMPTY_APISCHEMA,
batchsize=32,
num_episodes=33,
target_all_goals_hit=1,
)
self._check_all_goals_hit_by_opt_and_batchsize(
TEST_SETUP_EMPTY_APISCHEMA,
batchsize=32,
num_episodes=-1,
target_all_goals_hit=1,
target_metrics_length=NUM_EPISODES,
)
def test_all_goals_hit_all_fail_emptySchema(self):
"""
Make sure empty schema has no impact on goal success.
(Necessary cause original, more exhaustive implementation of goal success would
separate between required + optional opts using the schema; make sure it doesn't
impact anything broader)
"""
self._check_all_goals_hit_by_opt_and_batchsize(
TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA,
batchsize=1,
num_episodes=1,
target_all_goals_hit=0,
)
self._check_all_goals_hit_by_opt_and_batchsize(
TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA,
batchsize=1,
num_episodes=32,
target_all_goals_hit=0,
)
self._check_all_goals_hit_by_opt_and_batchsize(
TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA,
batchsize=32,
num_episodes=32,
target_all_goals_hit=0,
)
self._check_all_goals_hit_by_opt_and_batchsize(
TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA,
batchsize=32,
num_episodes=33,
target_all_goals_hit=0,
)
self._check_all_goals_hit_by_opt_and_batchsize(
TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA,
batchsize=32,
num_episodes=-1,
target_all_goals_hit=0,
target_metrics_length=NUM_EPISODES,
)
def _check_all_goals_hit_by_opt_and_batchsize(
self,
opt,
batchsize,
num_episodes,
target_all_goals_hit,
target_metrics_length=None,
):
opt = copy.deepcopy(opt)
opt["batchsize"] = batchsize
opt["num_episodes"] = num_episodes
report, metrics = self._run_opt_get_report(opt)
self.assertEqual(report.get("all_goals_hit"), target_all_goals_hit)
metrics_comp_length = num_episodes
if target_metrics_length:
metrics_comp_length = target_metrics_length
self.assertEqual(len(metrics), metrics_comp_length)
def _run_opt_get_report(self, opt):
script = TestTodWorldScript(opt)
script.run()
def get_episode_report(goal, episode_metric):
metrics_dict = dict_report(episode_metric.report())
metrics_dict["goal"] = goal
return metrics_dict
return (
dict_report(script.world.report()),
[get_episode_report(g, e) for g, e in script.episode_metrics],
)
def test_apiCallAttempts_usingGold(self):
opt = copy.deepcopy(TEST_SETUP)
opt["batchsize"] = 1
opt["num_episodes"] = -1
_, metrics = self._run_opt_get_report(opt)
for metric in metrics:
self.assertEqual(
len(
SerializationHelpers.str_to_goals(
metric["goal"]["text"][len("GOALS: ") :]
)
),
metric["call_attempts"],
)
if __name__ == "__main__":
unittest.main()