Skip to content

Commit

Permalink
Minor Bugfixes: Preprocessor now also works with models that do not u…
Browse files Browse the repository at this point in the history
…se token_type_ids, PVPs now also work with empty inputs.
  • Loading branch information
timoschick committed Jun 9, 2020
1 parent 4fd60b7 commit dc6cb21
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
7 changes: 5 additions & 2 deletions preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,17 @@ def get_input_features(self, example: InputExample, **kwargs) -> InputFeatures:
add_special_tokens=True,
max_length=self.wrapper.config.max_seq_length,
)
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
input_ids, token_type_ids = inputs["input_ids"], inputs.get("token_type_ids")

attention_mask = [1] * len(input_ids)
padding_length = self.wrapper.config.max_seq_length - len(input_ids)

input_ids = input_ids + ([self.wrapper.tokenizer.pad_token_id] * padding_length)
attention_mask = attention_mask + ([0] * padding_length)
token_type_ids = token_type_ids + ([0] * padding_length)
if not token_type_ids:
token_type_ids = [0] * self.wrapper.config.max_seq_length
else:
token_type_ids = token_type_ids + ([0] * padding_length)
mlm_labels = [-1] * len(input_ids)

assert len(input_ids) == self.wrapper.config.max_seq_length
Expand Down
4 changes: 2 additions & 2 deletions pvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ def encode(self, example: InputExample) -> Tuple[List[int], List[int]]:
kwargs = {'add_prefix_space': True} if isinstance(tokenizer, GPT2Tokenizer) else {}

parts_a = [x if isinstance(x, tuple) else (x, False) for x in parts_a]
parts_a = [(tokenizer.encode(x, add_special_tokens=False, **kwargs), s) for x, s in parts_a]
parts_a = [(tokenizer.encode(x, add_special_tokens=False, **kwargs), s) for x, s in parts_a if x]

if parts_b:
parts_b = [x if isinstance(x, tuple) else (x, False) for x in parts_b]
parts_b = [(tokenizer.encode(x, add_special_tokens=False, **kwargs), s) for x, s in parts_b]
parts_b = [(tokenizer.encode(x, add_special_tokens=False, **kwargs), s) for x, s in parts_b if x]

self.truncate(parts_a, parts_b, max_length=self.wrapper.config.max_seq_length)

Expand Down

0 comments on commit dc6cb21

Please sign in to comment.