forked from facebookresearch/ParlAI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_conversations.py
111 lines (94 loc) · 3.52 KB
/
test_conversations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import shutil
import unittest
import parlai.utils.testing as testing_utils
from parlai.utils.conversations import Conversations
from parlai.core.params import ParlaiParser
class TestConversations(unittest.TestCase):
"""
Tests Conversations utilities.
"""
def setUp(self):
self.datapath = ParlaiParser().parse_args([], print_args=False)['datapath']
self.datapath = os.path.join(self.datapath, 'tmp_conversations')
os.makedirs(self.datapath, exist_ok=True)
def test_conversations(self):
act_list = [
[
[
{'id': 'Emily', 'text': 'Hello, do you like this test?'},
{'id': 'Stephen', 'text': 'Why yes! I love this test!'},
],
[
{'id': 'Emily', 'text': 'So will you stamp this diff?'},
{'id': 'Stephen', 'text': 'Yes, I will do it right now!'},
],
],
[
[
{
'id': 'A',
'text': 'Somebody once told me the world is gonna roll me',
},
{'id': 'B', 'text': 'I aint the sharpest tool in the shed'},
],
[
{
'id': 'A',
'text': 'She was looking kind of dumb with her finger and her thumb',
},
{'id': 'B', 'text': 'In the shape of an L on her forehead'},
],
],
]
self.opt = {
'A': 'B',
'C': 'D',
'E': 'F',
}
self.convo_datapath = os.path.join(self.datapath, 'convo1')
Conversations.save_conversations(
act_list,
self.convo_datapath,
self.opt,
self_chat=False,
other_info='Blah blah blah',
)
assert os.path.exists(self.convo_datapath + '.jsonl')
assert os.path.exists(self.convo_datapath + '.metadata')
convos = Conversations(self.convo_datapath + '.jsonl')
# test conversations loaded
self.assertEqual(len(convos), 2)
# test speakers saved
speakers = {'Stephen', 'Emily', 'A', 'B'}
self.assertEqual(set(convos.metadata.speakers), speakers)
# test opt saved
for x in ['A', 'C', 'E']:
self.assertEqual(
self.opt[x], convos.metadata.opt[x],
)
# test kwargs
self.assertEqual({'other_info': 'Blah blah blah'}, convos.metadata.extra_data)
# test reading conversations
with testing_utils.capture_output() as out:
convos.read_conv_idx(0)
str_version = (
'Emily: Hello, do you like this test?\n'
'Stephen: Why yes! I love this test!\n'
'Emily: So will you stamp this diff?\n'
'Stephen: Yes, I will do it right now!\n'
)
self.assertIn(str_version, out.getvalue())
# test getting a specific turn
first = convos[0] # Conversation
self.assertEqual(first[0].id, 'Emily')
self.assertEqual(first[3].text, 'Yes, I will do it right now!')
def tearDown(self):
# remove conversations
shutil.rmtree(self.datapath)
if __name__ == '__main__':
unittest.main()