Skip to content

Commit

Permalink
Preserve spaces in GPT-2 tokenizers (huggingface#2778)
Browse files Browse the repository at this point in the history
* Preserve spaces in GPT-2 tokenizers

Preserves spaces after special tokens in GPT-2 and inhereted (RoBERTa)
tokenizers, enabling correct BPE encoding. Automatically inserts a space
in front of first token in encode function when adding special tokens.

* Add tokenization preprocessing method

* Add framework argument to pipeline factory

Also fixes pipeline test issue. Each test input now treated as a
distinct sequence.
  • Loading branch information
joeddav authored Feb 13, 2020
1 parent 0ed630f commit f1e8a51
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 41 deletions.
3 changes: 2 additions & 1 deletion src/transformers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,7 @@ def pipeline(
config: Optional[Union[str, PretrainedConfig]] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
modelcard: Optional[Union[str, ModelCard]] = None,
framework: Optional[str] = None,
**kwargs
) -> Pipeline:
"""
Expand All @@ -1021,7 +1022,7 @@ def pipeline(
if task not in SUPPORTED_TASKS:
raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys())))

framework = get_framework(model)
framework = framework or get_framework(model)

targeted_task = SUPPORTED_TASKS[task]
task, model_class = targeted_task["impl"], targeted_task[framework]
Expand Down
16 changes: 7 additions & 9 deletions src/transformers/tokenization_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,8 @@ def bpe(self, token):
self.cache[token] = word
return word

def _tokenize(self, text, add_prefix_space=False):
""" Tokenize a string.
Args:
- add_prefix_space (boolean, default False):
Begin the sentence with at least one space to get invariance to word order in GPT-2 (and RoBERTa) tokenizers.
"""
if add_prefix_space:
text = " " + text

def _tokenize(self, text):
""" Tokenize a string. """
bpe_tokens = []
for token in re.findall(self.pat, text):
token = "".join(
Expand Down Expand Up @@ -248,6 +241,11 @@ def save_vocabulary(self, save_directory):

return vocab_file, merge_file

def prepare_for_tokenization(self, text, **kwargs):
if "add_prefix_space" in kwargs and kwargs["add_prefix_space"]:
return " " + text
return text


class GPT2TokenizerFast(PreTrainedTokenizerFast):
vocab_files_names = VOCAB_FILES_NAMES
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/tokenization_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,12 @@ def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]

def prepare_for_tokenization(self, text, add_special_tokens=False, **kwargs):
if "add_prefix_space" in kwargs:
add_prefix_space = kwargs["add_prefix_space"]
else:
add_prefix_space = add_special_tokens
if add_prefix_space and not text[0].isspace():
text = " " + text
return text
22 changes: 17 additions & 5 deletions src/transformers/tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,9 +662,12 @@ def tokenize(self, text, **kwargs):
Take care of added tokens.
text: The sequence to be encoded.
**kwargs: passed to the child `self.tokenize()` method
add_prefix_space: Only applies to GPT-2 and RoBERTa tokenizers. When `True`, this ensures that the sequence
begins with an empty space. False by default except for when using RoBERTa with `add_special_tokens=True`.
**kwargs: passed to the `prepare_for_tokenization` preprocessing method.
"""
all_special_tokens = self.all_special_tokens
text = self.prepare_for_tokenization(text, **kwargs)

def lowercase_text(t):
# convert non-special tokens to lowercase
Expand All @@ -679,7 +682,7 @@ def split_on_token(tok, text):
result = []
split_text = text.split(tok)
for i, sub_text in enumerate(split_text):
sub_text = sub_text.strip()
sub_text = sub_text.rstrip()
if i == 0 and not sub_text:
result += [tok]
elif i == len(split_text) - 1:
Expand All @@ -697,7 +700,7 @@ def split_on_tokens(tok_list, text):
if not text.strip():
return []
if not tok_list:
return self._tokenize(text, **kwargs)
return self._tokenize(text)

tokenized_text = []
text_list = [text]
Expand All @@ -713,7 +716,7 @@ def split_on_tokens(tok_list, text):
return list(
itertools.chain.from_iterable(
(
self._tokenize(token, **kwargs) if token not in self.unique_added_tokens_encoder else [token]
self._tokenize(token) if token not in self.unique_added_tokens_encoder else [token]
for token in tokenized_text
)
)
Expand Down Expand Up @@ -802,6 +805,8 @@ def encode(
Defaults to False: no padding.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
add_prefix_space: Only applies to GPT-2 and RoBERTa tokenizers. When `True`, this ensures that the sequence
begins with an empty space. False by default except for when using RoBERTa with `add_special_tokens=True`.
**kwargs: passed to the `self.tokenize()` method
"""
encoded_inputs = self.encode_plus(
Expand Down Expand Up @@ -865,6 +870,8 @@ def encode_plus(
Defaults to False: no padding.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
add_prefix_space: Only applies to GPT-2 and RoBERTa tokenizers. When `True`, this ensures that the sequence
begins with an empty space. False by default except for when using RoBERTa with `add_special_tokens=True`.
return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default True).
return_attention_mask: (optional) Set to False to avoid returning attention mask (default True)
return_overflowing_tokens: (optional) Set to True to return overflowing token information (default False).
Expand Down Expand Up @@ -895,7 +902,8 @@ def encode_plus(

def get_input_ids(text):
if isinstance(text, str):
return self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
tokens = self.tokenize(text, add_special_tokens=add_special_tokens, **kwargs)
return self.convert_tokens_to_ids(tokens)
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
return self.convert_tokens_to_ids(text)
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
Expand Down Expand Up @@ -1215,6 +1223,10 @@ def prepare_for_model(

return encoded_inputs

def prepare_for_tokenization(self, text, **kwargs):
""" Performs any necessary transformations before tokenization """
return text

def truncate_sequences(
self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy="longest_first", stride=0
):
Expand Down
40 changes: 24 additions & 16 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _test_mono_column_pipeline(
for key in output_keys:
self.assertIn(key, mono_result[0])

multi_result = nlp(valid_inputs)
multi_result = [nlp(input) for input in valid_inputs]
self.assertIsInstance(multi_result, list)
self.assertIsInstance(multi_result[0], (dict, list))

Expand Down Expand Up @@ -129,7 +129,7 @@ def test_tf_ner(self):
valid_inputs = ["HuggingFace is solving NLP one commit at a time.", "HuggingFace is based in New-York & Paris"]
invalid_inputs = [None]
for tokenizer, model, config in TF_NER_FINETUNED_MODELS:
nlp = pipeline(task="ner", model=model, config=config, tokenizer=tokenizer)
nlp = pipeline(task="ner", model=model, config=config, tokenizer=tokenizer, framework="tf")
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)

@require_torch
Expand All @@ -147,7 +147,7 @@ def test_tf_sentiment_analysis(self):
valid_inputs = ["HuggingFace is solving NLP one commit at a time.", "HuggingFace is based in New-York & Paris"]
invalid_inputs = [None]
for tokenizer, model, config in TF_TEXT_CLASSIF_FINETUNED_MODELS:
nlp = pipeline(task="sentiment-analysis", model=model, config=config, tokenizer=tokenizer)
nlp = pipeline(task="sentiment-analysis", model=model, config=config, tokenizer=tokenizer, framework="tf")
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)

@require_torch
Expand All @@ -163,7 +163,7 @@ def test_tf_feature_extraction(self):
valid_inputs = ["HuggingFace is solving NLP one commit at a time.", "HuggingFace is based in New-York & Paris"]
invalid_inputs = [None]
for tokenizer, model, config in TF_FEATURE_EXTRACT_FINETUNED_MODELS:
nlp = pipeline(task="feature-extraction", model=model, config=config, tokenizer=tokenizer)
nlp = pipeline(task="feature-extraction", model=model, config=config, tokenizer=tokenizer, framework="tf")
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, {})

@require_torch
Expand All @@ -176,14 +176,18 @@ def test_fill_mask(self):
invalid_inputs = [None]
expected_multi_result = [
[
{"score": 0.008698059245944023, "sequence": "<s>My name is John</s>", "token": 610},
{"score": 0.007750614080578089, "sequence": "<s>My name is Chris</s>", "token": 1573},
{"sequence": "<s> My name is:</s>", "score": 0.009954338893294334, "token": 35},
{"sequence": "<s> My name is John</s>", "score": 0.0080940006300807, "token": 610},
],
[
{"score": 0.2721288502216339, "sequence": "<s>The largest city in France is Paris</s>", "token": 2201},
{
"score": 0.19764970242977142,
"sequence": "<s>The largest city in France is Lyon</s>",
"sequence": "<s> The largest city in France is Paris</s>",
"score": 0.3185044229030609,
"token": 2201,
},
{
"sequence": "<s> The largest city in France is Lyon</s>",
"score": 0.21112334728240967,
"token": 12790,
},
],
Expand All @@ -209,20 +213,24 @@ def test_tf_fill_mask(self):
invalid_inputs = [None]
expected_multi_result = [
[
{"score": 0.008698059245944023, "sequence": "<s>My name is John</s>", "token": 610},
{"score": 0.007750614080578089, "sequence": "<s>My name is Chris</s>", "token": 1573},
{"sequence": "<s> My name is:</s>", "score": 0.009954338893294334, "token": 35},
{"sequence": "<s> My name is John</s>", "score": 0.0080940006300807, "token": 610},
],
[
{"score": 0.2721288502216339, "sequence": "<s>The largest city in France is Paris</s>", "token": 2201},
{
"score": 0.19764970242977142,
"sequence": "<s>The largest city in France is Lyon</s>",
"sequence": "<s> The largest city in France is Paris</s>",
"score": 0.3185044229030609,
"token": 2201,
},
{
"sequence": "<s> The largest city in France is Lyon</s>",
"score": 0.21112334728240967,
"token": 12790,
},
],
]
for tokenizer, model, config in TF_FILL_MASK_FINETUNED_MODELS:
nlp = pipeline(task="fill-mask", model=model, config=config, tokenizer=tokenizer, topk=2)
nlp = pipeline(task="fill-mask", model=model, config=config, tokenizer=tokenizer, framework="tf", topk=2)
self._test_mono_column_pipeline(
nlp,
valid_inputs,
Expand Down Expand Up @@ -293,5 +301,5 @@ def test_tf_question_answering(self):
]

for tokenizer, model, config in TF_QA_FINETUNED_MODELS:
nlp = pipeline(task="question-answering", model=model, config=config, tokenizer=tokenizer)
nlp = pipeline(task="question-answering", model=model, config=config, tokenizer=tokenizer, framework="tf")
self._test_multicolumn_pipeline(nlp, valid_samples, invalid_samples, mandatory_output_keys)
54 changes: 44 additions & 10 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def test_add_special_tokens(self):
encoded = tokenizer.encode(text, add_special_tokens=False)

input_encoded = tokenizer.encode(input_text, add_special_tokens=False)
output_encoded = tokenizer.encode(output_text, add_special_tokens=False)
output_encoded = tokenizer.encode(" " + output_text, add_special_tokens=False)
special_token_id = tokenizer.encode(special_token, add_special_tokens=False)
assert encoded == input_encoded + special_token_id + output_encoded

Expand Down Expand Up @@ -264,7 +264,7 @@ def test_number_of_added_tokens(self):
seq_1 = "With these inputs."

sequences = tokenizer.encode(seq_0, seq_1, add_special_tokens=False)
attached_sequences = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
attached_sequences = tokenizer.encode(seq_0, seq_1, add_special_tokens=True, add_prefix_space=False)

# Method is implemented (e.g. not GPT-2)
if len(attached_sequences) != 2:
Expand All @@ -280,7 +280,12 @@ def test_maximum_encoding_length_single_input(self):
num_added_tokens = tokenizer.num_added_tokens()
total_length = len(sequence) + num_added_tokens
information = tokenizer.encode_plus(
seq_0, max_length=total_length - 2, add_special_tokens=True, stride=stride, return_overflowing_tokens=True,
seq_0,
max_length=total_length - 2,
add_special_tokens=True,
stride=stride,
return_overflowing_tokens=True,
add_prefix_space=False,
)

truncated_sequence = information["input_ids"]
Expand All @@ -301,7 +306,7 @@ def test_maximum_encoding_length_pair_input(self):
sequence_0_no_special_tokens = tokenizer.encode(seq_0, add_special_tokens=False)
sequence_1_no_special_tokens = tokenizer.encode(seq_1, add_special_tokens=False)

sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True, add_prefix_space=False)
truncated_second_sequence = tokenizer.build_inputs_with_special_tokens(
tokenizer.encode(seq_0, add_special_tokens=False), tokenizer.encode(seq_1, add_special_tokens=False)[:-2],
)
Expand All @@ -314,6 +319,7 @@ def test_maximum_encoding_length_pair_input(self):
stride=stride,
truncation_strategy="only_second",
return_overflowing_tokens=True,
add_prefix_space=False,
)
information_first_truncated = tokenizer.encode_plus(
seq_0,
Expand All @@ -323,6 +329,7 @@ def test_maximum_encoding_length_pair_input(self):
stride=stride,
truncation_strategy="only_first",
return_overflowing_tokens=True,
add_prefix_space=False,
)

truncated_sequence = information["input_ids"]
Expand All @@ -342,11 +349,39 @@ def test_encode_input_type(self):

tokens = tokenizer.tokenize(sequence)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
formatted_input = tokenizer.encode(sequence, add_special_tokens=True)
formatted_input = tokenizer.encode(sequence, add_special_tokens=True, add_prefix_space=False)

self.assertEqual(tokenizer.encode(tokens, add_special_tokens=True), formatted_input)
self.assertEqual(tokenizer.encode(input_ids, add_special_tokens=True), formatted_input)

def test_swap_special_token(self):
tokenizer = self.get_tokenizer()

mask = "<mask>"
sequence = "Encode this sequence"
sequence_masked_0 = "Encode <mask> sequence"
sequence_masked_1 = "<mask> this sequence"

# Add tokens so that masked token isn't split
tokenizer.add_tokens(sequence.split())
tokenizer.add_special_tokens({"mask_token": mask})
mask_ind = tokenizer.convert_tokens_to_ids(mask)
encoded = tokenizer.encode(sequence, add_special_tokens=False)

# Test first masked sequence
encoded_masked = tokenizer.encode(sequence_masked_0, add_special_tokens=False)
mask_loc = encoded_masked.index(mask_ind)
encoded_masked[mask_loc] = encoded[mask_loc]

self.assertEqual(encoded_masked, encoded)

# Test second masked sequence
encoded_masked = tokenizer.encode(sequence_masked_1, add_special_tokens=False)
mask_loc = encoded_masked.index(mask_ind)
encoded_masked[mask_loc] = encoded[mask_loc]

self.assertEqual(encoded_masked, encoded)

def test_special_tokens_mask(self):
tokenizer = self.get_tokenizer()

Expand All @@ -356,7 +391,7 @@ def test_special_tokens_mask(self):
# Testing single inputs
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
encoded_sequence_dict = tokenizer.encode_plus(
sequence_0, add_special_tokens=True, return_special_tokens_mask=True
sequence_0, add_special_tokens=True, return_special_tokens_mask=True, add_prefix_space=False
)
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
Expand All @@ -369,11 +404,10 @@ def test_special_tokens_mask(self):
self.assertEqual(encoded_sequence, filtered_sequence)

# Testing inputs pairs
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False) + tokenizer.encode(
sequence_1, add_special_tokens=False
)
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
encoded_sequence += tokenizer.encode(sequence_1, add_special_tokens=False)
encoded_sequence_dict = tokenizer.encode_plus(
sequence_0, sequence_1, add_special_tokens=True, return_special_tokens_mask=True
sequence_0, sequence_1, add_special_tokens=True, return_special_tokens_mask=True, add_prefix_space=False
)
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
Expand Down
Loading

0 comments on commit f1e8a51

Please sign in to comment.