Skip to content

Commit

Permalink
updated weights shape
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be committed Sep 18, 2020
1 parent abe177c commit 7db7b2b
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 83 deletions.
92 changes: 46 additions & 46 deletions examples/xlnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

extern crate anyhow;

use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::xlnet::{
XLNetConfig, XLNetConfigResources, XLNetModelResources, XLNetVocabResources,
XLNetConfig, XLNetConfigResources, XLNetModel, XLNetModelResources, XLNetVocabResources,
};
use rust_bert::Config;
use rust_tokenizers::{Tokenizer, TruncationStrategy, Vocab, XLNetTokenizer};
use rust_tokenizers::{Tokenizer, TruncationStrategy, XLNetTokenizer};
use tch::{nn, no_grad, Device, Tensor};

fn main() -> anyhow::Result<()> {
Expand All @@ -35,52 +35,52 @@ fn main() -> anyhow::Result<()> {
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let _weights_path = weights_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;

// Set-up masked LM model
let _device = Device::Cpu;
// let mut vs = nn::VarStore::new(device);
let _tokenizer: XLNetTokenizer =
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: XLNetTokenizer =
XLNetTokenizer::from_file(vocab_path.to_str().unwrap(), false, true)?;
let _config = XLNetConfig::from_file(config_path);
// let albert_model = AlbertForMaskedLM::new(&vs.root(), &config);
// vs.load(weights_path)?;
//
// // Define input
// let input = [
// "Looks like one [MASK] is missing",
// "It was a very nice and [MASK] day",
// ];
// let tokenized_input =
// tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
// let max_len = tokenized_input
// .iter()
// .map(|input| input.token_ids.len())
// .max()
// .unwrap();
// let tokenized_input = tokenized_input
// .iter()
// .map(|input| input.token_ids.clone())
// .map(|mut input| {
// input.extend(vec![0; max_len - input.len()]);
// input
// })
// .map(|input| Tensor::of_slice(&(input)))
// .collect::<Vec<_>>();
// let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
//
// // Forward pass
// let (output, _, _) =
// no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
// println!("{:?}", output.double_value(&[0, 0, 0]));
// // Print masked tokens
// let index_1 = output.get(0).get(4).argmax(0, false);
// let index_2 = output.get(1).get(7).argmax(0, false);
// let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
// let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
//
// println!("{} - {}", &index_1.int64_value(&[]), word_1); // Outputs "_them" : "Looks like one [them] is missing"
// println!("{} - {}", &index_2.int64_value(&[]), word_2); // Outputs "_enjoyable" : "It was a very nice and [enjoyable] day"
let config = XLNetConfig::from_file(config_path);
let xlnet_model = XLNetModel::new(&vs.root() / "transformer", &config);
vs.load(weights_path)?;

// Define input
let input = ["Hello, world!"];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);

// Forward pass
let model_output = no_grad(|| {
xlnet_model
.forward_t(
Some(&input_tensor),
None,
None,
None,
None,
None,
None,
false,
)
.unwrap()
});
model_output.hidden_state.print();
Ok(())
}
12 changes: 2 additions & 10 deletions src/xlnet/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl XLNetRelativeAttention {
);
let seg_embed = p.var(
"seg_embed",
&[config.n_head, config.d_head],
&[2, config.n_head, config.d_head],
Init::KaimingUniform,
);

Expand Down Expand Up @@ -154,14 +154,6 @@ impl XLNetRelativeAttention {
}
}

fn rel_shift(&self, x: &Tensor, klen: i64) -> Tensor {
let shape = x.size();
x.reshape(&[shape[1], shape[0], shape[2], shape[3]])
.narrow(0, 1, shape[1] - 1)
.reshape(&[shape[0], shape[1] - 1, shape[2], shape[3]])
.index_select(1, &Tensor::arange(klen, (Kind::Int64, x.device())))
}

fn rel_shift_bnij(&self, x: &Tensor, klen: i64) -> Tensor {
let shape = x.size();
x.reshape(&[shape[0], shape[1], shape[3], shape[2]])
Expand Down Expand Up @@ -241,7 +233,7 @@ impl XLNetRelativeAttention {
attn_mask_g: Option<&Tensor>,
r: &Tensor,
seg_mat: Option<&Tensor>,
mut layer_state: Option<LayerState>,
layer_state: Option<LayerState>,
target_mapping: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>, Option<Tensor>, Option<Tensor>) {
Expand Down
12 changes: 3 additions & 9 deletions src/xlnet/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ impl XLNetFeedForward {
);
let layer_2 = nn::linear(
p / "layer_2",
config.d_model,
config.d_inner,
config.d_model,
Default::default(),
);

Expand Down Expand Up @@ -87,7 +87,6 @@ impl XLNetFeedForward {
pub struct XLNetLayer {
rel_attn: XLNetRelativeAttention,
ff: XLNetFeedForward,
dropout: Dropout,
}

impl XLNetLayer {
Expand All @@ -99,12 +98,7 @@ impl XLNetLayer {

let rel_attn = XLNetRelativeAttention::new(p / "rel_attn", config);
let ff = XLNetFeedForward::new(p / "ff", config);
let dropout = Dropout::new(config.dropout);
XLNetLayer {
rel_attn,
ff,
dropout,
}
XLNetLayer { rel_attn, ff }
}

pub fn forward_t(
Expand All @@ -115,7 +109,7 @@ impl XLNetLayer {
attn_mask_g: Option<&Tensor>,
r: &Tensor,
seg_mat: Option<&Tensor>,
mut layer_state: Option<LayerState>,
layer_state: Option<LayerState>,
target_mapping: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>, Option<Tensor>, Option<Tensor>) {
Expand Down
13 changes: 7 additions & 6 deletions src/xlnet/xlnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,23 @@ pub struct XLNetVocabResources;
impl XLNetModelResources {
/// Shared under Apache 2.0 license by the XLNet Authors at https://github.com/zihangdai/xlnet. Modified with conversion to C-array format.
pub const XLNET_BASE_V2: (&'static str, &'static str) = (
"xlnet-base-cased/model.ot",
"https://cdn.huggingface.co/xlnet-base-cased/rust_model.ot",
"xlnet-base-cased/model",
"https://cdn.huggingface.co/xlnet-base-cased-rust_model.ot",
);
}

impl XLNetConfigResources {
/// Shared under Apache 2.0 license by the XLNet Authors at https://github.com/zihangdai/xlnet. Modified with conversion to C-array format.
pub const XLNET_BASE_V2: (&'static str, &'static str) = (
"xlnet-base-cased/config.json",
"xlnet-base-cased/config",
"https://cdn.huggingface.co/xlnet-base-cased-config.json",
);
}

impl XLNetVocabResources {
/// Shared under Apache 2.0 license by the XLNet Authors at https://github.com/zihangdai/xlnet. Modified with conversion to C-array format.
pub const XLNET_BASE_V2: (&'static str, &'static str) = (
"xlnet-base-cased/spiece.model",
"xlnet-base-cased/spiece",
"https://cdn.huggingface.co/xlnet-base-cased-spiece.model",
);
}
Expand Down Expand Up @@ -291,7 +291,7 @@ impl XLNetModel {
let mut forward_positions_sequence = Tensor::arange2(begin, end, -1, (Kind::Float, device));
match self.clamp_len {
Some(clamp_value) if clamp_value > 0 => {
let _ = forward_positions_sequence.clamp(-clamp_value, clamp_value);
let _ = forward_positions_sequence.clamp_(-clamp_value, clamp_value);
}
_ => {}
}
Expand All @@ -300,7 +300,7 @@ impl XLNetModel {
Tensor::arange2(-begin, -end, 1, (Kind::Float, device));
match self.clamp_len {
Some(clamp_value) if clamp_value > 0 => {
let _ = backward_positions_sequence.clamp(-clamp_value, clamp_value);
let _ = backward_positions_sequence.clamp_(-clamp_value, clamp_value);
}
_ => {}
}
Expand Down Expand Up @@ -522,6 +522,7 @@ impl XLNetModel {
target_mapping.as_ref(),
train,
);
panic!();
output_h = temp.0;
output_g = temp.1;
let attention_probas_h = temp.2;
Expand Down
24 changes: 12 additions & 12 deletions utils/download-dependencies_xlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@
k = k.replace("gamma", "weight").replace("beta", "bias")
nps[k] = np.ascontiguousarray(v.cpu().numpy())

np.savez(target_path / 'model.npz', **nps)

source = str(target_path / 'model.npz')
target = str(target_path / 'model.ot')

toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()

subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])

os.remove(str(target_path / 'model.bin'))
os.remove(str(target_path / 'model.npz'))
# np.savez(target_path / 'model.npz', **nps)
#
# source = str(target_path / 'model.npz')
# target = str(target_path / 'model.ot')
#
# toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
#
# subprocess.call(
# ['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])
#
# os.remove(str(target_path / 'model.bin'))
# os.remove(str(target_path / 'model.npz'))

0 comments on commit 7db7b2b

Please sign in to comment.