Skip to content

Commit

Permalink
SKTBrain#66 fix left padding to right padding
Browse files Browse the repository at this point in the history
  • Loading branch information
haven-jeon committed Oct 28, 2021
1 parent 8be1c06 commit ed965cc
Showing 1 changed file with 69 additions and 33 deletions.
102 changes: 69 additions & 33 deletions kobert_hf/kobert_tokenizer/kobert_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,27 @@


class KoBERTTokenizer(XLNetTokenizer):
def __init__(
self,
vocab_file,
do_lower_case=False,
remove_space=True,
keep_accents=False,
bos_token="[CLS]",
eos_token="[SEP]",
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
additional_special_tokens=None,
sp_model_kwargs: Optional[Dict[str, Any]] = None,
**kwargs
) -> None:
padding_side = "right"

def __init__(self,
vocab_file,
do_lower_case=False,
remove_space=True,
keep_accents=False,
bos_token="[CLS]",
eos_token="[SEP]",
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
additional_special_tokens=None,
sp_model_kwargs: Optional[Dict[str, Any]] = None,
**kwargs) -> None:
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
mask_token = AddedToken(mask_token,
lstrip=True, rstrip=False) if isinstance(
mask_token, str) else mask_token

self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs

Expand All @@ -57,10 +59,11 @@ def __init__(
sp_model_kwargs=self.sp_model_kwargs,
**kwargs,
)

def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. An XLNet sequence has the following format:
Expand All @@ -83,12 +86,17 @@ def build_inputs_with_special_tokens(
def _tokenize(self, text: str) -> List[str]:
"""Tokenize a string."""
text = self.preprocess_text(text)
pieces = self.sp_model.encode(text, out_type=str, **self.sp_model_kwargs)
pieces = self.sp_model.encode(text,
out_type=str,
**self.sp_model_kwargs)
new_pieces = []
for piece in pieces:
if len(piece) > 1 and piece[-1] == str(",") and piece[-2].isdigit():
cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(SPIECE_UNDERLINE, ""))
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
if len(piece) > 1 and piece[-1] == str(
",") and piece[-2].isdigit():
cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(
SPIECE_UNDERLINE, ""))
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][
0] == SPIECE_UNDERLINE:
if len(cur_pieces[0]) == 1:
cur_pieces = cur_pieces[1:]
else:
Expand All @@ -100,10 +108,37 @@ def _tokenize(self, text: str) -> List[str]:

return new_pieces

def build_inputs_with_special_tokens(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. An XLNet sequence has the following format:
- single sequence: ``<cls> X <sep> ``
- pair of sequences: ``<cls> A <sep> B <sep>``
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return cls + token_ids_0 + sep
return cls + token_ids_0 + sep + token_ids_1 + sep

def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. An XLNet
sequence pair mask has the following format:
Expand All @@ -124,8 +159,9 @@ def create_token_type_ids_from_sequences(
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
sequence(s).
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 +
sep) * [1]

0 comments on commit ed965cc

Please sign in to comment.