diff --git a/codegeex/megatron/tokenizer/tokenizer.py b/codegeex/megatron/tokenizer/tokenizer.py index c92ac1f..da55e2a 100644 --- a/codegeex/megatron/tokenizer/tokenizer.py +++ b/codegeex/megatron/tokenizer/tokenizer.py @@ -15,6 +15,7 @@ """CodeGeeX tokenizers.""" +import numpy as np from abc import ABC from abc import abstractmethod @@ -22,44 +23,18 @@ from transformers import AutoTokenizer -def encode_whitespaces(text, start_extra_id: int, max_len: int): - """Encode whitespaces to extra tokens in GPT-J. +def encode_whitespaces(text: str, start_extra_id: int=10, max_len: int=10): + """Encode whitespaces to extra tokens. >>> encode_whitespaces('a\\n b\\n c', 10, 10) 'a\\n<|extratoken_10|>b\\n<|extratoken_11|>c' """ + for i in np.arange(max_len, 1, -1): + text = text.replace(" " * i, f"<|extratoken_{start_extra_id + i - 2}|>") + return text + - def push_acc_space(acc_len: int, text: str): - if acc_len == 0: - return text - if acc_len == 1: - return text + " " - assert ( - acc_len <= max_len - ), f"Max whitespace run length {max_len}, but found {acc_len}" - extra_id = start_extra_id - 2 + acc_len - extra_token = f"<|extratoken_{extra_id}|>" - return text + extra_token - - acc_len = 0 - res = "" - for ch in text: - if ch == " ": - acc_len += 1 - if acc_len == max_len: - res = push_acc_space(acc_len, res) - acc_len = 0 - else: - res = push_acc_space(acc_len, res) - acc_len = 0 - res = res + ch - - res = push_acc_space(acc_len, res) - - return res - - -def decode_whitespaces(text: str, start_extra_id: int, max_len: int): +def decode_whitespaces(text: str, start_extra_id: int=10, max_len: int=10): """Decode the whitespace-encoded strings produced by encode_whitespace. >>> text = 'a\\n b\\n c' diff --git a/codegeex/tokenizer/tokenizer.py b/codegeex/tokenizer/tokenizer.py index 1b9fa28..d2f4f48 100644 --- a/codegeex/tokenizer/tokenizer.py +++ b/codegeex/tokenizer/tokenizer.py @@ -5,7 +5,7 @@ def encode_whitespaces(text: str, start_extra_id: int, max_len: int): - """ Encode whitespaces to extra tokens in GPT-J. + """ Encode whitespaces to extra tokens. >>> encode_whitespaces('a\\n b\\n c', 10, 10) 'a\\n<|extratoken_10|>b\\n<|extratoken_11|>c'