Skip to content

Commit

Permalink
Making Embedding.weight public (tracel-ai#1094)
Browse files Browse the repository at this point in the history
  • Loading branch information
unrenormalizable authored Dec 22, 2023
1 parent fceb036 commit 40ec289
Show file tree
Hide file tree
Showing 12 changed files with 88 additions and 85 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ Cargo.lock
.idea
.vscode
.fleet
.vs
4 changes: 3 additions & 1 deletion burn-core/src/nn/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ pub struct EmbeddingConfig {
/// `N(0, 1)`
#[derive(Module, Debug)]
pub struct Embedding<B: Backend> {
weight: Param<Tensor<B, 2>>,
/// The learnable weights of the module of shape [n_embedding, d_model] initialized
/// from a normal distribution `N(0, 1)`.
pub weight: Param<Tensor<B, 2>>,
}

impl EmbeddingConfig {
Expand Down
2 changes: 1 addition & 1 deletion burn-fusion/src/ops/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
}
}

let tensor_first = tensors.get(0).unwrap();
let tensor_first = tensors.first().unwrap();
let client = tensor_first.client.clone();

// Calculate the output shape
Expand Down
2 changes: 1 addition & 1 deletion burn-fusion/src/ops/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1367,7 +1367,7 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
}
}

let tensor_first = tensors.get(0).unwrap();
let tensor_first = tensors.first().unwrap();
let client = tensor_first.client.clone();

// Calculate the output shape
Expand Down
2 changes: 1 addition & 1 deletion burn-fusion/src/ops/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}
}

let tensor_first = tensors.get(0).unwrap();
let tensor_first = tensors.first().unwrap();
let client = tensor_first.client.clone();

// Calculate the output shape
Expand Down
10 changes: 5 additions & 5 deletions burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ pub fn flatten_config(curr: &Node) -> (usize, usize) {
}

// extract the shape of the input tensor
let tensor = match curr.inputs.get(0).unwrap().clone().ty {
let tensor = match curr.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};
Expand Down Expand Up @@ -262,7 +262,7 @@ pub fn gather_config(curr: &Node) -> usize {
}

// extract the shape of the input tensor
let tensor = match curr.inputs.get(0).unwrap().clone().ty {
let tensor = match curr.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};
Expand Down Expand Up @@ -355,7 +355,7 @@ pub fn log_softmax_config(node: &Node) -> usize {
}

// extract the shape of the input tensor
let tensor = match node.inputs.get(0).unwrap().clone().ty {
let tensor = match node.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};
Expand Down Expand Up @@ -390,7 +390,7 @@ pub fn softmax_config(node: &Node) -> usize {
}

// extract the shape of the input tensor
let tensor = match node.inputs.get(0).unwrap().clone().ty {
let tensor = match node.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};
Expand All @@ -417,7 +417,7 @@ pub fn concat_config(node: &Node) -> usize {
let mut axis: i64 = 1;

// extract the shape of the input tensor
let tensor = match node.inputs.get(0).unwrap().clone().ty {
let tensor = match node.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};
Expand Down
Loading

0 comments on commit 40ec289

Please sign in to comment.