Skip to content

Commit

Permalink
Implemented generation for XLNet
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be committed Sep 20, 2020
1 parent 6d6c7ea commit 8d0d222
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
23 changes: 20 additions & 3 deletions src/pipelines/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1509,7 +1509,7 @@ impl PrivateLanguageGenerator<XLNetLMHeadModel, XLNetVocab, XLNetTokenizer> for
&[effective_batch_size, sequence_length, sequence_length],
(Kind::Float, input_ids.device()),
);
let _ = perm_mask.narrow(2, sequence_length - 2, 1).fill_(1.0);
let _ = perm_mask.narrow(2, sequence_length - 1, 1).fill_(1.0);

let target_mapping = Tensor::zeros(
&[effective_batch_size, 1, sequence_length],
Expand All @@ -1523,13 +1523,30 @@ impl PrivateLanguageGenerator<XLNetLMHeadModel, XLNetVocab, XLNetTokenizer> for

match past {
Cache::XLNetCache(past) => {
if past.is_some() {
if let Some(past) = past {
// let new_past = Vec::with_capacity(past.len());
let past = if let Some(first_past) = &past[0] {
let past_len = first_past.prev_content.size()[0];
past.iter()
.map(|old_layer_state| {
Some(LayerState {
prev_content: old_layer_state
.as_ref()
.unwrap()
.prev_content
.slice(0, 0, past_len - offset, 1),
})
})
.collect()
} else {
past
};
(
Some(input_ids),
Some(perm_mask),
None,
Some(target_mapping),
Cache::XLNetCache(past),
Cache::XLNetCache(Some(past)),
)
} else {
(
Expand Down
2 changes: 1 addition & 1 deletion src/xlnet/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl Clone for LayerState {

impl LayerState {
pub(crate) fn reorder_cache(&mut self, new_indices: &Tensor) {
self.prev_content = self.prev_content.index_select(0, new_indices);
self.prev_content = self.prev_content.index_select(1, new_indices);
}
}

Expand Down
5 changes: 2 additions & 3 deletions src/xlnet/xlnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,7 @@ impl XLNetModel {
});
}

let non_tgt_mask = if let Some(attn_mask_value) = attn_mask {
attn_mask = Some(attn_mask_value.ge(0));
let non_tgt_mask = if let Some(attn_mask_value) = &attn_mask {
let mut non_tgt_mask = -Tensor::eye(q_len, (Kind::Int64, attn_mask_value.device()));
if m_len > 0 {
non_tgt_mask = Tensor::cat(
Expand All @@ -446,7 +445,7 @@ impl XLNetModel {
-1,
);
}
Some((attn_mask_value + non_tgt_mask.unsqueeze(-1).unsqueeze(-1)).ge(0))
Some((attn_mask_value + non_tgt_mask.unsqueeze(-1).unsqueeze(-1)).gt(0))
} else {
None
};
Expand Down

0 comments on commit 8d0d222

Please sign in to comment.