Skip to content

Commit

Permalink
tch 0.12.0 Update (guillaume-be#379)
Browse files Browse the repository at this point in the history
* Fix 0.12 breaking changes

* Fix Clippy warnings

* Updated changelog
  • Loading branch information
guillaume-be authored May 11, 2023
1 parent 9fd7983 commit 5f9500c
Show file tree
Hide file tree
Showing 53 changed files with 323 additions and 271 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ All notable changes to this project will be documented in this file. The format
## Changed
- Bumped the tokenizers dependency from 7.x to 8.x, exposing additional options for special token mapping and adding the NLLBTokenizer.
- (BREAKING) Simplified the generation traits (removal of LMHeadModel and elimination of unnecessary specification for LanguageGenerator)
- Upgraded to `torch` 2.0 (via `tch` 0.11.0).
- Upgraded to `torch` 2.0 (via `tch` 0.12.0).

## Fixed
- MIN/MAX computation for float-like (was set to infinity instead of min/max)
Expand Down
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ features = ["doc-only"]

[dependencies]
rust_tokenizers = "8.1"
tch = "0.11.0"
tch = "0.12.0"
serde_json = "1"
serde = { version = "1", features = ["derive"] }
ordered-float = "3"
Expand All @@ -88,6 +88,6 @@ anyhow = "1"
csv = "1"
criterion = "0.4"
tokio = { version = "1.24", features = ["sync", "rt-multi-thread", "macros"] }
torch-sys = "0.11.0"
torch-sys = "0.12.0"
tempfile = "3"
itertools = "0.10"
4 changes: 2 additions & 2 deletions benches/tensor_operations_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ fn bench_tensor_ops(c: &mut Criterion) {
unsafe {
torch_sys::dummy_cuda_dependency();
}
let input = Tensor::rand(&[32, 128, 512], (Kind::Float, Device::cuda_if_available()));
let weights = Tensor::rand(&[512, 512], (Kind::Float, Device::cuda_if_available()));
let input = Tensor::rand([32, 128, 512], (Kind::Float, Device::cuda_if_available()));
let weights = Tensor::rand([512, 512], (Kind::Float, Device::cuda_if_available()));

let _ = &input.matmul(&weights);
c.bench_function("Matrix multiply ", |b| {
Expand Down
2 changes: 1 addition & 1 deletion src/albert/albert_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ impl AlbertModel {
get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;

let calc_mask = if mask.is_none() {
Some(Tensor::ones(&input_shape, (Kind::Int64, device)))
Some(Tensor::ones(input_shape, (Kind::Int64, device)))
} else {
None
};
Expand Down
4 changes: 2 additions & 2 deletions src/albert/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ impl AlbertSelfAttention {
self.hidden_size,
));

let context: Tensor =
Tensor::einsum("bfnd,ndh->bfh", &[context, w], None) + self.dense.bs.as_ref().unwrap();
let context: Tensor = Tensor::einsum("bfnd,ndh->bfh", &[context, w], None::<i64>)
+ self.dense.bs.as_ref().unwrap();
let context = (input_ids + context.apply_t(&self.dropout, train)).apply(&self.layer_norm);

if !self.output_attentions {
Expand Down
2 changes: 1 addition & 1 deletion src/bart/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ impl BartAttention {
.bmm(&value_states)
.view([bs, self.num_heads, target_length, self.head_dim])
.transpose(1, 2)
.reshape(&[bs, target_length, embed_dim])
.reshape([bs, target_length, embed_dim])
.apply(&self.out_proj);

(attention_output, saved_attention_weights, new_layer_state)
Expand Down
10 changes: 5 additions & 5 deletions src/bart/bart_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ pub(crate) fn _make_causal_mask(
let target_length = input_ids_shape[1];

let mut mask = Tensor::full(
&[target_length, target_length],
[target_length, target_length],
get_min(dtype).unwrap(),
(dtype, device),
);
Expand All @@ -283,14 +283,14 @@ pub(crate) fn _make_causal_mask(
if past_key_values_length > 0 {
mask = Tensor::cat(
&[
Tensor::zeros(&[target_length, past_key_values_length], (dtype, device)),
Tensor::zeros([target_length, past_key_values_length], (dtype, device)),
mask,
],
-1,
);
}
mask.unsqueeze(0).unsqueeze(0).expand(
&[
[
batch_size,
1,
target_length,
Expand All @@ -306,7 +306,7 @@ pub(crate) fn _expand_mask(mask: &Tensor, target_length: Option<i64>, dtype: Kin
let expanded_mask = mask
.unsqueeze(1)
.unsqueeze(1)
.expand(&[batch_size, 1, target_length, source_length], true)
.expand([batch_size, 1, target_length, source_length], true)
.totype(dtype);
let inverted_mask: Tensor = 1 - expanded_mask;
inverted_mask.masked_fill(&inverted_mask.to_kind(Kind::Bool), get_min(dtype).unwrap())
Expand Down Expand Up @@ -863,7 +863,7 @@ impl BartForSequenceClassification {
let reshape = eos_mask.sum_dim_intlist([1].as_slice(), true, input_ids.kind());
let sentence_representation = base_model_output
.decoder_output
.permute(&[2, 0, 1])
.permute([2, 0, 1])
.masked_select(&eos_mask)
.view((-1, reshape.size()[0] * reshape.int64_value(&[0, 0])))
.transpose(0, 1)
Expand Down
4 changes: 2 additions & 2 deletions src/bert/bert_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ impl<T: BertEmbedding> BertModel<T> {
2 => {
if self.is_decoder {
let seq_ids = Tensor::arange(input_shape[1], (Kind::Int8, device));
let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat(&[
let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat([
input_shape[0],
input_shape[1],
1,
Expand Down Expand Up @@ -407,7 +407,7 @@ impl<T: BertEmbedding> BertModel<T> {
let encoder_mask = match encoder_mask {
Some(value) => value.copy(),
None => Tensor::ones(
&[
[
encoder_hidden_states_shape[0],
encoder_hidden_states_shape[1],
],
Expand Down
2 changes: 1 addition & 1 deletion src/common/dropout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl XDropout {
impl ModuleT for XDropout {
fn forward_t(&self, input: &Tensor, train: bool) -> Tensor {
if train {
let mask = (Tensor::ones(&[1], (input.kind(), input.device()))
let mask = (Tensor::ones([1], (input.kind(), input.device()))
- input
.empty_like()
.bernoulli_float_(1_f64 - self.dropout_prob))
Expand Down
20 changes: 10 additions & 10 deletions src/deberta/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub trait DisentangledSelfAttention {
pub fn build_relative_position(query_size: i64, key_size: i64, device: Device) -> Tensor {
let q_ids = Tensor::arange(query_size, (Kind::Int64, device));
let k_ids = Tensor::arange(key_size, (Kind::Int64, device));
let rel_pos_ids = q_ids.unsqueeze(-1) - k_ids.view([1, -1]).repeat(&[query_size, 1]);
let rel_pos_ids = q_ids.unsqueeze(-1) - k_ids.view([1, -1]).repeat([query_size, 1]);
rel_pos_ids.slice(0, 0, query_size, 1).unsqueeze(0)
}

Expand All @@ -62,7 +62,7 @@ impl DebertaDisentangledSelfAttention {
let mut new_shape = x.size();
let _ = new_shape.pop();
new_shape.extend_from_slice(&[self.num_attention_heads, -1]);
x.view(new_shape.as_slice()).permute(&[0, 2, 1, 3])
x.view(new_shape.as_slice()).permute([0, 2, 1, 3])
}

fn linear(&self, weights: &Tensor, bias: Option<&Tensor>, x: &Tensor) -> Tensor {
Expand All @@ -81,7 +81,7 @@ impl DebertaDisentangledSelfAttention {
) -> Tensor {
let query_layer_size = query_layer.size();
c2p_pos.expand(
&[
[
query_layer_size[0],
query_layer_size[1],
query_layer_size[2],
Expand All @@ -101,7 +101,7 @@ impl DebertaDisentangledSelfAttention {
let mut key_layer_size = key_layer.size();
key_layer_size.reverse();
c2p_pos.expand(
&[
[
query_layer_size[0],
query_layer_size[1],
key_layer_size[1],
Expand Down Expand Up @@ -182,7 +182,7 @@ impl DebertaDisentangledSelfAttention {
)
.unsqueeze(0);

let mut score = Tensor::zeros(&[1], (query_layer.kind(), key_layer.device()));
let mut score = Tensor::zeros([1], (query_layer.kind(), key_layer.device()));

// content -> position
if let Some(pos_proj) = &self.pos_proj {
Expand Down Expand Up @@ -410,24 +410,24 @@ impl DisentangledSelfAttention for DebertaDisentangledSelfAttention {

if let Some(head_logits_proj) = &self.head_logits_proj {
attention_scores = attention_scores
.permute(&[0, 2, 3, 1])
.permute([0, 2, 3, 1])
.apply(head_logits_proj)
.permute(&[0, 3, 1, 2]);
.permute([0, 3, 1, 2]);
}

let mut attention_probs =
x_softmax(&attention_scores, attention_mask, -1).apply_t(&self.dropout, train);

if let Some(head_weights_proj) = &self.head_weights_proj {
attention_probs = attention_probs
.permute(&[0, 2, 3, 1])
.permute([0, 2, 3, 1])
.apply(head_weights_proj)
.permute(&[0, 3, 1, 2]);
.permute([0, 3, 1, 2]);
}

let context_layer = attention_probs
.matmul(&value_layer)
.permute(&[0, 2, 1, 3])
.permute([0, 2, 1, 3])
.contiguous();

let mut new_context_layer_shape = context_layer.size();
Expand Down
4 changes: 2 additions & 2 deletions src/deberta/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,15 @@ where
let calc_position_ids = if position_ids.is_none() {
Some(
Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
.expand(&[1, -1], true),
.expand([1, -1], true),
)
} else {
None
};

let calc_token_type_ids = if token_type_ids.is_none() {
Some(Tensor::zeros(
&input_shape,
input_shape,
(Kind::Int64, input_embeddings.device()),
))
} else {
Expand Down
18 changes: 9 additions & 9 deletions src/deberta_v2/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pub fn build_relative_position(
) -> Tensor {
let q_ids = Tensor::arange(query_size, (Kind::Int64, device));
let k_ids = Tensor::arange(key_size, (Kind::Int64, device));
let mut rel_pos_ids = q_ids.unsqueeze(-1) - k_ids.tile(&[q_ids.size()[0], 1]);
let mut rel_pos_ids = q_ids.unsqueeze(-1) - k_ids.tile([q_ids.size()[0], 1]);
if (bucket_size > 0) & (max_position > 0) {
rel_pos_ids = make_log_bucket_position(&rel_pos_ids, bucket_size, max_position);
}
Expand Down Expand Up @@ -80,7 +80,7 @@ impl DebertaV2DisentangledSelfAttention {
let _ = new_shape.pop();
new_shape.extend_from_slice(&[self.num_attention_heads, -1]);
let x = x.view(new_shape.as_slice());
x.permute(&[0, 2, 1, 3])
x.permute([0, 2, 1, 3])
.contiguous()
.view([-1, x.size()[1], *x.size().last().unwrap()])
}
Expand Down Expand Up @@ -133,12 +133,12 @@ impl DebertaV2DisentangledSelfAttention {

let pos_query_layer = self
.transpose_for_scores(&relative_embeddings.apply(query_proj))
.repeat(&[query_layer.size()[0] / self.num_attention_heads, 1, 1]);
.repeat([query_layer.size()[0] / self.num_attention_heads, 1, 1]);
let pos_key_layer = self
.transpose_for_scores(&relative_embeddings.apply(key_proj))
.repeat(&[query_layer.size()[0] / self.num_attention_heads, 1, 1]);
.repeat([query_layer.size()[0] / self.num_attention_heads, 1, 1]);

let mut score = Tensor::zeros(&[1], (query_layer.kind(), query_layer.device()));
let mut score = Tensor::zeros([1], (query_layer.kind(), query_layer.device()));

let c2p_pos = if self.pos_att_type.has_type(PositionAttentionType::c2p)
| self.pos_att_type.has_type(PositionAttentionType::p2p)
Expand All @@ -149,7 +149,7 @@ impl DebertaV2DisentangledSelfAttention {
let c2p_att = c2p_att.gather(
-1,
&c2p_pos.squeeze_dim(0).expand(
&[
[
query_layer.size()[0],
query_layer.size()[1],
*relative_pos.size().last().unwrap(),
Expand Down Expand Up @@ -186,7 +186,7 @@ impl DebertaV2DisentangledSelfAttention {
.gather(
-1,
&p2c_pos.squeeze_dim(0).expand(
&[query_layer.size()[0], key_layer_size[1], key_layer_size[1]],
[query_layer.size()[0], key_layer_size[1], key_layer_size[1]],
true,
),
true,
Expand All @@ -203,7 +203,7 @@ impl DebertaV2DisentangledSelfAttention {
let p2p_att = p2p_att.gather(
-1,
&c2p_pos.unwrap().expand(
&[
[
query_layer.size()[0],
query_layer.size()[1],
query_layer.size()[2],
Expand Down Expand Up @@ -402,7 +402,7 @@ impl DisentangledSelfAttention for DebertaV2DisentangledSelfAttention {
reverse_context_layer_size[1],
reverse_context_layer_size[0],
])
.permute(&[0, 2, 1, 3])
.permute([0, 2, 1, 3])
.contiguous();

let mut new_context_layer_shape = context_layer.size();
Expand Down
4 changes: 2 additions & 2 deletions src/deberta_v2/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ impl ConvLayer {
train: bool,
) -> Tensor {
let out = hidden_states
.permute(&[0, 2, 1])
.permute([0, 2, 1])
.contiguous()
.apply(&self.conv)
.permute(&[0, 2, 1])
.permute([0, 2, 1])
.contiguous();
let reverse_mask: Tensor = 1 - input_mask;
let out = out.masked_fill(
Expand Down
2 changes: 1 addition & 1 deletion src/electra/electra_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ impl ElectraModel {
get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;

let calc_mask = if mask.is_none() {
Some(Tensor::ones(&input_shape, (Kind::Int64, device)))
Some(Tensor::ones(input_shape, (Kind::Int64, device)))
} else {
None
};
Expand Down
4 changes: 3 additions & 1 deletion src/fnet/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ impl FNetFourierTransform {
}

pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
let self_outputs = hidden_states.fft_fft2(None, &[1, 2], "backward").real();
let self_outputs = hidden_states
.fft_fft2(None::<i64>, [1, 2], "backward")
.real();
(self_outputs + hidden_states).apply(&self.layer_norm)
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/gpt2/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ impl Attention {
{
let p = p.borrow();

let bias = Tensor::ones(&[config.n_ctx, config.n_ctx], (Float, p.device()))
let bias = Tensor::ones([config.n_ctx, config.n_ctx], (Float, p.device()))
.tril(0)
.view((1, 1, config.n_ctx, config.n_ctx));

Expand Down Expand Up @@ -111,9 +111,9 @@ impl Attention {
fn split_heads(&self, x: &Tensor, k: bool) -> Tensor {
let x = x.view((x.size()[0], -1, self.n_head, self.dim_per_head));
if k {
x.permute(&[0, 2, 3, 1])
x.permute([0, 2, 3, 1])
} else {
x.permute(&[0, 2, 1, 3])
x.permute([0, 2, 1, 3])
}
}

Expand Down
Loading

0 comments on commit 5f9500c

Please sign in to comment.