Skip to content

Commit

Permalink
Computation of data masks
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be committed Sep 16, 2020
1 parent 6f907ff commit 161bc4c
Showing 1 changed file with 117 additions and 2 deletions.
119 changes: 117 additions & 2 deletions src/xlnet/xlnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use crate::common::dropout::Dropout;
use crate::xlnet::attention::LayerState;
use crate::xlnet::encoder::XLNetLayer;
use crate::Config;
use crate::{Config, RustBertError};
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::collections::HashMap;
Expand Down Expand Up @@ -147,7 +147,7 @@ pub struct XLNetModel {
}

impl XLNetModel {
pub fn new<'p, P>(p: P, config: &XLNetConfig, generation_mode: bool) -> XLNetModel
pub fn new<'p, P>(p: P, config: &XLNetConfig) -> XLNetModel
where
P: Borrow<nn::Path<'p>>,
{
Expand Down Expand Up @@ -304,4 +304,119 @@ impl XLNetModel {
self.positional_embedding(&forward_positions_sequence, &inverse_frequency, batch_size)
}
}

pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
old_layer_states: Option<Vec<Option<LayerState>>>,
perm_mask: Option<Tensor>,
target_mapping: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<(), RustBertError> {
let (input_embeddings, input_shape) = match input_ids {
Some(input_value) => match input_embeds {
Some(_) => {
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
None => (
input_value
.transpose(0, 1)
.contiguous()
.apply_t(&self.word_embeddings, train),
input_value.size(),
),
},
None => match input_embeds {
Some(embeds) => {
let size = vec![embeds.size()[1], embeds.size()[0]];
(embeds.transpose(0, 1).contiguous(), size)
}
None => {
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
},
};

let token_type_ids = match token_type_ids {
Some(token_type_ids) => Some(token_type_ids.transpose(0, 1).contiguous()),
None => None,
};
let attention_mask = match attention_mask {
Some(attention_mask) => Some(attention_mask.transpose(0, 1).contiguous()),
None => None,
};
let perm_mask = match perm_mask {
Some(perm_mask) => Some(perm_mask.permute(&[1, 2, 0]).contiguous()),
None => None,
};
let target_mapping = match target_mapping {
Some(target_mapping) => Some(target_mapping.permute(&[1, 2, 0]).contiguous()),
None => None,
};

let m_len = if let Some(mems) = &old_layer_states {
if let Some(mem_0) = &mems[0] {
mem_0.prev_content.size()[0]
} else {
0
}
} else {
0
};
let (q_len, batch_size) = (input_shape[0], input_shape[1]);
let k_len = q_len + m_len;

let mut attn_mask = match self.attention_type {
AttentionType::uni => Some(
self.create_mask(q_len, m_len, input_embeddings.device())
.unsqueeze(-1)
.unsqueeze(-1),
),
AttentionType::bi => None,
};

let input_mask: Option<Tensor> = if let Some(attention_mask) = attention_mask {
Some(1.0 - attention_mask)
} else {
None
};

let mut data_mask: Option<Tensor> = match (input_mask, perm_mask) {
(Some(input_mask_value), Some(perm_mask_value)) => {
Some(input_mask_value.unsqueeze(0) + perm_mask_value)
}
(Some(input_mask_value), None) => Some(input_mask_value.unsqueeze(0)),
(None, Some(perm_mask_value)) => Some(perm_mask_value),
(None, None) => None,
};

let (data_mask, attn_mask) = if let Some(data_mask) = data_mask {
let data_mask = if m_len > 0 {
let mems_mask = Tensor::zeros(
&[data_mask.size()[0], m_len, batch_size],
(Kind::Bool, data_mask.device()),
);
Tensor::cat(&[mems_mask, data_mask], 1)
} else {
data_mask
};
let attn_mask = Some(if let Some(attn_mask) = attn_mask {
attn_mask + data_mask.unsqueeze(-1)
} else {
data_mask.unsqueeze(-1)
});
(Some(data_mask), attn_mask)
} else {
(data_mask, attn_mask)
};

Ok(())
}
}

0 comments on commit 161bc4c

Please sign in to comment.