forked from Morizeyao/GPT2-Chinese
-
Notifications
You must be signed in to change notification settings - Fork 0
/
bpe_tokenizer.py
142 lines (115 loc) · 4.21 KB
/
bpe_tokenizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
from https://github.com/openai/gpt-2/, changed for chinese
"""
import json
import os
import sentencepiece as spm
"""
SentencePiece is an unsupervised text tokenizer and detokenizer mainly for Neural Network-based text generation
systems where the vocabulary size is predetermined prior to the neural model training. SentencePiece implements
subword units (e.g., byte-pair-encoding (BPE) [Sennrich et al.]) and unigram language model [Kudo.]) with the
extension of direct training from raw sentences. SentencePiece allows us to make a purely end-to-end
system that does not depend on language-specific pre/postprocessing.
https://github.com/google/sentencepiece
pip install sentencepiece
or git clone https://github.com/google/sentencepiece.git
python setup.py install
"""
def get_pairs(word):
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class Encoder:
def __init__(self, encoder, bpe_merges):
self.encoder = encoder
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
self.max_len = 0
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
return [self.encoder.get(token, 1) for token in self.tokenize(text)]
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
return text
def tokenize(self, text):
bpe_tokens = []
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(text).split(' '))
return bpe_tokens
def convert_tokens_to_ids(self, tokens):
return [self.encoder.get(token, 1) for token in tokens]
class Encoder_SP:
def __init__(self, model_path):
self.sp = spm.SentencePieceProcessor()
self.sp.Load(model_path)
def encode(self, text):
"""
text="...."
"""
return self.sp.EncodeAsIds(text)
def decode(self, tokens):
"""
tokens=[x1,x2,...]
"""
text = [int(token) for token in tokens]
#print(text)
return self.sp.DecodeIds(text)
def tokenize(self, text):
return self.sp.EncodeAsPieces(text)
def convert_tokens_to_ids(self, tokens):
return [self.sp.PieceToId(token) for token in tokens]
def get_encoder(encoder_file, bpe_file):
#以下是为了同一个函数入兼容sentencepiece
filepath, filename = os.path.split(encoder_file)
shotname, extension = os.path.splitext(filename)
if(".model" == extension) and (bpe_file == ""):
return Encoder_SP(encoder_file)
else:
with open(encoder_file, 'r', encoding="utf-8") as f:
encoder = json.load(f)
with open(bpe_file, 'r', encoding="utf-8") as f:
bpe_data = f.read()
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
return Encoder(
encoder=encoder,
bpe_merges=bpe_merges,
)