Skip to content

Commit

Permalink
feature(tensor): Add narrow op (tracel-ai#996)
Browse files Browse the repository at this point in the history
* Add narrow methods

* Revert "Add narrow methods"

This reverts commit 9371d87.

* Implement a shared version of narrow

* Correct test case

* Update book

* Improve tests
  • Loading branch information
dcvz authored Nov 24, 2023
1 parent e5e4771 commit f09baad
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 0 deletions.
1 change: 1 addition & 0 deletions burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
| `tensor.device()` | `tensor.device` |
| `tensor.to_device(device)` | `tensor.to(device)` |
| `tensor.repeat(2, 4)` | `tensor.repeat([1, 1, 4])` |
Expand Down
28 changes: 28 additions & 0 deletions burn-tensor/src/tensor/api/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,34 @@ where
check!(TensorCheck::dim_ops::<D>("iter_dim", dim));
DimIter::new(self, dim)
}

/// Returns a new tensor with the given dimension narrowed to the given range.
///
/// # Panics
///
/// - If the dimension is greater than the number of dimensions of the tensor.
/// - If the given range exceeds the number of elements on the given dimension.
///
/// # Returns
///
/// A new tensor with the given dimension narrowed to the given range.
pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self {
check!(TensorCheck::narrow(&self, dim, start, length));

let ranges: Vec<_> = (0..D)
.map(|i| {
if i == dim {
start..(start + length)
} else {
0..self.shape().dims[i]
}
})
.collect();

let ranges_array: [_; D] = ranges.try_into().unwrap();

self.slice(ranges_array)
}
}

/// Iterator given by (Tensor::iter_dim).
Expand Down
51 changes: 51 additions & 0 deletions burn-tensor/src/tensor/api/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,57 @@ impl TensorCheck {
check
}

pub(crate) fn narrow<B: Backend, const D: usize, K: BasicOps<B>>(
tensor: &Tensor<B, D, K>,
dim: usize,
start: usize,
length: usize,
) -> Self {
let mut check = Self::Ok;

if dim >= D {
check = check.register(
"Unsqueeze",
TensorError::new(format!(
"Can't unsqueeze at dimension {}, exceeds tensor dimensions (D={})",
dim, D
)),
);
}

if length == 0 {
check = check.register(
"Narrow",
TensorError::new(format!(
"Can't narrow at dimension {}, length must be greater than 0",
dim
)),
);
}

if start >= tensor.shape().dims[dim] {
check = check.register(
"Narrow",
TensorError::new(format!(
"Can't narrow at dimension {}, start exceeds tensor dimensions (D={})",
dim, D
)),
);
}

if start + length > tensor.shape().dims[dim] {
check = check.register(
"Narrow",
TensorError::new(format!(
"Can't narrow at dimension {}, start + length exceeds tensor dimensions (D={})",
dim, D
)),
);
}

check
}

pub(crate) fn reshape_args_usize<const D1: usize, const D2: usize>(
original: &Shape<D1>,
target: &Shape<D2>,
Expand Down
1 change: 1 addition & 0 deletions burn-tensor/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_matmul!();
burn_tensor::testgen_maxmin!();
burn_tensor::testgen_mul!();
burn_tensor::testgen_narrow!();
burn_tensor::testgen_neg!();
burn_tensor::testgen_one_hot!();
burn_tensor::testgen_powf!();
Expand Down
1 change: 1 addition & 0 deletions burn-tensor/src/tests/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ mod mask;
mod matmul;
mod maxmin;
mod mul;
mod narrow;
mod neg;
mod one_hot;
mod powf;
Expand Down
59 changes: 59 additions & 0 deletions burn-tensor/src/tests/ops/narrow.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#[burn_tensor_testgen::testgen(narrow)]
mod tests {
use super::*;
use burn_tensor::{Data, Shape, Tensor};

#[test]
fn test_narrow() {
let tensor: Tensor<TestBackend, 2> =
Tensor::from_data(Data::from([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]));

let output = tensor.clone().narrow(0, 0, 2);
assert_eq!(output.shape(), Shape::from([2, 3]));
output
.to_data()
.assert_approx_eq(&Data::from([[1., 2., 3.], [4., 5., 6.]]), 3);

let output = tensor.clone().narrow(1, 1, 2);
assert_eq!(output.shape(), Shape::from([3, 2]));
output
.to_data()
.assert_approx_eq(&Data::from([[2., 3.], [5., 6.], [8., 9.]]), 3);
}

#[test]
#[should_panic]
fn test_narrow_invalid_dim() {
let tensor: Tensor<TestBackend, 2> =
Tensor::from_data(Data::from([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]));

let output = tensor.narrow(2, 0, 2);
}

#[test]
#[should_panic]
fn test_narrow_invalid_start() {
let tensor: Tensor<TestBackend, 2> =
Tensor::from_data(Data::from([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]));

let output = tensor.narrow(0, 3, 2);
}

#[test]
#[should_panic]
fn test_narrow_invalid_zero_length() {
let tensor: Tensor<TestBackend, 2> =
Tensor::from_data(Data::from([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]));

let output = tensor.narrow(0, 1, 0);
}

#[test]
#[should_panic]
fn test_narrow_invalid_length() {
let tensor: Tensor<TestBackend, 2> =
Tensor::from_data(Data::from([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]));

let output = tensor.narrow(0, 0, 4);
}
}

0 comments on commit f09baad

Please sign in to comment.