forked from facebookresearch/MUSE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdictionary.py
72 lines (59 loc) · 1.83 KB
/
dictionary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from logging import getLogger
logger = getLogger()
class Dictionary(object):
def __init__(self, id2word, word2id, lang):
assert len(id2word) == len(word2id)
self.id2word = id2word
self.word2id = word2id
self.lang = lang
self.check_valid()
def __len__(self):
"""
Returns the number of words in the dictionary.
"""
return len(self.id2word)
def __getitem__(self, i):
"""
Returns the word of the specified index.
"""
return self.id2word[i]
def __contains__(self, w):
"""
Returns whether a word is in the dictionary.
"""
return w in self.word2id
def __eq__(self, y):
"""
Compare the dictionary with another one.
"""
self.check_valid()
y.check_valid()
if len(self.id2word) != len(y):
return False
return self.lang == y.lang and all(self.id2word[i] == y[i] for i in range(len(y)))
def check_valid(self):
"""
Check that the dictionary is valid.
"""
assert len(self.id2word) == len(self.word2id)
for i in range(len(self.id2word)):
assert self.word2id[self.id2word[i]] == i
def index(self, word):
"""
Returns the index of the specified word.
"""
return self.word2id[word]
def prune(self, max_vocab):
"""
Limit the vocabulary size.
"""
assert max_vocab >= 1
self.id2word = {k: v for k, v in self.id2word.items() if k < max_vocab}
self.word2id = {v: k for k, v in self.id2word.items()}
self.check_valid()