Skip to content

Commit

Permalink
updated relative shift
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be committed Sep 18, 2020
1 parent 7db7b2b commit 536b0b0
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion examples/xlnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ fn main() -> anyhow::Result<()> {
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);

input_tensor.print();
// Forward pass
let model_output = no_grad(|| {
xlnet_model
Expand Down
5 changes: 2 additions & 3 deletions src/xlnet/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ impl XLNetRelativeAttention {
fn rel_shift_bnij(&self, x: &Tensor, klen: i64) -> Tensor {
let shape = x.size();
x.reshape(&[shape[0], shape[1], shape[3], shape[2]])
.narrow(2, 1, shape[1] - 1)
.narrow(2, 1, shape[3] - 1)
.reshape(&[shape[0], shape[1], shape[2], shape[3] - 1])
.index_select(1, &Tensor::arange(klen, (Kind::Int64, x.device())))
.index_select(3, &Tensor::arange(klen, (Kind::Int64, x.device())))
}

fn rel_attention_core(
Expand Down Expand Up @@ -188,7 +188,6 @@ impl XLNetRelativeAttention {
}
None => Tensor::zeros(&[1], (Kind::Float, ac.device())),
};

let mut attention_score = (ac + bd + ef) * self.scale;
if let Some(value) = attention_mask {
attention_score = attention_score - value.permute(&[2, 3, 0, 1]) * 1e30;
Expand Down
19 changes: 10 additions & 9 deletions src/xlnet/xlnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,6 @@ impl XLNetModel {
) -> Tensor {
let frequency_sequence = Tensor::arange2(0, self.d_model, 2, (Kind::Float, device));
let inverse_frequency = 1f64 / Tensor::pow2(10000f64, &(frequency_sequence / self.d_model));

let (begin, end) = match self.attention_type {
AttentionType::bi => (k_len, -q_len),
AttentionType::uni => (k_len, -1),
Expand Down Expand Up @@ -340,13 +339,16 @@ impl XLNetModel {
"Only one of input ids or input embeddings may be set".into(),
));
}
None => (
input_value
.transpose(0, 1)
.contiguous()
.apply_t(&self.word_embeddings, train),
input_value.size(),
),
None => {
let size = input_value.size();
(
input_value
.transpose(0, 1)
.contiguous()
.apply_t(&self.word_embeddings, train),
vec![size[1], size[0]],
)
}
},
None => match input_embeds {
Some(embeds) => {
Expand Down Expand Up @@ -522,7 +524,6 @@ impl XLNetModel {
target_mapping.as_ref(),
train,
);
panic!();
output_h = temp.0;
output_g = temp.1;
let attention_probas_h = temp.2;
Expand Down

0 comments on commit 536b0b0

Please sign in to comment.