forked from huggingface/transformers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
check_tokenizers.py
171 lines (137 loc) · 5.4 KB
/
check_tokenizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from collections import Counter
import datasets
import transformers
from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS
from transformers.utils import logging
logging.set_verbosity_info()
TOKENIZER_CLASSES = {
name: (getattr(transformers, name), getattr(transformers, name + "Fast")) for name in SLOW_TO_FAST_CONVERTERS
}
dataset = datasets.load_dataset("xnli", split="test+validation")
total = 0
perfect = 0
imperfect = 0
wrong = 0
def check_diff(spm_diff, tok_diff, slow, fast):
if spm_diff == list(reversed(tok_diff)):
# AAA -> AA+A vs A+AA case.
return True
elif len(spm_diff) == len(tok_diff) and fast.decode(spm_diff) == fast.decode(tok_diff):
# Second order OK
# Barrich -> Barr + ich vs Bar + rich
return True
spm_reencoded = slow.encode(slow.decode(spm_diff))
tok_reencoded = fast.encode(fast.decode(spm_diff))
if spm_reencoded != spm_diff and spm_reencoded == tok_reencoded:
# Type 3 error.
# Snehagatha ->
# Sne, h, aga, th, a
# Sne, ha, gat, ha
# Encoding the wrong with sp does not even recover what spm gave us
# It fits tokenizer however...
return True
return False
def check_LTR_mark(line, idx, fast):
enc = fast.encode_plus(line)[0]
offsets = enc.offsets
curr, prev = offsets[idx], offsets[idx - 1]
if curr is not None and line[curr[0] : curr[1]] == "\u200f":
return True
if prev is not None and line[prev[0] : prev[1]] == "\u200f":
return True
def check_details(line, spm_ids, tok_ids, slow, fast):
# Encoding can be the same with same result AAA -> A + AA vs AA + A
# We can check that we use at least exactly the same number of tokens.
for i, (spm_id, tok_id) in enumerate(zip(spm_ids, tok_ids)):
if spm_id != tok_id:
break
first = i
for i, (spm_id, tok_id) in enumerate(zip(reversed(spm_ids), reversed(tok_ids))):
if spm_id != tok_id:
break
last = len(spm_ids) - i
spm_diff = spm_ids[first:last]
tok_diff = tok_ids[first:last]
if check_diff(spm_diff, tok_diff, slow, fast):
return True
if check_LTR_mark(line, first, fast):
return True
if last - first > 5:
# We might have twice a single problem, attempt to subdivide the disjointed tokens into smaller problems
spms = Counter(spm_ids[first:last])
toks = Counter(tok_ids[first:last])
removable_tokens = {spm_ for (spm_, si) in spms.items() if toks.get(spm_, 0) == si}
min_width = 3
for i in range(last - first - min_width):
if all(spm_ids[first + i + j] in removable_tokens for j in range(min_width)):
possible_matches = [
k
for k in range(last - first - min_width)
if tok_ids[first + k : first + k + min_width] == spm_ids[first + i : first + i + min_width]
]
for j in possible_matches:
if check_diff(spm_ids[first : first + i], tok_ids[first : first + j], sp, tok) and check_details(
line,
spm_ids[first + i : last],
tok_ids[first + j : last],
slow,
fast,
):
return True
print(f"Spm: {[fast.decode([spm_ids[i]]) for i in range(first, last)]}")
try:
print(f"Tok: {[fast.decode([tok_ids[i]]) for i in range(first, last)]}")
except Exception:
pass
fast.decode(spm_ids[:first])
fast.decode(spm_ids[last:])
wrong = fast.decode(spm_ids[first:last])
print()
print(wrong)
return False
def test_string(slow, fast, text):
global perfect
global imperfect
global wrong
global total
slow_ids = slow.encode(text)
fast_ids = fast.encode(text)
skip_assert = False
total += 1
if slow_ids != fast_ids:
if check_details(text, slow_ids, fast_ids, slow, fast):
skip_assert = True
imperfect += 1
else:
wrong += 1
else:
perfect += 1
if total % 10000 == 0:
print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})")
if skip_assert:
return
assert (
slow_ids == fast_ids
), f"line {text} : \n\n{slow_ids}\n{fast_ids}\n\n{slow.tokenize(text)}\n{fast.tokenize(text)}"
def test_tokenizer(slow, fast):
global batch_total
for i in range(len(dataset)):
# premise, all languages
for text in dataset[i]["premise"].values():
test_string(slow, fast, text)
# hypothesis, all languages
for text in dataset[i]["hypothesis"]["translation"]:
test_string(slow, fast, text)
if __name__ == "__main__":
for name, (slow_class, fast_class) in TOKENIZER_CLASSES.items():
checkpoint_names = list(slow_class.max_model_input_sizes.keys())
for checkpoint in checkpoint_names:
imperfect = 0
perfect = 0
wrong = 0
total = 0
print(f"========================== Checking {name}: {checkpoint} ==========================")
slow = slow_class.from_pretrained(checkpoint, force_download=True)
fast = fast_class.from_pretrained(checkpoint, force_download=True)
test_tokenizer(slow, fast)
print(f"Accuracy {perfect * 100 / total:.2f}")