Skip to content

Commit

Permalink
Finalization of integration tests for XLNet
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be committed Sep 29, 2020
1 parent e28fd26 commit b0d04da
Show file tree
Hide file tree
Showing 2 changed files with 385 additions and 15 deletions.
19 changes: 8 additions & 11 deletions src/xlnet/xlnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ impl XLNetModel {
} else {
None
};
let mut all_attentions: Option<Vec<(Tensor, Tensor)>> = if self.output_attentions {
let mut all_attentions: Option<Vec<(Tensor, Option<Tensor>)>> = if self.output_attentions {
Some(vec![])
} else {
None
Expand Down Expand Up @@ -509,7 +509,7 @@ impl XLNetModel {
));
};
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push((attention_probas_h.unwrap(), attention_probas_g.unwrap()));
attentions.push((attention_probas_h.unwrap(), attention_probas_g));
};
}
let hidden_state = if let Some(output_g_value) = output_g {
Expand Down Expand Up @@ -812,10 +812,7 @@ pub struct XLNetForMultipleChoice {
}

impl XLNetForMultipleChoice {
pub fn new<'p, P>(
p: P,
config: &XLNetConfig,
) -> Result<XLNetForSequenceClassification, RustBertError>
pub fn new<'p, P>(p: P, config: &XLNetConfig) -> Result<XLNetForMultipleChoice, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
Expand All @@ -827,7 +824,7 @@ impl XLNetForMultipleChoice {

let logits_proj = nn::linear(p / "logits_proj", config.d_model, 1, Default::default());

Ok(XLNetForSequenceClassification {
Ok(XLNetForMultipleChoice {
base_model,
sequence_summary,
logits_proj,
Expand Down Expand Up @@ -974,7 +971,7 @@ pub struct XLNetModelOutput {
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<(Tensor, Option<Tensor>)>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<(Tensor, Tensor)>>,
pub all_attentions: Option<Vec<(Tensor, Option<Tensor>)>>,
}

/// Container for the XLNet sequence classification model output.
Expand All @@ -986,7 +983,7 @@ pub struct XLNetSequenceClassificationOutput {
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<(Tensor, Option<Tensor>)>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<(Tensor, Tensor)>>,
pub all_attentions: Option<Vec<(Tensor, Option<Tensor>)>>,
}

/// Container for the XLNet token classification model output.
Expand All @@ -998,7 +995,7 @@ pub struct XLNetTokenClassificationOutput {
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<(Tensor, Option<Tensor>)>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<(Tensor, Tensor)>>,
pub all_attentions: Option<Vec<(Tensor, Option<Tensor>)>>,
}

/// Container for the XLNet question answering model output.
Expand All @@ -1012,5 +1009,5 @@ pub struct XLNetQuestionAnsweringOutput {
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<(Tensor, Option<Tensor>)>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<(Tensor, Tensor)>>,
pub all_attentions: Option<Vec<(Tensor, Option<Tensor>)>>,
}
Loading

0 comments on commit b0d04da

Please sign in to comment.