forked from facebookresearch/ParlAI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_dict.py
64 lines (53 loc) · 2.33 KB
/
test_dict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
from parlai.core.dict import find_ngrams
import unittest
class TestDictionary(unittest.TestCase):
"""Basic tests on the built-in parlai Dictionary."""
def test_find_ngrams(self):
"""Can the ngram class properly recognize uni, bi, and trigrams?"""
s = set()
s.add('hello world')
s.add('ol boy')
res = find_ngrams(s, ['hello', 'world', 'buddy', 'ol', 'boy'], 2)
assert ' '.join(res) == 'hello world buddy ol boy'
assert '-'.join(res) == 'hello world-buddy-ol boy'
s.add('world buddy ol')
res = find_ngrams(s, ['hello', 'world', 'buddy', 'ol', 'boy'], 3)
assert ' '.join(res) == 'hello world buddy ol boy'
assert '-'.join(res) == 'hello-world buddy ol-boy'
s.add('hello world buddy')
res = find_ngrams(s, ['hello', 'world', 'buddy', 'ol', 'boy'], 3)
assert ' '.join(res) == 'hello world buddy ol boy'
assert '-'.join(res) == 'hello world buddy-ol boy'
def test_basic_parse(self):
"""Check that the dictionary is correctly adding and parsing short
sentence.
"""
from parlai.core.dict import DictionaryAgent
from parlai.core.params import ParlaiParser
argparser = ParlaiParser()
DictionaryAgent.add_cmdline_args(argparser)
opt = argparser.parse_args()
dictionary = DictionaryAgent(opt)
num_builtin = len(dictionary)
dictionary.observe({'text': 'hello world'})
dictionary.act()
assert len(dictionary) - num_builtin == 2
vec = dictionary.parse('hello world')
assert len(vec) == 2
assert vec[0] == num_builtin
assert vec[1] == num_builtin + 1
vec = dictionary.parse('hello world', vec_type=list)
assert len(vec) == 2
assert vec[0] == num_builtin
assert vec[1] == num_builtin + 1
vec = dictionary.parse('hello world', vec_type=tuple)
assert len(vec) == 2
assert vec[0] == num_builtin
assert vec[1] == num_builtin + 1
if __name__ == '__main__':
unittest.main()