Skip to content

Commit 6cfa408

Browse files
siyuan.yangsophon-leevi
siyuan.yang
authored andcommitted
fix(Qwen): add extra end of text id
Change-Id: Iad7edabcdb9e21d89bbd50e626d6d88222936145 (cherry picked from commit 16dbf8154427a142ac74471a49adff6d9be28af4)
1 parent 987121d commit 6cfa408

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

sample/Qwen/python/qwen.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ def __init__(self, bmodel_path, dev_ids, tokenizer_path) -> None:
2020
self.version = "1.1.0"
2121

2222
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
23-
self.EOS = self.tokenizer.eos_token_id
23+
ID_IM_END = self.tokenizer.convert_tokens_to_ids("<|im_end|>")
24+
ID_END = self.tokenizer.convert_tokens_to_ids("<|end|>")
25+
EOF = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
26+
self.EOS = [self.tokenizer.eos_token_id, ID_IM_END, ID_END, EOF]
2427
self.dev_ids = [int(x) for x in str(dev_ids).split(',')]
2528
self.handles = {dev: sail.Handle(dev) for dev in self.dev_ids}
2629
self.target = sail.Handle(self.dev_ids[0]).get_target()
@@ -345,7 +348,7 @@ def chat_stream(self, messages):
345348
first_end = time.time()
346349
full_word_tokens = []
347350
tok_num = 0
348-
while(token != self.EOS and self.token_length < self.SEQLEN):
351+
while token not in self.EOS and self.token_length < self.SEQLEN:
349352
full_word_tokens.append(token)
350353
word = self.tokenizer.decode(full_word_tokens)
351354
if "�" in word:
@@ -373,7 +376,7 @@ def chat_stream_for_api(self, params):
373376
return
374377
token = self.forward_first(tokens)
375378
full_word_tokens = []
376-
while(token != self.EOS and self.token_length < self.SEQLEN):
379+
while token not in self.EOS and self.token_length < self.SEQLEN:
377380
full_word_tokens.append(token)
378381
text = self.tokenizer.decode(full_word_tokens)
379382
if "�" in text:
@@ -397,7 +400,7 @@ def chat_for_api(self, params):
397400
return res_dict
398401
all_token = []
399402
token = self.forward_first(input_tokens)
400-
while token != self.EOS and self.token_length < self.SEQLEN:
403+
while token not in self.EOS and self.token_length < self.SEQLEN:
401404
all_token.append(token)
402405
token = self.forward_next()
403406
text = self.tokenizer.decode(all_token)

0 commit comments

Comments
 (0)