Skip to content

Commit

Permalink
[Longformer] Multiple choice for longformer (huggingface#4645)
Browse files Browse the repository at this point in the history
* add multiple choice for longformer

* add models to docs

* adapt docstring

* add test to longformer

* add longformer for mc in init and modeling auto

* fix tests
  • Loading branch information
patrickvonplaten authored May 29, 2020
1 parent 91487cb commit 9c17256
Showing 12 changed files with 227 additions and 63 deletions.
14 changes: 14 additions & 0 deletions docs/source/model_doc/albert.rst
Original file line number Diff line number Diff line change
@@ -94,3 +94,17 @@ TFAlbertForSequenceClassification

.. autoclass:: transformers.TFAlbertForSequenceClassification
:members:


TFAlbertForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFAlbertForMultipleChoice
:members:


TFAlbertForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFAlbertForQuestionAnswering
:members:
15 changes: 15 additions & 0 deletions docs/source/model_doc/longformer.rst
Original file line number Diff line number Diff line change
@@ -74,3 +74,18 @@ LongformerForQuestionAnswering

.. autoclass:: transformers.LongformerForQuestionAnswering
:members:


LongformerForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.LongformerForMultipleChoice
:members:


LongformerForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.LongformerForTokenClassification
:members:

7 changes: 7 additions & 0 deletions docs/source/model_doc/roberta.rst
Original file line number Diff line number Diff line change
@@ -74,6 +74,13 @@ RobertaForSequenceClassification
:members:


RobertaForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.RobertaForMultipleChoice
:members:


RobertaForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

1 change: 1 addition & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -326,6 +326,7 @@
LongformerModel,
LongformerForMaskedLM,
LongformerForSequenceClassification,
LongformerForMultipleChoice,
LongformerForTokenClassification,
LongformerForQuestionAnswering,
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
2 changes: 2 additions & 0 deletions src/transformers/modeling_auto.py
Original file line number Diff line number Diff line change
@@ -104,6 +104,7 @@
from .modeling_longformer import (
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
LongformerForMaskedLM,
LongformerForMultipleChoice,
LongformerForQuestionAnswering,
LongformerForSequenceClassification,
LongformerForTokenClassification,
@@ -297,6 +298,7 @@
[
(CamembertConfig, CamembertForMultipleChoice),
(XLMRobertaConfig, XLMRobertaForMultipleChoice),
(LongformerConfig, LongformerForMultipleChoice),
(RobertaConfig, RobertaForMultipleChoice),
(BertConfig, BertForMultipleChoice),
(XLNetConfig, XLNetForMultipleChoice),
24 changes: 12 additions & 12 deletions src/transformers/modeling_bert.py
Original file line number Diff line number Diff line change
@@ -543,27 +543,27 @@ def _init_weights(self, module):

BERT_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`transformers.BertTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
`What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token
`What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``.
@@ -632,7 +632,7 @@ def _prune_heads(self, heads_to_prune):
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)

@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward(
self,
input_ids=None,
@@ -759,7 +759,7 @@ def __init__(self, config):
def get_output_embeddings(self):
return self.cls.predictions.decoder

@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward(
self,
input_ids=None,
@@ -859,7 +859,7 @@ def __init__(self, config):
def get_output_embeddings(self):
return self.cls.predictions.decoder

@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward(
self,
input_ids=None,
@@ -992,7 +992,7 @@ def __init__(self, config):

self.init_weights()

@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward(
self,
input_ids=None,
@@ -1081,7 +1081,7 @@ def __init__(self, config):

self.init_weights()

@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward(
self,
input_ids=None,
@@ -1177,7 +1177,7 @@ def __init__(self, config):

self.init_weights()

@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
def forward(
self,
input_ids=None,
@@ -1278,7 +1278,7 @@ def __init__(self, config):

self.init_weights()

@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward(
self,
input_ids=None,
@@ -1375,7 +1375,7 @@ def __init__(self, config):

self.init_weights()

@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward(
self,
input_ids=None,
115 changes: 106 additions & 9 deletions src/transformers/modeling_longformer.py
Original file line number Diff line number Diff line change
@@ -411,15 +411,15 @@ def forward(

LONGFORMER_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`transformers.LonmgformerTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Mask to decide the attention given on each token, local attention, global attenion, or no attention (for padding tokens).
Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is important for
task-specific finetuning because it makes the model more flexible at representing the task. For example,
@@ -431,13 +431,13 @@ def forward(
``2`` for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
`What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token
`What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``.
@@ -537,7 +537,7 @@ def _pad_to_window_size(

return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds

@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING)
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward(
self,
input_ids=None,
@@ -641,7 +641,7 @@ def __init__(self, config):

self.init_weights()

@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING)
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward(
self,
input_ids=None,
@@ -729,7 +729,7 @@ def __init__(self, config):
self.longformer = LongformerModel(config)
self.classifier = LongformerClassificationHead(config)

@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING)
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward(
self,
input_ids=None,
@@ -866,7 +866,7 @@ def _get_question_end_index(self, input_ids):

return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1]

@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING)
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward(
self,
input_ids,
@@ -993,7 +993,7 @@ def __init__(self, config):

self.init_weights()

@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING)
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward(
self,
input_ids=None,
@@ -1070,3 +1070,100 @@ def forward(
outputs = (loss,) + outputs

return outputs # (loss), scores, (hidden_states), (attentions)


@add_start_docstrings(
"""Longformer Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
LONGFORMER_START_DOCSTRING,
)
class LongformerForMultipleChoice(BertPreTrainedModel):
config_class = LongformerConfig
pretrained_model_archive_map = LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "longformer"

def __init__(self, config):
super().__init__(config)

self.longformer = LongformerModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)

self.init_weights()

@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
def forward(
self,
input_ids=None,
token_type_ids=None,
attention_mask=None,
labels=None,
position_ids=None,
inputs_embeds=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
loss (:obj:`torch.FloatTensor`` of shape ``(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
from transformers import LongformerTokenizer, LongformerForTokenClassification
import torch
tokenizer = LongformerTokenizer.from_pretrained('longformer-base-4096')
model = LongformerForMultipleChoice.from_pretrained('longformer-base-4096')
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
input_ids = torch.tensor([tokenizer.encode(s, add_special_tokens=True) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
labels = torch.tensor(1).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, classification_scores = outputs[:2]
"""
num_choices = input_ids.shape[1]

flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
outputs = self.longformer(
flat_input_ids,
position_ids=flat_position_ids,
token_type_ids=flat_token_type_ids,
attention_mask=flat_attention_mask,
)
pooled_output = outputs[1]

pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices)

outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here

if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
outputs = (loss,) + outputs

return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
Loading
Oops, something went wrong.

0 comments on commit 9c17256

Please sign in to comment.