Skip to content

Commit

Permalink
[ProphetNet] Correct Doc string example (huggingface#7944)
Browse files Browse the repository at this point in the history
* correct xlm prophetnet auto model and examples

* fix line-break docs
  • Loading branch information
patrickvonplaten authored Oct 21, 2020
1 parent e174bfe commit 9b6610f
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/transformers/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@
(CTRLConfig, CTRLLMHeadModel),
(ReformerConfig, ReformerModelWithLMHead),
(BertGenerationConfig, BertGenerationDecoder),
(ProphetNetConfig, XLMProphetNetForCausalLM),
(XLMProphetNetConfig, XLMProphetNetForCausalLM),
(ProphetNetConfig, ProphetNetForCausalLM),
]
)
Expand Down
19 changes: 13 additions & 6 deletions src/transformers/modeling_prophetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1931,14 +1931,21 @@ def forward(
>>> logits = outputs.logits
>>> # Model can also be used with EncoderDecoder framework
>>> from transformers import BertTokenizer, EncoderDecoderModel
>>> from transformers import BertTokenizer, EncoderDecoderModel, ProphetNetTokenizer
>>> import torch
>>> tokenizer = BertTokenizer.from_pretrained('bert-uncased-large')
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-uncased-large", "patrickvonplaten/prophetnet-decoder-clm-large-uncased")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(input_ids=inputs["input_ids"], labels=inputs["input_ids"], return_dict=True)
>>> tokenizer_enc = BertTokenizer.from_pretrained('bert-large-uncased')
>>> tokenizer_dec = ProphetNetTokenizer.from_pretrained('microsoft/prophetnet-large-uncased')
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "patrickvonplaten/prophetnet-decoder-clm-large-uncased")
>>> ARTICLE = (
... "the us state department said wednesday it had received no "
... "formal word from bolivia that it was expelling the us ambassador there "
... "but said the charges made against him are `` baseless ."
... )
>>> input_ids = tokenizer_enc(ARTICLE, return_tensors="pt").input_ids
>>> labels = tokenizer_dec("us rejects charges against its ambassador in bolivia", return_tensors="pt").input_ids
>>> outputs = model(input_ids=input_ids, decoder_input_ids=labels[:, :-1], labels=labels[:, 1:], return_dict=True)
>>> loss = outputs.loss
"""
Expand Down
19 changes: 13 additions & 6 deletions src/transformers/modeling_xlm_prophetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,21 @@ class XLMProphetNetForCausalLM(ProphetNetForCausalLM):
>>> logits = outputs.logits
>>> # Model can also be used with EncoderDecoder framework
>>> from transformers import BertTokenizer, EncoderDecoderModel
>>> from transformers import EncoderDecoderModel, XLMProphetNetTokenizer, XLMRobertaTokenizer
>>> import torch
>>> tokenizer = BertTokenizer.from_pretrained('bert-uncased-large')
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-uncased-large", "patrickvonplaten/xprophetnet-decoder-clm-large-uncased")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(input_ids=inputs["input_ids"], labels=inputs["input_ids"])
>>> tokenizer_enc = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
>>> tokenizer_dec = XLMProphetNetTokenizer.from_pretrained('microsoft/xprophetnet-large-wiki100-cased')
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("xlm-roberta-large", "patrickvonplaten/xprophetnet-decoder-clm-large-uncased")
>>> ARTICLE = (
... "the us state department said wednesday it had received no "
... "formal word from bolivia that it was expelling the us ambassador there "
... "but said the charges made against him are `` baseless ."
... )
>>> input_ids = tokenizer_enc(ARTICLE, return_tensors="pt").input_ids
>>> labels = tokenizer_dec("us rejects charges against its ambassador in bolivia", return_tensors="pt").input_ids
>>> outputs = model(input_ids=input_ids, decoder_input_ids=labels[:, :-1], labels=labels[:, 1:], return_dict=True)
>>> loss = outputs.loss
"""
Expand Down

0 comments on commit 9b6610f

Please sign in to comment.