Skip to content

Commit

Permalink
Use DataCollatorForSeq2Seq in run_summarization in all cases (hugging…
Browse files Browse the repository at this point in the history
…face#10856)

Co-authored-by: Eliza <[email protected]>
  • Loading branch information
elsanns and Eliza authored Mar 22, 2021
1 parent a8d4d67 commit 9f8fa4e
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions examples/seq2seq/run_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
HfArgumentParser,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
default_data_collator,
set_seed,
)
from transformers.file_utils import is_offline_mode
Expand Down Expand Up @@ -466,15 +465,12 @@ def preprocess_function(examples):

# Data collator
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
if data_args.pad_to_max_length:
data_collator = default_data_collator
else:
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=label_pad_token_id,
pad_to_multiple_of=8 if training_args.fp16 else None,
)
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=label_pad_token_id,
pad_to_multiple_of=8 if training_args.fp16 else None,
)

# Metric
metric = load_metric("rouge")
Expand Down

0 comments on commit 9f8fa4e

Please sign in to comment.