Skip to content

Commit

Permalink
updated ALBERT documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be committed Sep 14, 2020
1 parent 882ec17 commit 59e6cb5
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 52 deletions.
7 changes: 2 additions & 5 deletions examples/bart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,7 @@ fn main() -> anyhow::Result<()> {

// Print masked tokens
println!("{:?}", model_output.encoder_hidden_state);
println!("{:?}", model_output.decoder_hidden_state);
println!(
"{:?}",
model_output.decoder_hidden_state.double_value(&[0, 0, 0])
);
println!("{:?}", model_output.decoder_output);
println!("{:?}", model_output.decoder_output.double_value(&[0, 0, 0]));
Ok(())
}
59 changes: 43 additions & 16 deletions src/albert/albert_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -400,9 +400,10 @@ impl AlbertForMaskedLM {
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*)
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `AlbertMaskedLMOutput` containing:
/// - `prediction_scores` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*)
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
Expand Down Expand Up @@ -543,9 +544,10 @@ impl AlbertForSequenceClassification {
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *num_labels*)
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `AlbertSequenceClassificationOutput` containing:
/// - `logits` - `Tensor` of shape (*batch size*, *num_labels*)
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
Expand Down Expand Up @@ -684,9 +686,10 @@ impl AlbertForTokenClassification {
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `AlbertTokenClassificationOutput` containing:
/// - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
Expand Down Expand Up @@ -814,10 +817,11 @@ impl AlbertForQuestionAnswering {
///
/// # Returns
///
/// * `start_scores` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
/// * `end_scores` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `AlbertQuestionAnsweringOutput` containing:
/// - `start_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
/// - `end_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
Expand Down Expand Up @@ -957,9 +961,10 @@ impl AlbertForMultipleChoice {
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*1*, *batch size*) containing the logits for each of the alternatives given
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `AlbertSequenceClassificationOutput` containing:
/// - `logits` - `Tensor` of shape (*1*, *batch size*) containing the logits for each of the alternatives given
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
Expand Down Expand Up @@ -1060,34 +1065,56 @@ impl AlbertForMultipleChoice {
}
}

/// Container for the ALBERT model output.
pub struct AlbertOutput {
/// Last hidden states from the model
pub hidden_state: Tensor,
/// Pooled output (hidden state for the first token)
pub pooled_output: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Vec<Tensor>>>,
}

/// Container for the ALBERT masked LM model output.
pub struct AlbertMaskedLMOutput {
/// Logits for the vocabulary items at each sequence position
pub prediction_scores: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Vec<Tensor>>>,
}

/// Container for the ALBERT sequence classification model
pub struct AlbertSequenceClassificationOutput {
/// Logits for each input (sequence) for each target class
pub logits: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Vec<Tensor>>>,
}

/// Container for the ALBERT token classification model
pub struct AlbertTokenClassificationOutput {
/// Logits for each sequence item (token) for each target class
pub logits: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Vec<Tensor>>>,
}

/// Container for the ALBERT question answering model
pub struct AlbertQuestionAnsweringOutput {
/// Logits for the start position for token of each input sequence
pub start_logits: Tensor,
/// Logits for the end position for token of each input sequence
pub end_logits: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Vec<Tensor>>>,
}
4 changes: 4 additions & 0 deletions src/albert/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,12 @@ impl AlbertTransformer {
}
}

/// Container holding the ALBERT transformer output
pub struct AlbertTransformerOutput {
/// Last hidden states of the transformer
pub hidden_state: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers. As layers in ALBERT can be made of a number of sub-layers, a vector of vector is used to store al of the attentions
pub all_attentions: Option<Vec<Vec<Tensor>>>,
}
22 changes: 9 additions & 13 deletions src/bart/bart_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ impl BartModel {
train,
);
BartModelOutput {
decoder_hidden_state: decoder_output.hidden_state,
decoder_output: decoder_output.hidden_state,
encoder_hidden_state: encoder_output.hidden_state,
cache: decoder_output.next_decoder_cache,
all_decoder_hidden_states: decoder_output.all_hidden_states,
Expand Down Expand Up @@ -535,10 +535,10 @@ impl BartForConditionalGeneration {
);

let lm_logits = base_model_output
.decoder_hidden_state
.decoder_output
.linear::<Tensor>(&self.base_model.embeddings.ws, None);
BartModelOutput {
decoder_hidden_state: lm_logits,
decoder_output: lm_logits,
..base_model_output
}
}
Expand Down Expand Up @@ -731,27 +731,23 @@ impl BartForSequenceClassification {
let eos_mask = input_ids.eq(self.eos_token_id);
let reshape = eos_mask.sum1(&[1], true, Int64);
let sentence_representation = base_model_output
.decoder_hidden_state
.decoder_output
.permute(&[2, 0, 1])
.masked_select(&eos_mask)
.view((-1, reshape.size()[0] * reshape.int64_value(&[0, 0])))
.transpose(0, 1)
.view((
base_model_output.decoder_hidden_state.size()[0],
base_model_output.decoder_output.size()[0],
-1,
*base_model_output
.decoder_hidden_state
.size()
.last()
.unwrap(),
*base_model_output.decoder_output.size().last().unwrap(),
))
.select(1, -1);

let logits = self
.classification_head
.forward_t(&sentence_representation, train);
BartModelOutput {
decoder_hidden_state: logits,
decoder_output: logits,
encoder_hidden_state: base_model_output.encoder_hidden_state,
cache: None,
all_decoder_hidden_states: base_model_output.all_decoder_hidden_states,
Expand Down Expand Up @@ -863,7 +859,7 @@ impl LMHeadModel for BartForConditionalGeneration {
};

let lm_logits = base_model_output
.decoder_hidden_state
.decoder_output
.linear::<Tensor>(&self.base_model.embeddings.ws, None);
Ok(LMModelOutput {
lm_logits,
Expand All @@ -876,7 +872,7 @@ impl LMHeadModel for BartForConditionalGeneration {
}

pub struct BartModelOutput {
pub decoder_hidden_state: Tensor,
pub decoder_output: Tensor,
pub encoder_hidden_state: Tensor,
pub cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
pub all_decoder_hidden_states: Option<Vec<Tensor>>,
Expand Down
14 changes: 7 additions & 7 deletions src/gpt2/gpt2_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ impl Gpt2Model {
}

Ok(Gpt2ModelOutput {
hidden_state: hidden_state.apply(&self.ln_f),
output: hidden_state.apply(&self.ln_f),
cache: all_presents,
all_hidden_states,
all_attentions,
Expand Down Expand Up @@ -623,7 +623,7 @@ impl LMHeadModel for GPT2LMHeadModel {
_decoder_input_ids: &Option<Tensor>,
train: bool,
) -> Result<LMModelOutput, &'static str> {
let model_output = match layer_past {
let base_model_output = match layer_past {
Cache::GPT2Cache(layer_past) => Ok(self.transformer.forward_t(
input_ids,
&layer_past,
Expand All @@ -645,19 +645,19 @@ impl LMHeadModel for GPT2LMHeadModel {
_ => Err("Cache not compatible with GPT2 model"),
}?;

let lm_logits = model_output.hidden_state.apply(&self.lm_head);
let lm_logits = base_model_output.output.apply(&self.lm_head);
Ok(LMModelOutput {
lm_logits,
encoder_hidden_state: None,
cache: Cache::GPT2Cache(model_output.cache),
all_hidden_states: model_output.all_hidden_states,
all_attentions: model_output.all_attentions,
cache: Cache::GPT2Cache(base_model_output.cache),
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
})
}
}

pub struct Gpt2ModelOutput {
pub hidden_state: Tensor,
pub output: Tensor,
pub cache: Option<Vec<Tensor>>,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
Expand Down
6 changes: 3 additions & 3 deletions src/marian/marian_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,10 @@ impl MarianForConditionalGeneration {
);

let lm_logits = base_model_output
.decoder_hidden_state
.decoder_output
.linear::<Tensor>(&self.base_model.embeddings.ws, None);
BartModelOutput {
decoder_hidden_state: lm_logits,
decoder_output: lm_logits,
..base_model_output
}
}
Expand Down Expand Up @@ -482,7 +482,7 @@ impl LMHeadModel for MarianForConditionalGeneration {
};

let lm_logits = base_model_output
.decoder_hidden_state
.decoder_output
.linear::<Tensor>(&self.base_model.embeddings.ws, None)
+ &self.final_logits_bias;
Ok(LMModelOutput {
Expand Down
2 changes: 1 addition & 1 deletion src/pipelines/sequence_classification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ impl SequenceClassificationOption {
None,
train,
)
.decoder_hidden_state
.decoder_output
}
Self::Bert(ref model) => {
model
Expand Down
2 changes: 1 addition & 1 deletion src/pipelines/zero_shot_classification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ impl ZeroShotClassificationOption {
None,
train,
)
.decoder_hidden_state
.decoder_output
}
Self::Bert(ref model) => {
model
Expand Down
7 changes: 3 additions & 4 deletions src/t5/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,9 @@ impl T5Attention {
temp_value = temp_value.slice(2, length - 1, length, 1);
};
if let Some(attention_mask) = attention_mask {
Some(temp_value + attention_mask)
} else {
Some(temp_value)
}
temp_value = temp_value + attention_mask
};
temp_value
} else {
None
};
Expand Down
4 changes: 2 additions & 2 deletions tests/bart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ fn bart_lm_model() -> anyhow::Result<()> {
let model_output =
bart_model.forward_t(Some(&input_tensor), None, None, None, None, None, false);

assert_eq!(model_output.decoder_hidden_state.size(), vec!(1, 6, 1024));
assert_eq!(model_output.decoder_output.size(), vec!(1, 6, 1024));
assert_eq!(model_output.encoder_hidden_state.size(), vec!(1, 6, 1024));
assert!((model_output.decoder_hidden_state.double_value(&[0, 0, 0]) - 0.7877).abs() < 1e-4);
assert!((model_output.decoder_output.double_value(&[0, 0, 0]) - 0.7877).abs() < 1e-4);
Ok(())
}

Expand Down

0 comments on commit 59e6cb5

Please sign in to comment.