forked from dragen1860/TensorFlow-2.x-Tutorials
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtokenizer.py
149 lines (130 loc) · 5.22 KB
/
tokenizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import unicodedata
from .bert import TOKEN_CLS, TOKEN_SEP, TOKEN_UNK
class Tokenizer(object):
def __init__(self,
token_dict,
token_cls=TOKEN_CLS,
token_sep=TOKEN_SEP,
token_unk=TOKEN_UNK,
pad_index=0,
cased=False):
"""Initialize tokenizer.
:param token_dict: A dict maps tokens to indices.
:param token_cls: The token represents classification.
:param token_sep: The token represents separator.
:param token_unk: The token represents unknown token.
:param pad_index: The index to pad.
:param cased: Whether to keep the case.
"""
self._token_dict = token_dict
self._token_cls = token_cls
self._token_sep = token_sep
self._token_unk = token_unk
self._pad_index = pad_index
self._cased = cased
@staticmethod
def _truncate(first_tokens, second_tokens=None, max_len=None):
if max_len is None:
return
if second_tokens is not None:
while True:
total_len = len(first_tokens) + len(second_tokens)
if total_len <= max_len - 3: # 3 for [CLS] .. tokens_a .. [SEP] .. tokens_b [SEP]
break
if len(first_tokens) > len(second_tokens):
first_tokens.pop()
else:
second_tokens.pop()
else:
del first_tokens[max_len - 2:] # 2 for [CLS] .. tokens .. [SEP]
def _pack(self, first_tokens, second_tokens=None):
first_packed_tokens = [self._token_cls] + first_tokens + [self._token_sep]
if second_tokens is not None:
second_packed_tokens = second_tokens + [self._token_sep]
return first_packed_tokens + second_packed_tokens, len(first_packed_tokens), len(second_packed_tokens)
else:
return first_packed_tokens, len(first_packed_tokens), 0
def _convert_tokens_to_ids(self, tokens):
unk_id = self._token_dict.get(self._token_unk)
return [self._token_dict.get(token, unk_id) for token in tokens]
def tokenize(self, first, second=None):
first_tokens = self._tokenize(first)
second_tokens = self._tokenize(second) if second is not None else None
tokens, _, _ = self._pack(first_tokens, second_tokens)
return tokens
def encode(self, first, second=None, max_len=None):
first_tokens = self._tokenize(first)
second_tokens = self._tokenize(second) if second is not None else None
self._truncate(first_tokens, second_tokens, max_len)
tokens, first_len, second_len = self._pack(first_tokens, second_tokens)
token_ids = self._convert_tokens_to_ids(tokens)
segment_ids = [0] * first_len + [1] * second_len
if max_len is not None:
pad_len = max_len - first_len - second_len
token_ids += [self._pad_index] * pad_len
segment_ids += [0] * pad_len
return token_ids, segment_ids
def _tokenize(self, text):
if not self._cased:
text = unicodedata.normalize('NFD', text)
text = ''.join([ch for ch in text if unicodedata.category(ch) != 'Mn'])
text = text.lower()
spaced = ''
for ch in text:
if self._is_punctuation(ch) or self._is_cjk_character(ch):
spaced += ' ' + ch + ' '
elif self._is_space(ch):
spaced += ' '
elif ord(ch) == 0 or ord(ch) == 0xfffd or self._is_control(ch):
continue
else:
spaced += ch
tokens = []
for word in spaced.strip().split():
tokens += self._word_piece_tokenize(word)
return tokens
def _word_piece_tokenize(self, word):
if word in self._token_dict:
return [word]
tokens = []
start, stop = 0, 0
while start < len(word):
stop = len(word)
while stop > start:
sub = word[start:stop]
if start > 0:
sub = '##' + sub
if sub in self._token_dict:
break
stop -= 1
if start == stop:
stop += 1
tokens.append(sub)
start = stop
return tokens
@staticmethod
def _is_punctuation(ch):
code = ord(ch)
return 33 <= code <= 47 or \
58 <= code <= 64 or \
91 <= code <= 96 or \
123 <= code <= 126 or \
unicodedata.category(ch).startswith('P')
@staticmethod
def _is_cjk_character(ch):
code = ord(ch)
return 0x4E00 <= code <= 0x9FFF or \
0x3400 <= code <= 0x4DBF or \
0x20000 <= code <= 0x2A6DF or \
0x2A700 <= code <= 0x2B73F or \
0x2B740 <= code <= 0x2B81F or \
0x2B820 <= code <= 0x2CEAF or \
0xF900 <= code <= 0xFAFF or \
0x2F800 <= code <= 0x2FA1F
@staticmethod
def _is_space(ch):
return ch == ' ' or ch == '\n' or ch == '\r' or ch == '\t' or \
unicodedata.category(ch) == 'Zs'
@staticmethod
def _is_control(ch):
return unicodedata.category(ch).startswith('C')