From 59e6cb54f0eeacd55aa4857efb533d4d4bfe17e1 Mon Sep 17 00:00:00 2001 From: guillaume-be Date: Mon, 14 Sep 2020 18:33:58 +0200 Subject: [PATCH] updated ALBERT documentation --- examples/bart.rs | 7 +-- src/albert/albert_model.rs | 59 +++++++++++++++++------ src/albert/encoder.rs | 4 ++ src/bart/bart_model.rs | 22 ++++----- src/gpt2/gpt2_model.rs | 14 +++--- src/marian/marian_model.rs | 6 +-- src/pipelines/sequence_classification.rs | 2 +- src/pipelines/zero_shot_classification.rs | 2 +- src/t5/attention.rs | 7 ++- tests/bart.rs | 4 +- 10 files changed, 75 insertions(+), 52 deletions(-) diff --git a/examples/bart.rs b/examples/bart.rs index ac1897235..c8749963f 100644 --- a/examples/bart.rs +++ b/examples/bart.rs @@ -78,10 +78,7 @@ fn main() -> anyhow::Result<()> { // Print masked tokens println!("{:?}", model_output.encoder_hidden_state); - println!("{:?}", model_output.decoder_hidden_state); - println!( - "{:?}", - model_output.decoder_hidden_state.double_value(&[0, 0, 0]) - ); + println!("{:?}", model_output.decoder_output); + println!("{:?}", model_output.decoder_output.double_value(&[0, 0, 0])); Ok(()) } diff --git a/src/albert/albert_model.rs b/src/albert/albert_model.rs index 0978bcdc4..de4b56fa1 100644 --- a/src/albert/albert_model.rs +++ b/src/albert/albert_model.rs @@ -400,9 +400,10 @@ impl AlbertForMaskedLM { /// /// # Returns /// - /// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) - /// * `hidden_states` - `Option>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*) - /// * `attentions` - `Option>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*) + /// * `AlbertMaskedLMOutput` containing: + /// - `prediction_scores` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) + /// - `all_hidden_states` - `Option>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*) + /// - `all_attentions` - `Option>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*) /// /// # Example /// @@ -543,9 +544,10 @@ impl AlbertForSequenceClassification { /// /// # Returns /// - /// * `output` - `Tensor` of shape (*batch size*, *num_labels*) - /// * `hidden_states` - `Option>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*) - /// * `attentions` - `Option>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*) + /// * `AlbertSequenceClassificationOutput` containing: + /// - `logits` - `Tensor` of shape (*batch size*, *num_labels*) + /// - `all_hidden_states` - `Option>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*) + /// - `all_attentions` - `Option>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*) /// /// # Example /// @@ -684,9 +686,10 @@ impl AlbertForTokenClassification { /// /// # Returns /// - /// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes - /// * `hidden_states` - `Option>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*) - /// * `attentions` - `Option>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*) + /// * `AlbertTokenClassificationOutput` containing: + /// - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes + /// - `all_hidden_states` - `Option>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*) + /// - `all_attentions` - `Option>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*) /// /// # Example /// @@ -814,10 +817,11 @@ impl AlbertForQuestionAnswering { /// /// # Returns /// - /// * `start_scores` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer - /// * `end_scores` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer - /// * `hidden_states` - `Option>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*) - /// * `attentions` - `Option>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*) + /// * `AlbertQuestionAnsweringOutput` containing: + /// - `start_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer + /// - `end_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer + /// - `all_hidden_states` - `Option>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*) + /// - `all_attentions` - `Option>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*) /// /// # Example /// @@ -957,9 +961,10 @@ impl AlbertForMultipleChoice { /// /// # Returns /// - /// * `output` - `Tensor` of shape (*1*, *batch size*) containing the logits for each of the alternatives given - /// * `hidden_states` - `Option>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*) - /// * `attentions` - `Option>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*) + /// * `AlbertSequenceClassificationOutput` containing: + /// - `logits` - `Tensor` of shape (*1*, *batch size*) containing the logits for each of the alternatives given + /// - `all_hidden_states` - `Option>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*) + /// - `all_attentions` - `Option>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*) /// /// # Example /// @@ -1060,34 +1065,56 @@ impl AlbertForMultipleChoice { } } +/// Container for the ALBERT model output. pub struct AlbertOutput { + /// Last hidden states from the model pub hidden_state: Tensor, + /// Pooled output (hidden state for the first token) pub pooled_output: Tensor, + /// Hidden states for all intermediate layers pub all_hidden_states: Option>, + /// Attention weights for all intermediate layers pub all_attentions: Option>>, } +/// Container for the ALBERT masked LM model output. pub struct AlbertMaskedLMOutput { + /// Logits for the vocabulary items at each sequence position pub prediction_scores: Tensor, + /// Hidden states for all intermediate layers pub all_hidden_states: Option>, + /// Attention weights for all intermediate layers pub all_attentions: Option>>, } +/// Container for the ALBERT sequence classification model pub struct AlbertSequenceClassificationOutput { + /// Logits for each input (sequence) for each target class pub logits: Tensor, + /// Hidden states for all intermediate layers pub all_hidden_states: Option>, + /// Attention weights for all intermediate layers pub all_attentions: Option>>, } +/// Container for the ALBERT token classification model pub struct AlbertTokenClassificationOutput { + /// Logits for each sequence item (token) for each target class pub logits: Tensor, + /// Hidden states for all intermediate layers pub all_hidden_states: Option>, + /// Attention weights for all intermediate layers pub all_attentions: Option>>, } +/// Container for the ALBERT question answering model pub struct AlbertQuestionAnsweringOutput { + /// Logits for the start position for token of each input sequence pub start_logits: Tensor, + /// Logits for the end position for token of each input sequence pub end_logits: Tensor, + /// Hidden states for all intermediate layers pub all_hidden_states: Option>, + /// Attention weights for all intermediate layers pub all_attentions: Option>>, } diff --git a/src/albert/encoder.rs b/src/albert/encoder.rs index 7aaade00a..cd9a97e34 100644 --- a/src/albert/encoder.rs +++ b/src/albert/encoder.rs @@ -259,8 +259,12 @@ impl AlbertTransformer { } } +/// Container holding the ALBERT transformer output pub struct AlbertTransformerOutput { + /// Last hidden states of the transformer pub hidden_state: Tensor, + /// Hidden states for all intermediate layers pub all_hidden_states: Option>, + /// Attention weights for all intermediate layers. As layers in ALBERT can be made of a number of sub-layers, a vector of vector is used to store al of the attentions pub all_attentions: Option>>, } diff --git a/src/bart/bart_model.rs b/src/bart/bart_model.rs index 06fa030d2..a1f835193 100644 --- a/src/bart/bart_model.rs +++ b/src/bart/bart_model.rs @@ -405,7 +405,7 @@ impl BartModel { train, ); BartModelOutput { - decoder_hidden_state: decoder_output.hidden_state, + decoder_output: decoder_output.hidden_state, encoder_hidden_state: encoder_output.hidden_state, cache: decoder_output.next_decoder_cache, all_decoder_hidden_states: decoder_output.all_hidden_states, @@ -535,10 +535,10 @@ impl BartForConditionalGeneration { ); let lm_logits = base_model_output - .decoder_hidden_state + .decoder_output .linear::(&self.base_model.embeddings.ws, None); BartModelOutput { - decoder_hidden_state: lm_logits, + decoder_output: lm_logits, ..base_model_output } } @@ -731,19 +731,15 @@ impl BartForSequenceClassification { let eos_mask = input_ids.eq(self.eos_token_id); let reshape = eos_mask.sum1(&[1], true, Int64); let sentence_representation = base_model_output - .decoder_hidden_state + .decoder_output .permute(&[2, 0, 1]) .masked_select(&eos_mask) .view((-1, reshape.size()[0] * reshape.int64_value(&[0, 0]))) .transpose(0, 1) .view(( - base_model_output.decoder_hidden_state.size()[0], + base_model_output.decoder_output.size()[0], -1, - *base_model_output - .decoder_hidden_state - .size() - .last() - .unwrap(), + *base_model_output.decoder_output.size().last().unwrap(), )) .select(1, -1); @@ -751,7 +747,7 @@ impl BartForSequenceClassification { .classification_head .forward_t(&sentence_representation, train); BartModelOutput { - decoder_hidden_state: logits, + decoder_output: logits, encoder_hidden_state: base_model_output.encoder_hidden_state, cache: None, all_decoder_hidden_states: base_model_output.all_decoder_hidden_states, @@ -863,7 +859,7 @@ impl LMHeadModel for BartForConditionalGeneration { }; let lm_logits = base_model_output - .decoder_hidden_state + .decoder_output .linear::(&self.base_model.embeddings.ws, None); Ok(LMModelOutput { lm_logits, @@ -876,7 +872,7 @@ impl LMHeadModel for BartForConditionalGeneration { } pub struct BartModelOutput { - pub decoder_hidden_state: Tensor, + pub decoder_output: Tensor, pub encoder_hidden_state: Tensor, pub cache: Option, Option)>>, pub all_decoder_hidden_states: Option>, diff --git a/src/gpt2/gpt2_model.rs b/src/gpt2/gpt2_model.rs index 1e09cc687..339c61a0e 100644 --- a/src/gpt2/gpt2_model.rs +++ b/src/gpt2/gpt2_model.rs @@ -476,7 +476,7 @@ impl Gpt2Model { } Ok(Gpt2ModelOutput { - hidden_state: hidden_state.apply(&self.ln_f), + output: hidden_state.apply(&self.ln_f), cache: all_presents, all_hidden_states, all_attentions, @@ -623,7 +623,7 @@ impl LMHeadModel for GPT2LMHeadModel { _decoder_input_ids: &Option, train: bool, ) -> Result { - let model_output = match layer_past { + let base_model_output = match layer_past { Cache::GPT2Cache(layer_past) => Ok(self.transformer.forward_t( input_ids, &layer_past, @@ -645,19 +645,19 @@ impl LMHeadModel for GPT2LMHeadModel { _ => Err("Cache not compatible with GPT2 model"), }?; - let lm_logits = model_output.hidden_state.apply(&self.lm_head); + let lm_logits = base_model_output.output.apply(&self.lm_head); Ok(LMModelOutput { lm_logits, encoder_hidden_state: None, - cache: Cache::GPT2Cache(model_output.cache), - all_hidden_states: model_output.all_hidden_states, - all_attentions: model_output.all_attentions, + cache: Cache::GPT2Cache(base_model_output.cache), + all_hidden_states: base_model_output.all_hidden_states, + all_attentions: base_model_output.all_attentions, }) } } pub struct Gpt2ModelOutput { - pub hidden_state: Tensor, + pub output: Tensor, pub cache: Option>, pub all_hidden_states: Option>, pub all_attentions: Option>, diff --git a/src/marian/marian_model.rs b/src/marian/marian_model.rs index 88d81b9a8..6cdb2365e 100644 --- a/src/marian/marian_model.rs +++ b/src/marian/marian_model.rs @@ -358,10 +358,10 @@ impl MarianForConditionalGeneration { ); let lm_logits = base_model_output - .decoder_hidden_state + .decoder_output .linear::(&self.base_model.embeddings.ws, None); BartModelOutput { - decoder_hidden_state: lm_logits, + decoder_output: lm_logits, ..base_model_output } } @@ -482,7 +482,7 @@ impl LMHeadModel for MarianForConditionalGeneration { }; let lm_logits = base_model_output - .decoder_hidden_state + .decoder_output .linear::(&self.base_model.embeddings.ws, None) + &self.final_logits_bias; Ok(LMModelOutput { diff --git a/src/pipelines/sequence_classification.rs b/src/pipelines/sequence_classification.rs index b5519b7da..c4b0554dd 100644 --- a/src/pipelines/sequence_classification.rs +++ b/src/pipelines/sequence_classification.rs @@ -301,7 +301,7 @@ impl SequenceClassificationOption { None, train, ) - .decoder_hidden_state + .decoder_output } Self::Bert(ref model) => { model diff --git a/src/pipelines/zero_shot_classification.rs b/src/pipelines/zero_shot_classification.rs index 0c3506d0f..39e5353dc 100644 --- a/src/pipelines/zero_shot_classification.rs +++ b/src/pipelines/zero_shot_classification.rs @@ -332,7 +332,7 @@ impl ZeroShotClassificationOption { None, train, ) - .decoder_hidden_state + .decoder_output } Self::Bert(ref model) => { model diff --git a/src/t5/attention.rs b/src/t5/attention.rs index f4c8177a4..8bf4f13e4 100644 --- a/src/t5/attention.rs +++ b/src/t5/attention.rs @@ -199,10 +199,9 @@ impl T5Attention { temp_value = temp_value.slice(2, length - 1, length, 1); }; if let Some(attention_mask) = attention_mask { - Some(temp_value + attention_mask) - } else { - Some(temp_value) - } + temp_value = temp_value + attention_mask + }; + temp_value } else { None }; diff --git a/tests/bart.rs b/tests/bart.rs index 7999d6b6d..2f2ba1a98 100644 --- a/tests/bart.rs +++ b/tests/bart.rs @@ -65,9 +65,9 @@ fn bart_lm_model() -> anyhow::Result<()> { let model_output = bart_model.forward_t(Some(&input_tensor), None, None, None, None, None, false); - assert_eq!(model_output.decoder_hidden_state.size(), vec!(1, 6, 1024)); + assert_eq!(model_output.decoder_output.size(), vec!(1, 6, 1024)); assert_eq!(model_output.encoder_hidden_state.size(), vec!(1, 6, 1024)); - assert!((model_output.decoder_hidden_state.double_value(&[0, 0, 0]) - 0.7877).abs() < 1e-4); + assert!((model_output.decoder_output.double_value(&[0, 0, 0]) - 0.7877).abs() < 1e-4); Ok(()) }