Skip to content

Commit

Permalink
Update megatron tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
Stanislas0 committed Mar 7, 2023
1 parent ebeb850 commit 2147a88
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 34 deletions.
41 changes: 8 additions & 33 deletions codegeex/megatron/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,51 +15,26 @@

"""CodeGeeX tokenizers."""

import numpy as np
from abc import ABC
from abc import abstractmethod

from .gpt2_tokenization import GPT2Tokenizer
from transformers import AutoTokenizer


def encode_whitespaces(text, start_extra_id: int, max_len: int):
"""Encode whitespaces to extra tokens in GPT-J.
def encode_whitespaces(text: str, start_extra_id: int=10, max_len: int=10):
"""Encode whitespaces to extra tokens.
>>> encode_whitespaces('a\\n b\\n c', 10, 10)
'a\\n<|extratoken_10|>b\\n<|extratoken_11|>c'
"""
for i in np.arange(max_len, 1, -1):
text = text.replace(" " * i, f"<|extratoken_{start_extra_id + i - 2}|>")
return text


def push_acc_space(acc_len: int, text: str):
if acc_len == 0:
return text
if acc_len == 1:
return text + " "
assert (
acc_len <= max_len
), f"Max whitespace run length {max_len}, but found {acc_len}"
extra_id = start_extra_id - 2 + acc_len
extra_token = f"<|extratoken_{extra_id}|>"
return text + extra_token

acc_len = 0
res = ""
for ch in text:
if ch == " ":
acc_len += 1
if acc_len == max_len:
res = push_acc_space(acc_len, res)
acc_len = 0
else:
res = push_acc_space(acc_len, res)
acc_len = 0
res = res + ch

res = push_acc_space(acc_len, res)

return res


def decode_whitespaces(text: str, start_extra_id: int, max_len: int):
def decode_whitespaces(text: str, start_extra_id: int=10, max_len: int=10):
"""Decode the whitespace-encoded strings produced by encode_whitespace.
>>> text = 'a\\n b\\n c'
Expand Down
2 changes: 1 addition & 1 deletion codegeex/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def encode_whitespaces(text: str, start_extra_id: int, max_len: int):
""" Encode whitespaces to extra tokens in GPT-J.
""" Encode whitespaces to extra tokens.
>>> encode_whitespaces('a\\n b\\n c', 10, 10)
'a\\n<|extratoken_10|>b\\n<|extratoken_11|>c'
Expand Down

0 comments on commit 2147a88

Please sign in to comment.