From 9c17256447b91cf8483c856cb15e95ed30ace538 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 29 May 2020 13:46:08 +0200 Subject: [PATCH] [Longformer] Multiple choice for longformer (#4645) * add multiple choice for longformer * add models to docs * adapt docstring * add test to longformer * add longformer for mc in init and modeling auto * fix tests --- docs/source/model_doc/albert.rst | 14 +++ docs/source/model_doc/longformer.rst | 15 ++++ docs/source/model_doc/roberta.rst | 7 ++ src/transformers/__init__.py | 1 + src/transformers/modeling_auto.py | 2 + src/transformers/modeling_bert.py | 24 ++--- src/transformers/modeling_longformer.py | 115 ++++++++++++++++++++++-- src/transformers/modeling_roberta.py | 18 ++-- src/transformers/modeling_tf_albert.py | 20 ++--- src/transformers/modeling_tf_bert.py | 24 ++--- src/transformers/modeling_xlnet.py | 22 ++--- tests/test_modeling_longformer.py | 28 ++++++ 12 files changed, 227 insertions(+), 63 deletions(-) diff --git a/docs/source/model_doc/albert.rst b/docs/source/model_doc/albert.rst index 8b78a336b54e..057187e3d06a 100644 --- a/docs/source/model_doc/albert.rst +++ b/docs/source/model_doc/albert.rst @@ -94,3 +94,17 @@ TFAlbertForSequenceClassification .. autoclass:: transformers.TFAlbertForSequenceClassification :members: + + +TFAlbertForMultipleChoice +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFAlbertForMultipleChoice + :members: + + +TFAlbertForQuestionAnswering +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFAlbertForQuestionAnswering + :members: diff --git a/docs/source/model_doc/longformer.rst b/docs/source/model_doc/longformer.rst index 7e8e816410ec..07d0898ccf3b 100644 --- a/docs/source/model_doc/longformer.rst +++ b/docs/source/model_doc/longformer.rst @@ -74,3 +74,18 @@ LongformerForQuestionAnswering .. autoclass:: transformers.LongformerForQuestionAnswering :members: + + +LongformerForMultipleChoice +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.LongformerForMultipleChoice + :members: + + +LongformerForTokenClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.LongformerForTokenClassification + :members: + diff --git a/docs/source/model_doc/roberta.rst b/docs/source/model_doc/roberta.rst index 07e511228a86..31b39998160c 100644 --- a/docs/source/model_doc/roberta.rst +++ b/docs/source/model_doc/roberta.rst @@ -74,6 +74,13 @@ RobertaForSequenceClassification :members: +RobertaForMultipleChoice +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.RobertaForMultipleChoice + :members: + + RobertaForTokenClassification ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 33907741f96c..6c392478bd44 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -326,6 +326,7 @@ LongformerModel, LongformerForMaskedLM, LongformerForSequenceClassification, + LongformerForMultipleChoice, LongformerForTokenClassification, LongformerForQuestionAnswering, LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP, diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index cc8604f560b3..11a8281963fd 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -104,6 +104,7 @@ from .modeling_longformer import ( LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP, LongformerForMaskedLM, + LongformerForMultipleChoice, LongformerForQuestionAnswering, LongformerForSequenceClassification, LongformerForTokenClassification, @@ -297,6 +298,7 @@ [ (CamembertConfig, CamembertForMultipleChoice), (XLMRobertaConfig, XLMRobertaForMultipleChoice), + (LongformerConfig, LongformerForMultipleChoice), (RobertaConfig, RobertaForMultipleChoice), (BertConfig, BertForMultipleChoice), (XLNetConfig, XLNetForMultipleChoice), diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index 1e31b5c402a5..29aea1b0a687 100644 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -543,7 +543,7 @@ def _init_weights(self, module): BERT_INPUTS_DOCSTRING = r""" Args: - input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`): Indices of input sequence tokens in the vocabulary. Indices can be obtained using :class:`transformers.BertTokenizer`. @@ -551,19 +551,19 @@ def _init_weights(self, module): :func:`transformers.PreTrainedTokenizer.encode_plus` for details. `What are input IDs? <../glossary.html#input-ids>`__ - attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. `What are attention masks? <../glossary.html#attention-mask>`__ - token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` corresponds to a `sentence B` token `What are token type IDs? <../glossary.html#token-type-ids>`_ - position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, config.max_position_embeddings - 1]``. @@ -632,7 +632,7 @@ def _prune_heads(self, heads_to_prune): for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -759,7 +759,7 @@ def __init__(self, config): def get_output_embeddings(self): return self.cls.predictions.decoder - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -859,7 +859,7 @@ def __init__(self, config): def get_output_embeddings(self): return self.cls.predictions.decoder - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -992,7 +992,7 @@ def __init__(self, config): self.init_weights() - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -1081,7 +1081,7 @@ def __init__(self, config): self.init_weights() - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -1177,7 +1177,7 @@ def __init__(self, config): self.init_weights() - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)")) def forward( self, input_ids=None, @@ -1278,7 +1278,7 @@ def __init__(self, config): self.init_weights() - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -1375,7 +1375,7 @@ def __init__(self, config): self.init_weights() - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index 70e5fbf903bf..5baf056a10f7 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -411,7 +411,7 @@ def forward( LONGFORMER_INPUTS_DOCSTRING = r""" Args: - input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`): Indices of input sequence tokens in the vocabulary. Indices can be obtained using :class:`transformers.LonmgformerTokenizer`. @@ -419,7 +419,7 @@ def forward( :func:`transformers.PreTrainedTokenizer.encode_plus` for details. `What are input IDs? <../glossary.html#input-ids>`__ - attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Mask to decide the attention given on each token, local attention, global attenion, or no attention (for padding tokens). Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is important for task-specific finetuning because it makes the model more flexible at representing the task. For example, @@ -431,13 +431,13 @@ def forward( ``2`` for global attention (tokens that attend to all other tokens, and all other tokens attend to them). `What are attention masks? <../glossary.html#attention-mask>`__ - token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` corresponds to a `sentence B` token `What are token type IDs? <../glossary.html#token-type-ids>`_ - position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, config.max_position_embeddings - 1]``. @@ -537,7 +537,7 @@ def _pad_to_window_size( return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds - @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -641,7 +641,7 @@ def __init__(self, config): self.init_weights() - @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -729,7 +729,7 @@ def __init__(self, config): self.longformer = LongformerModel(config) self.classifier = LongformerClassificationHead(config) - @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -866,7 +866,7 @@ def _get_question_end_index(self, input_ids): return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1] - @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids, @@ -993,7 +993,7 @@ def __init__(self, config): self.init_weights() - @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -1070,3 +1070,100 @@ def forward( outputs = (loss,) + outputs return outputs # (loss), scores, (hidden_states), (attentions) + + +@add_start_docstrings( + """Longformer Model with a multiple choice classification head on top (a linear layer on top of + the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """, + LONGFORMER_START_DOCSTRING, +) +class LongformerForMultipleChoice(BertPreTrainedModel): + config_class = LongformerConfig + pretrained_model_archive_map = LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP + base_model_prefix = "longformer" + + def __init__(self, config): + super().__init__(config) + + self.longformer = LongformerModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + self.init_weights() + + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)")) + def forward( + self, + input_ids=None, + token_type_ids=None, + attention_mask=None, + labels=None, + position_ids=None, + inputs_embeds=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for computing the multiple choice classification loss. + Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension + of the input tensors. (see `input_ids` above) + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs: + loss (:obj:`torch.FloatTensor`` of shape ``(1,)`, `optional`, returned when :obj:`labels` is provided): + Classification loss. + classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`): + `num_choices` is the second dimension of the input tensors. (see `input_ids` above). + + Classification scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + from transformers import LongformerTokenizer, LongformerForTokenClassification + import torch + + tokenizer = LongformerTokenizer.from_pretrained('longformer-base-4096') + model = LongformerForMultipleChoice.from_pretrained('longformer-base-4096') + choices = ["Hello, my dog is cute", "Hello, my cat is amazing"] + input_ids = torch.tensor([tokenizer.encode(s, add_special_tokens=True) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices + labels = torch.tensor(1).unsqueeze(0) # Batch size 1 + outputs = model(input_ids, labels=labels) + loss, classification_scores = outputs[:2] + + """ + num_choices = input_ids.shape[1] + + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) + flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + outputs = self.longformer( + flat_input_ids, + position_ids=flat_position_ids, + token_type_ids=flat_token_type_ids, + attention_mask=flat_attention_mask, + ) + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here + + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + outputs = (loss,) + outputs + + return outputs # (loss), reshaped_logits, (hidden_states), (attentions) diff --git a/src/transformers/modeling_roberta.py b/src/transformers/modeling_roberta.py index 9e1460c83009..2d085e3a8aec 100644 --- a/src/transformers/modeling_roberta.py +++ b/src/transformers/modeling_roberta.py @@ -95,7 +95,7 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): ROBERTA_INPUTS_DOCSTRING = r""" Args: - input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`): Indices of input sequence tokens in the vocabulary. Indices can be obtained using :class:`transformers.RobertaTokenizer`. @@ -103,19 +103,19 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): :func:`transformers.PreTrainedTokenizer.encode_plus` for details. `What are input IDs? <../glossary.html#input-ids>`__ - attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. `What are attention masks? <../glossary.html#attention-mask>`__ - token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` corresponds to a `sentence B` token `What are token type IDs? <../glossary.html#token-type-ids>`_ - position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, config.max_position_embeddings - 1]``. @@ -175,7 +175,7 @@ def __init__(self, config): def get_output_embeddings(self): return self.lm_head.decoder - @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -286,7 +286,7 @@ def __init__(self, config): self.roberta = RobertaModel(config) self.classifier = RobertaClassificationHead(config) - @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -379,7 +379,7 @@ def __init__(self, config): self.init_weights() - @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)")) def forward( self, input_ids=None, @@ -479,7 +479,7 @@ def __init__(self, config): self.init_weights() - @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -598,7 +598,7 @@ def __init__(self, config): self.init_weights() - @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids, diff --git a/src/transformers/modeling_tf_albert.py b/src/transformers/modeling_tf_albert.py index 186b0ae3288a..da7c3d458f7f 100644 --- a/src/transformers/modeling_tf_albert.py +++ b/src/transformers/modeling_tf_albert.py @@ -628,7 +628,7 @@ def call( ALBERT_INPUTS_DOCSTRING = r""" Args: - input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`): + input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`): Indices of input sequence tokens in the vocabulary. Indices can be obtained using :class:`transformers.AlbertTokenizer`. @@ -636,19 +636,19 @@ def call( :func:`transformers.PreTrainedTokenizer.encode_plus` for details. `What are input IDs? <../glossary.html#input-ids>`__ - attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional, defaults to :obj:`None`): + attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional, defaults to :obj:`None`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. `What are attention masks? <../glossary.html#attention-mask>`__ - token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` corresponds to a `sentence B` token `What are token type IDs? <../glossary.html#token-type-ids>`_ - position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, config.max_position_embeddings - 1]``. @@ -676,7 +676,7 @@ def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.albert = TFAlbertMainLayer(config, name="albert") - @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def call(self, inputs, **kwargs): r""" Returns: @@ -734,7 +734,7 @@ def __init__(self, config, *inputs, **kwargs): def get_output_embeddings(self): return self.albert.embeddings - @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def call(self, inputs, **kwargs): r""" Return: @@ -795,7 +795,7 @@ def __init__(self, config, *inputs, **kwargs): def get_output_embeddings(self): return self.albert.embeddings - @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def call(self, inputs, **kwargs): r""" Returns: @@ -852,7 +852,7 @@ def __init__(self, config, *inputs, **kwargs): config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" ) - @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def call(self, inputs, **kwargs): r""" Returns: @@ -908,7 +908,7 @@ def __init__(self, config, *inputs, **kwargs): config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" ) - @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def call(self, inputs, **kwargs): r""" Return: @@ -983,7 +983,7 @@ def dummy_inputs(self): """ return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)} - @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)")) def call( self, inputs, diff --git a/src/transformers/modeling_tf_bert.py b/src/transformers/modeling_tf_bert.py index b2dd660f995b..48ad5656c790 100644 --- a/src/transformers/modeling_tf_bert.py +++ b/src/transformers/modeling_tf_bert.py @@ -621,7 +621,7 @@ class TFBertPreTrainedModel(TFPreTrainedModel): BERT_INPUTS_DOCSTRING = r""" Args: - input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`): + input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`): Indices of input sequence tokens in the vocabulary. Indices can be obtained using :class:`transformers.BertTokenizer`. @@ -629,19 +629,19 @@ class TFBertPreTrainedModel(TFPreTrainedModel): :func:`transformers.PreTrainedTokenizer.encode_plus` for details. `What are input IDs? <../glossary.html#input-ids>`__ - attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. `What are attention masks? <../glossary.html#attention-mask>`__ - token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` corresponds to a `sentence B` token `What are token type IDs? <../glossary.html#token-type-ids>`__ - position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, config.max_position_embeddings - 1]``. @@ -669,7 +669,7 @@ def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.bert = TFBertMainLayer(config, name="bert") - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def call(self, inputs, **kwargs): r""" Returns: @@ -726,7 +726,7 @@ def __init__(self, config, *inputs, **kwargs): def get_output_embeddings(self): return self.bert.embeddings - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def call(self, inputs, **kwargs): r""" Return: @@ -782,7 +782,7 @@ def __init__(self, config, *inputs, **kwargs): def get_output_embeddings(self): return self.bert.embeddings - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def call(self, inputs, **kwargs): r""" Return: @@ -832,7 +832,7 @@ def __init__(self, config, *inputs, **kwargs): self.bert = TFBertMainLayer(config, name="bert") self.nsp = TFBertNSPHead(config, name="nsp___cls") - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def call(self, inputs, **kwargs): r""" Return: @@ -888,7 +888,7 @@ def __init__(self, config, *inputs, **kwargs): config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" ) - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def call(self, inputs, **kwargs): r""" Return: @@ -954,7 +954,7 @@ def dummy_inputs(self): """ return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)} - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)")) def call( self, inputs, @@ -1065,7 +1065,7 @@ def __init__(self, config, *inputs, **kwargs): config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" ) - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def call(self, inputs, **kwargs): r""" Return: @@ -1122,7 +1122,7 @@ def __init__(self, config, *inputs, **kwargs): config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" ) - @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def call(self, inputs, **kwargs): r""" Return: diff --git a/src/transformers/modeling_xlnet.py b/src/transformers/modeling_xlnet.py index 5aeb69fca0e7..9dfbae4f6f9a 100644 --- a/src/transformers/modeling_xlnet.py +++ b/src/transformers/modeling_xlnet.py @@ -506,7 +506,7 @@ def _init_weights(self, module): XLNET_INPUTS_DOCSTRING = r""" Args: - input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`): Indices of input sequence tokens in the vocabulary. Indices can be obtained using :class:`transformers.BertTokenizer`. @@ -514,7 +514,7 @@ def _init_weights(self, module): :func:`transformers.PreTrainedTokenizer.encode_plus` for details. `What are input IDs? <../glossary.html#input-ids>`__ - attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. @@ -535,13 +535,13 @@ def _init_weights(self, module): Mask to indicate the output tokens to use. If ``target_mapping[k, i, j] = 1``, the i-th predict in batch k is on the j-th token. Only used during pretraining for partial prediction or for sequential decoding (generation). - token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` corresponds to a `sentence B` token. The classifier token should be represented by a ``2``. `What are token type IDs? <../glossary.html#token-type-ids>`_ - input_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): + input_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Mask to avoid performing attention on padding token indices. Negative of `attention_mask`, i.e. with 0 for real tokens and 1 for padding. Kept for compatibility with the original code base. @@ -688,7 +688,7 @@ def relative_positional_encoding(self, qlen, klen, bsz=None): pos_emb = pos_emb.to(self.device) return pos_emb - @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -971,7 +971,7 @@ def prepare_inputs_for_generation(self, input_ids, past, **kwargs): return inputs - @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -1091,7 +1091,7 @@ def __init__(self, config): self.init_weights() - @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -1196,7 +1196,7 @@ def __init__(self, config): self.init_weights() - @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -1305,7 +1305,7 @@ def __init__(self, config): self.init_weights() - @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)")) def forward( self, input_ids=None, @@ -1418,7 +1418,7 @@ def __init__(self, config): self.init_weights() - @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, @@ -1544,7 +1544,7 @@ def __init__(self, config): self.init_weights() - @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) + @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids=None, diff --git a/tests/test_modeling_longformer.py b/tests/test_modeling_longformer.py index dbf5e39ab363..00a67d716f8d 100644 --- a/tests/test_modeling_longformer.py +++ b/tests/test_modeling_longformer.py @@ -32,6 +32,7 @@ LongformerForSequenceClassification, LongformerForTokenClassification, LongformerForQuestionAnswering, + LongformerForMultipleChoice, ) @@ -228,6 +229,29 @@ def create_and_check_longformer_for_token_classification( self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]) self.check_loss_output(result) + def create_and_check_longformer_for_multiple_choice( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_choices = self.num_choices + model = LongformerForMultipleChoice(config=config) + model.to(torch_device) + model.eval() + multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + loss, logits = model( + multiple_choice_inputs_ids, + attention_mask=multiple_choice_input_mask, + token_type_ids=multiple_choice_token_type_ids, + labels=choice_labels, + ) + result = { + "loss": loss, + "logits": logits, + } + self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices]) + self.check_loss_output(result) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -298,6 +322,10 @@ def test_for_token_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_longformer_for_token_classification(*config_and_inputs) + def test_for_multiple_choice(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_longformer_for_multiple_choice(*config_and_inputs) + class LongformerModelIntegrationTest(unittest.TestCase): @slow