Skip to content

Commit

Permalink
feature(tensor): Add chunk op (tracel-ai#998)
Browse files Browse the repository at this point in the history
  • Loading branch information
dcvz authored Nov 27, 2023
1 parent 2fdf9a3 commit 929b178
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 0 deletions.
1 change: 1 addition & 0 deletions burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
| `tensor.device()` | `tensor.device` |
| `tensor.to_device(device)` | `tensor.to(device)` |
| `tensor.repeat(2, 4)` | `tensor.repeat([1, 1, 4])` |
Expand Down
39 changes: 39 additions & 0 deletions burn-tensor/src/tensor/api/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,45 @@ where

self.slice(ranges_array)
}

/// Attempts to split the tensor along the given dimension into chunks.
/// May return less chunks than requested if the tensor size is not divisible by the number of chunks.
///
/// When the given dimension is evenly divisible by the number of chunks, the chunks will be of equal size.
/// Otherwise all chunks will be of equal size except for the last one.
///
/// # Panics
///
/// If the dimension is greater than the number of dimensions of the tensor.
///
/// # Returns
/// A vector of tensors.
pub fn chunk(self, chunks: usize, dim: usize) -> Vec<Self> {
check!(TensorCheck::dim_ops::<D>("chunk", dim));

let size = self.shape().dims[dim];
if size < chunks {
return (0..size).map(|i| self.clone().narrow(dim, i, 1)).collect();
}

let chunk_size = size / chunks;
let cnt_additional = size % chunks;
let mut tensors = Vec::with_capacity(chunks);

let mut sum_chunk_size = 0;
for i in 0..chunks {
let chunk_size = if i < cnt_additional {
chunk_size + 1
} else {
chunk_size
};

tensors.push(self.clone().narrow(dim, sum_chunk_size, chunk_size));
sum_chunk_size += chunk_size;
}

tensors
}
}

/// Iterator given by (Tensor::iter_dim).
Expand Down
1 change: 1 addition & 0 deletions burn-tensor/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_arg!();
burn_tensor::testgen_cast!();
burn_tensor::testgen_cat!();
burn_tensor::testgen_chunk!();
burn_tensor::testgen_clamp!();
burn_tensor::testgen_cos!();
burn_tensor::testgen_create_like!();
Expand Down
81 changes: 81 additions & 0 deletions burn-tensor/src/tests/ops/chunk.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#[burn_tensor_testgen::testgen(chunk)]
mod tests {
use super::*;
use alloc::vec::Vec;
use burn_tensor::{Data, Int, Shape, Tensor};

fn test_chunk_evenly_divisible() {
let tensors: Vec<Tensor<TestBackend, 1, Int>> = Tensor::arange(0..12).chunk(6, 0);
assert_eq!(tensors.len(), 6);

let expected = vec![
Data::from([0, 1]),
Data::from([2, 3]),
Data::from([4, 5]),
Data::from([6, 7]),
Data::from([8, 9]),
Data::from([10, 11]),
];

for (index, tensor) in tensors.iter().enumerate() {
assert_eq!(tensor.to_data(), expected[index]);
}
}

#[test]
fn test_chunk_not_evenly_divisible() {
let tensors: Vec<Tensor<TestBackend, 1, Int>> = Tensor::arange(0..11).chunk(6, 0);
assert_eq!(tensors.len(), 6);

let expected = vec![
Data::from([0, 1]),
Data::from([2, 3]),
Data::from([4, 5]),
Data::from([6, 7]),
Data::from([8, 9]),
Data::from([10]),
];

for (index, tensor) in tensors.iter().enumerate() {
assert_eq!(tensor.to_data(), expected[index]);
}
}

#[test]
fn test_chunk_not_divisible() {
let tensors: Vec<Tensor<TestBackend, 1, Int>> = Tensor::arange(0..6).chunk(7, 0);
assert_eq!(tensors.len(), 6);

let expected = vec![
Data::from([0]),
Data::from([1]),
Data::from([2]),
Data::from([3]),
Data::from([4]),
Data::from([5]),
];

for (index, tensor) in tensors.iter().enumerate() {
assert_eq!(tensor.to_data(), expected[index]);
}
}

#[test]
fn test_chunk_multi_dimension() {
let tensors: Vec<Tensor<TestBackend, 2, Int>> =
Tensor::from_data(Data::from([[0, 1, 2, 3]])).chunk(2, 1);
assert_eq!(tensors.len(), 2);

let expected = vec![Data::from([[0, 1]]), Data::from([[2, 3]])];

for (index, tensor) in tensors.iter().enumerate() {
assert_eq!(tensor.to_data(), expected[index]);
}
}

#[test]
#[should_panic]
fn test_invalid_dim() {
let tensors: Vec<Tensor<TestBackend, 1, Int>> = Tensor::arange(0..12).chunk(6, 1);
}
}
1 change: 1 addition & 0 deletions burn-tensor/src/tests/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod arange_step;
mod arg;
mod cast;
mod cat;
mod chunk;
mod clamp;
mod cos;
mod create_like;
Expand Down

0 comments on commit 929b178

Please sign in to comment.