@@ -20,7 +20,10 @@ def __init__(self, bmodel_path, dev_ids, tokenizer_path) -> None:
20
20
self .version = "1.1.0"
21
21
22
22
self .tokenizer = AutoTokenizer .from_pretrained (tokenizer_path , trust_remote_code = True )
23
- self .EOS = self .tokenizer .eos_token_id
23
+ ID_IM_END = self .tokenizer .convert_tokens_to_ids ("<|im_end|>" )
24
+ ID_END = self .tokenizer .convert_tokens_to_ids ("<|end|>" )
25
+ EOF = self .tokenizer .convert_tokens_to_ids ("<|endoftext|>" )
26
+ self .EOS = [self .tokenizer .eos_token_id , ID_IM_END , ID_END , EOF ]
24
27
self .dev_ids = [int (x ) for x in str (dev_ids ).split (',' )]
25
28
self .handles = {dev : sail .Handle (dev ) for dev in self .dev_ids }
26
29
self .target = sail .Handle (self .dev_ids [0 ]).get_target ()
@@ -345,7 +348,7 @@ def chat_stream(self, messages):
345
348
first_end = time .time ()
346
349
full_word_tokens = []
347
350
tok_num = 0
348
- while ( token != self .EOS and self .token_length < self .SEQLEN ) :
351
+ while token not in self .EOS and self .token_length < self .SEQLEN :
349
352
full_word_tokens .append (token )
350
353
word = self .tokenizer .decode (full_word_tokens )
351
354
if "�" in word :
@@ -373,7 +376,7 @@ def chat_stream_for_api(self, params):
373
376
return
374
377
token = self .forward_first (tokens )
375
378
full_word_tokens = []
376
- while ( token != self .EOS and self .token_length < self .SEQLEN ) :
379
+ while token not in self .EOS and self .token_length < self .SEQLEN :
377
380
full_word_tokens .append (token )
378
381
text = self .tokenizer .decode (full_word_tokens )
379
382
if "�" in text :
@@ -397,7 +400,7 @@ def chat_for_api(self, params):
397
400
return res_dict
398
401
all_token = []
399
402
token = self .forward_first (input_tokens )
400
- while token != self .EOS and self .token_length < self .SEQLEN :
403
+ while token not in self .EOS and self .token_length < self .SEQLEN :
401
404
all_token .append (token )
402
405
token = self .forward_next ()
403
406
text = self .tokenizer .decode (all_token )
0 commit comments