Skip to content

Commit

Permalink
updated BART documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be committed Sep 14, 2020
1 parent 59e6cb5 commit 40b33dd
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 27 deletions.
67 changes: 41 additions & 26 deletions src/bart/bart_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,14 +315,14 @@ impl BartModel {
///
/// # Returns
///
/// * `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing the activations of the last decoder hidden state
/// * `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
/// * `decoder_cache` - `(Option<Tensor>, Option<Vec<&LayerState, &LayerState>>)` of length *n_layer* containing the encoder padding mask and past keys and values for
/// both the self attention and the encoder cross attention of each layer of the decoder.
/// * `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// * `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// * `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// * `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// * `BartModelOutput` containing:
/// - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing the activations of the last decoder hidden state
/// - `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
/// - `cache` - `(Option<Tensor>, Option<Vec<&LayerState, &LayerState>>)` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
/// - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
///
/// # Example
///
Expand Down Expand Up @@ -476,12 +476,14 @@ impl BartForConditionalGeneration {
///
/// # Returns
///
/// * `lm_logits` - `Tensor` of shape (*batch size*, *target_sequence_length*, *vocab_size*) representing the logits for each vocab item and position
/// * `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
/// * `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// * `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// * `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// * `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// * `BartModelOutput` containing:
/// - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *vocab_size*) representing the logits for each vocabulary item and position
/// - `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
/// - `cache` - `(Option<Tensor>, Option<Vec<&LayerState, &LayerState>>)` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
/// - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
///
/// # Example
///
Expand Down Expand Up @@ -672,12 +674,14 @@ impl BartForSequenceClassification {
///
/// # Returns
///
/// * `logits` - `Tensor` of shape (*batch size*, *num_classes*) representing the logits for each class item and batch item
/// * `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
/// * `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// * `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// * `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// * `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// * `BartModelOutput` containing:
/// - `decoder_output` - `Tensor` of shape (*batch size*, *num_classes*) representing the activations for each class and batch item
/// - `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
/// - `cache` - `(Option<Tensor>, Option<Vec<&LayerState, &LayerState>>)` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
/// - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
///
/// # Example
///
Expand Down Expand Up @@ -775,13 +779,13 @@ impl LMHeadModel for BartForConditionalGeneration {
///
///
/// # Returns
///
/// * `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position
/// * `past` - `BartCache` made of `Option<Vec<(Option<Vec<&LayerState, &LayerState>>)>>` of length *n_layer* containing the encoder past keys and values for
/// * `LMModelOutput` containing:
/// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position
/// - `cache` - `BartCache` made of `Option<Vec<(Option<Vec<&LayerState, &LayerState>>)>>` of length *n_layer* containing the encoder past keys and values for
/// both the self attention and the encoder cross attention of each layer of the decoder.
/// * `encoder_hidden_states` - `Option<Tensor>` Hidden states for the encoder
/// * `hidden_states` - None
/// * `attentions` - None
/// - `encoder_hidden_states` - `Option<Tensor>` Hidden states for the encoder
/// - `all_hidden_states` - None
/// - `all_attentions` - None
///
/// # Example
///
Expand Down Expand Up @@ -871,12 +875,23 @@ impl LMHeadModel for BartForConditionalGeneration {
}
}

/// Container holding a BART model output. The decoder output may hold the hidden state of
/// the last layer of the decoder, or may hold logits for a custom head module after the
/// decoder (e.g. for classification or language modeling tasks)
pub struct BartModelOutput {
/// Hidden state of the last layer of the decoder, or logits for a custom head
/// module after the decoder (e.g. for classification or language modeling tasks)
pub decoder_output: Tensor,
/// Hidden state for the last layer of the encoder
pub encoder_hidden_state: Tensor,
/// Cached outputs of the model (attention layers keys and values) if the model is used for generation
pub cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
/// Hidden states for all layers of the decoder
pub all_decoder_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all layers of the decoder
pub all_decoder_attentions: Option<Vec<Tensor>>,
/// Hidden states for all layers of the encoder
pub all_encoder_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all layers of the encoder
pub all_encoder_attentions: Option<Vec<Tensor>>,
}
6 changes: 6 additions & 0 deletions src/bart/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,16 @@ impl BartDecoder {
}
}

///Container holding a BART decoder output
pub struct BartDecoderOutput {
/// last decoder layer hidden state
pub hidden_state: Tensor,
/// Padding mask for the encoder positions to attend to
pub encoder_padding_mask: Option<Tensor>,
/// Cached outputs of the model (attention layers keys and values) if the model is used for generation
pub next_decoder_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}
4 changes: 4 additions & 0 deletions src/bart/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,12 @@ impl BartEncoder {
}
}

/// Container holding a BART encoder output
pub struct BartEncoderOutput {
/// Last encoder layer hidden state
pub hidden_state: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}
6 changes: 6 additions & 0 deletions src/pipelines/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2616,10 +2616,16 @@ pub trait LMHeadModel {
) -> Result<LMModelOutput, &'static str>;
}

/// Container holding a language model output for generation tasks
pub struct LMModelOutput {
/// Logits for each vocab item and position
pub lm_logits: Tensor,
/// Encoder hidden state (re-used for encoder/decoder architectures)
pub encoder_hidden_state: Option<Tensor>,
/// cached state for improved efficiency during decoding
pub cache: Cache,
/// Hidden states for all intermediate model layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate model layers
pub all_attentions: Option<Vec<Tensor>>,
}
2 changes: 1 addition & 1 deletion src/t5/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ impl T5Attention {
if let Some(attention_mask) = attention_mask {
temp_value = temp_value + attention_mask
};
temp_value
Some(temp_value)
} else {
None
};
Expand Down

0 comments on commit 40b33dd

Please sign in to comment.