Skip to content

Commit

Permalink
upadted
Browse files Browse the repository at this point in the history
Signed-off-by: Anhforth <[email protected]>
  • Loading branch information
Anhforth committed Feb 20, 2023
1 parent cd13fc5 commit 9dc6e10
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 13 deletions.
3 changes: 2 additions & 1 deletion examples/glm_blank_filling/glm_generate_samples.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright © 2022 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")

import sys
sys.path.append("/home/yanzhaodong/anhforth/FlagAI")
import torch
from flagai.model.glm_model import GLMModel
from flagai.data.tokenizer import Tokenizer
Expand Down
3 changes: 2 additions & 1 deletion examples/glm_blank_filling/glm_generate_samples_en.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright © 2022 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")

import sys
sys.path.append("/home/yanzhaodong/anhforth/FlagAI")
import torch
from flagai.model.glm_model import GLMModel
from flagai.data.tokenizer import Tokenizer
Expand Down
14 changes: 7 additions & 7 deletions examples/t5_title_generation/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
beam_size=3,
input_max_length=512,
out_max_length=100)
out_2 = predictor.predict_generate_randomsample(text,
input_max_length=512,
out_max_length=100,
repetition_penalty=1.5,
top_k=20,
top_p=0.8)
# out_2 = predictor.predict_generate_randomsample(text,
# input_max_length=512,
# out_max_length=100,
# repetition_penalty=1.5,
# top_k=20,
# top_p=0.8)

print(f"out_1 is {out_1}")
print(f"out_2 is {out_2}")
# print(f"out_2 is {out_2}")
11 changes: 10 additions & 1 deletion flagai/data/tokenizer/uni_tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(self,
add_task_mask=True,
add_decoder_mask=False,
fix_command_token=True,
pre_tokenizer=None,
**kwargs):
super().__init__(**kwargs)
if self.tokenizer_class == "wp":
Expand All @@ -75,6 +76,9 @@ def __init__(self,
if self.tokenizer_model_name.lower().startswith('glm') or self.tokenizer_model_name.lower().startswith('alm'):
add_block_symbols=True
# self.is_clip = self.tokenizer_model_name.startswith('clip')
# if self.tokenizer_model_name.startswith('t5'):
# import jieba
# self.pre_tokenizer = lambda x: jieba.cut(x, HMM=False)
self.num_tokens = self.text_tokenizer.vocab_size
with open(self.special_tokens_map, encoding='utf8') as file: dct=json.load(file)
sp_tokens = [(k.replace("_token",""),v['content']) for k,v in dct.items()]
Expand Down Expand Up @@ -590,7 +594,8 @@ def encode_plus_non_glm(
truncation=True,
max_length=None,
):

if self.tokenizer_model_name.startswith('t5'):
assert second_text is None, "t5 does not support multi-sentence encoding"
def get_input_ids(text):
tokens = self.text_tokenizer.tokenize(text)
return self.text_tokenizer.convert_tokens_to_ids(tokens)
Expand Down Expand Up @@ -753,6 +758,10 @@ def tokenize_as_tensor(self, texts):
eot_token=eot_token)

def tokenize(self, text, maxlen=None, add_spatial_tokens=False):
"""
add_spatial_token: (bool) Add cls at the front and sep at the end
max_len: Truncate the length to max_len
"""
tokens = self.text_tokenizer.tokenize(text)

if add_spatial_tokens:
Expand Down
3 changes: 2 additions & 1 deletion flagai/model/predictor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,8 +928,9 @@ def t5_random_sample(model, tokenizer, text, input_max_length, out_max_length,
TopPLogitsProcessor(top_p=top_p),
]
list_processor = ListProcessor(lp)
from tqdm import trange
with torch.no_grad():
for step in range(out_max_length):
for step in trange(out_max_length):
scores = model(**{
"input_ids": token_ids,
"decoder_input_ids": input_decoder_ids
Expand Down
4 changes: 2 additions & 2 deletions tests/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def test_tokenizer_t5(self):
self.assertEqual(tokenizer.DecodeIds([3, 7704, 3832, 656, 140, 1095]),
'fried chicken makes me happy', 'DecodeIds Error')
self.assertEqual([(v.name, k,v.Id) for k,v in tokenizer.command_token_map.items()],
[('eos', '<|endoftext|>', 32000), ('sep', '[SEP]', 32001), ('cls', '[CLS]', 32002),
('MASK', '[MASK]', 32003), ('unk', '[UNK]', 32004)])
[('eos', '[PAD]', 0), ('cls', '[CLS]', 101), ('MASK', '[MASK]', 103),
('unk', '[UNK]', 100), ('sep', '[SEP]', 102)])

# def test_tokenizer_roberta(self):
# tokenizer = Tokenizer.from_pretrained('RoBERTa-base-ch')
Expand Down

0 comments on commit 9dc6e10

Please sign in to comment.