Skip to content

Commit

Permalink
[examples/translation] support mBART-50 and M2M100 fine-tuning (huggi…
Browse files Browse the repository at this point in the history
…ngface#11170)

* keep a list of multilingual tokenizers

* add forced_bos_token argument
  • Loading branch information
patil-suraj authored Apr 9, 2021
1 parent fb41f9f commit c161dd5
Showing 1 changed file with 29 additions and 8 deletions.
37 changes: 29 additions & 8 deletions examples/seq2seq/run_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser,
M2M100Tokenizer,
MBart50Tokenizer,
MBart50TokenizerFast,
MBartTokenizer,
MBartTokenizerFast,
Seq2SeqTrainer,
Expand All @@ -50,6 +53,9 @@

logger = logging.getLogger(__name__)

# A list of all multilingual tokenizer which require src_lang and tgt_lang attributes.
MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast, M2M100Tokenizer]


@dataclass
class ModelArguments:
Expand Down Expand Up @@ -191,6 +197,14 @@ class DataTrainingArguments:
source_prefix: Optional[str] = field(
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
)
forced_bos_token: Optional[str] = field(
default=None,
metadata={
"help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`."
"Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token "
"needs to be the target language token.(Usually it is the target language token)"
},
)

def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
Expand Down Expand Up @@ -325,9 +339,6 @@ def main():

# Set decoder_start_token_id
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
assert (
data_args.target_lang is not None and data_args.source_lang is not None
), "mBart requires --target_lang and --source_lang"
if isinstance(tokenizer, MBartTokenizer):
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
else:
Expand All @@ -352,11 +363,21 @@ def main():

# For translation we set the codes of our source and target languages (only useful for mBART, the others will
# ignore those attributes).
if isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
if data_args.source_lang is not None:
tokenizer.src_lang = data_args.source_lang
if data_args.target_lang is not None:
tokenizer.tgt_lang = data_args.target_lang
if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)):
assert data_args.target_lang is not None and data_args.source_lang is not None, (
f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --source_lang and "
"--target_lang arguments."
)

tokenizer.src_lang = data_args.source_lang
tokenizer.tgt_lang = data_args.target_lang

# For multilingual translation models like mBART-50 and M2M100 we need to force the target language token
# as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument.
forced_bos_token_id = (
tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None
)
model.config.foced_bos_token_id = forced_bos_token_id

# Get the language codes for input/target.
source_lang = data_args.source_lang.split("_")[0]
Expand Down

0 comments on commit c161dd5

Please sign in to comment.