diff --git a/.gitignore b/.gitignore index 088ba6ba7d..22d351634a 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,8 @@ Cargo.lock # These are backup files generated by rustfmt **/*.rs.bk + + +# Added by cargo + +/target diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000000..804dbdf117 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "burn" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] diff --git a/burn-tensor/.gitignore b/burn-tensor/.gitignore new file mode 100644 index 0000000000..22d351634a --- /dev/null +++ b/burn-tensor/.gitignore @@ -0,0 +1,15 @@ +# Generated by Cargo +# will have compiled files and executables +/target/ + +# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries +# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html +Cargo.lock + +# These are backup files generated by rustfmt +**/*.rs.bk + + +# Added by cargo + +/target diff --git a/burn-tensor/Cargo.toml b/burn-tensor/Cargo.toml new file mode 100644 index 0000000000..4f858540cd --- /dev/null +++ b/burn-tensor/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "burn-tensor" +version = "0.1.0" +edition = "2021" + +[features] +default = ["tch"] +full = ["arrayfire", "tch"] +arrayfire = ["dep:arrayfire"] +tch = ["dep:tch"] + +[dependencies] +derive-new = "0.5" +rand = "0.8" +num-traits = "0.2" +once_cell = "1.1" + +# Backends +arrayfire = { version = "3.8", optional = true } +tch = { version = "0.8", optional = true } + +# Autodiff +nanoid = "0.4" + +[dev-dependencies] +float-cmp = "0.9.0" diff --git a/burn-tensor/env.bash b/burn-tensor/env.bash new file mode 100644 index 0000000000..caff49fe37 --- /dev/null +++ b/burn-tensor/env.bash @@ -0,0 +1,7 @@ +echo "Setup arrayfire backend" +export AF_PATH=$HOME/.local/share/arrayfire +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$AF_PATH/lib64 + +echo "Setup tch backend" +export LIBTORCH=${HOME}/.local/lib/libtorch +export LD_LIBRARY_PATH=${LIBTORCH}/lib:$LD_LIBRARY_PATH diff --git a/burn-tensor/src/graph/mod.rs b/burn-tensor/src/graph/mod.rs new file mode 100644 index 0000000000..81d8025cce --- /dev/null +++ b/burn-tensor/src/graph/mod.rs @@ -0,0 +1,5 @@ +pub mod node; +pub mod ops; +pub mod tape; + +mod node_old; diff --git a/burn-tensor/src/graph/node.rs b/burn-tensor/src/graph/node.rs new file mode 100644 index 0000000000..70d22dec16 --- /dev/null +++ b/burn-tensor/src/graph/node.rs @@ -0,0 +1,95 @@ +use std::{ + cell::RefCell, + ops::{Add, Mul}, + rc::Rc, +}; + +#[derive(PartialEq, Eq, Debug, Clone, Hash)] +pub struct NodeId { + value: String, +} + +impl NodeId { + pub fn new() -> Self { + Self { + value: nanoid::nanoid!(), + } + } +} + +pub trait Node: std::fmt::Debug { + fn id(&self) -> NodeId; + fn grad(&mut self) -> Out; + fn value(&self) -> Out; + fn update_grad(&mut self, grad: Out); +} +pub type NodeRef = Rc>>; + +pub trait Zeros { + fn zeros(&self) -> T; +} +pub trait Ones { + fn ones(&self) -> T; +} + +#[derive(Debug)] +pub struct RootNode { + pub id: NodeId, + pub value: Out, + pub grad: Option, +} + +impl RootNode { + pub fn new(value: Out) -> Self { + Self { + id: NodeId::new(), + value, + grad: None, + } + } +} + +impl Node for RootNode +where + Out: Zeros + Clone + Mul + Add, + Out: std::fmt::Debug, +{ + fn id(&self) -> NodeId { + self.id.clone() + } + fn grad(&mut self) -> Out { + let grad_self: Out = match &self.grad { + Some(val) => val.clone(), + None => self.value.zeros(), + }; + self.grad = Some(grad_self.clone()); + grad_self + } + + fn value(&self) -> Out { + self.value.clone() + } + + fn update_grad(&mut self, grad: Out) { + self.grad = Some(self.grad() + grad); + } +} + +#[macro_export] +macro_rules! node_init { + ( lhs $lhs:expr, rhs $rhs:expr, out $out:expr, ) => {{ + use $crate::graph::ops::BinaryOpsNode; + let node = BinaryOpsNode::new($lhs, $rhs, $out); + std::rc::Rc::new(std::cell::RefCell::new(node)) + }}; + ( input $input:expr, out $out:expr, ) => {{ + use $crate::graph::ops::SingleOpsNode; + let node = SingleOpsNode::new($input, $out); + std::rc::Rc::new(std::cell::RefCell::new(node)) + }}; + ( root $out:expr ) => {{ + use $crate::graph::node::RootNode; + let node = RootNode::new($out); + std::rc::Rc::new(std::cell::RefCell::new(node)) + }}; +} diff --git a/burn-tensor/src/graph/node_old.rs b/burn-tensor/src/graph/node_old.rs new file mode 100644 index 0000000000..237f6600cc --- /dev/null +++ b/burn-tensor/src/graph/node_old.rs @@ -0,0 +1,345 @@ +use crate::node::{Ones, Zeros}; +use std::ops::{Add, Div, Mul, Neg, Sub}; +use std::sync::RwLock; + +#[derive(new, Clone, Copy)] +pub struct Partial { + pub parent_position: usize, + pub partial: T, +} + +pub trait Node: Send + Sync { + fn parent_left(&self) -> Option>; + fn parent_right(&self) -> Option>; + fn position(&self) -> usize; +} + +#[derive(new)] +pub struct RootNode { + index: usize, +} + +impl Node for RootNode { + fn parent_left(&self) -> Option> { + None + } + fn parent_right(&self) -> Option> { + None + } + + fn position(&self) -> usize { + self.index + } +} + +#[derive(new)] +pub struct BinaryOperationNode { + left: Partial, + right: Partial, + index: usize, +} + +impl Node for BinaryOperationNode { + fn parent_left(&self) -> Option> { + Some(self.left.clone()) + } + fn parent_right(&self) -> Option> { + Some(self.right.clone()) + } + fn position(&self) -> usize { + self.index + } +} + +#[derive(new)] +pub struct UnaryOperationNode { + left: Partial, + index: usize, +} + +impl Node for UnaryOperationNode { + fn parent_left(&self) -> Option> { + Some(self.left.clone()) + } + + fn parent_right(&self) -> Option> { + None + } + + fn position(&self) -> usize { + self.index + } +} + +/// Tape holding the computation graph +pub struct Tape { + pub nodes: RwLock>>>, + pub grads: RwLock>, +} + +impl Tape +where + T: Zeros + Ones + Clone + Send + Sync + 'static, +{ + pub fn new() -> Tape { + Tape { + nodes: RwLock::new(Vec::new()), + grads: RwLock::new(Vec::new()), + } + } + + pub fn register(&self, value: T) -> Var { + let mut nodes = self.nodes.write().unwrap(); + let mut grads = self.grads.write().unwrap(); + + let position = nodes.len(); + + nodes.push(Box::new(RootNode::new(position))); + grads.push(value.zeros()); + + Var::new(self, position, value) + } + + pub fn register_unary(&self, partial: T, index: usize, value: T) -> Var { + let mut nodes = self.nodes.write().unwrap(); + let mut grads = self.grads.write().unwrap(); + + let len = nodes.len(); + + nodes.push(Box::new(UnaryOperationNode::new( + Partial::new(index, partial), + len, + ))); + grads.push(value.zeros()); + + Var::new(self, len, value) + } + + pub fn register_binary( + &self, + partial_left: T, + partial_right: T, + parent_left_index: usize, + parent_right_index: usize, + value: T, + ) -> Var { + let mut nodes = self.nodes.write().unwrap(); + let mut grads = self.grads.write().unwrap(); + + let len = nodes.len(); + + nodes.push(Box::new(BinaryOperationNode::new( + Partial::new(parent_left_index, partial_left), + Partial::new(parent_right_index, partial_right), + len, + ))); + grads.push(value.zeros()); + + Var::new(self, len, value) + } +} + +/// Variable for computations +#[derive(new, Clone, Copy)] +pub struct Var<'t, T> { + /// Pointer to the tape holding the corresponding node + pub tape: &'t Tape, + /// Index of the node in the tape + pub index: usize, + /// Value + pub v: T, +} + +impl Var<'_, T> +where + T: Mul + Clone + Zeros + Ones + Add + Copy, +{ + /// Perform back propagation + pub fn backprop(self) -> Grad { + // vector storing the gradients + let nodes = self.tape.nodes.read().unwrap(); + let tape_len = nodes.len(); + let mut grad: Vec = self.tape.grads.read().unwrap().to_owned(); + grad[self.index] = grad[self.index].ones(); + + // iterate through the tape from back to front + // because during forward pass, we always store new nodes at the end + // of the tape, when we do the backward pass we can + // just incrementally add partial * adjoint + for i in (0..tape_len).rev() { + let node = &nodes[i]; + // increment gradient contribution to the left parent + if let Some(parent) = node.parent_left() { + let grad_value = grad[parent.parent_position]; + grad[parent.parent_position] = grad_value + (parent.partial * grad[i].clone()); + } + + // increment gradient contribution to the right parent + if let Some(parent) = node.parent_right() { + let grad_value = grad[parent.parent_position]; + grad[parent.parent_position] = grad_value + (parent.partial * grad[i].clone()); + } + } + + // TODO: reset tape; + Grad { grad } + } +} + +impl<'t, T> Add for Var<'t, T> +where + T: Ones + Zeros + Add + Clone + Send + Sync + 'static, +{ + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + self.tape.register_binary( + self.v.ones(), + self.v.ones(), + self.index, + rhs.index, + self.v + rhs.v, + ) + } +} + +impl<'t, T> Sub for Var<'t, T> +where + T: Ones + Zeros + Neg + Sub + Clone + Send + Sync + 'static, +{ + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + self.tape.register_binary( + self.v.ones(), + -self.v.ones(), + self.index, + rhs.index, + self.v - rhs.v, + ) + } +} + +impl<'t, T> Neg for Var<'t, T> +where + T: Ones + Zeros + Neg + Sub + Clone + Send + Sync + 'static, +{ + type Output = Self; + + fn neg(self) -> Self::Output { + self.tape + .register_unary(-self.v.ones(), self.index, -self.v) + } +} + +impl<'t, T> Mul for Var<'t, T> +where + T: Ones + Zeros + Mul + Clone + Send + Sync + 'static, +{ + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + self.tape.register_binary( + rhs.v.clone(), + self.v.clone(), + self.index, + rhs.index, + self.v * rhs.v, + ) + } +} +impl<'t, T> Div for Var<'t, T> +where + T: Ones + + Zeros + + Neg + + Div + + Mul + + Clone + + Send + + Sync + + 'static, +{ + type Output = Self; + + fn div(self, rhs: Self) -> Self::Output { + self.tape.register_binary( + self.v.ones() / rhs.v.clone(), + -self.v.clone() / (rhs.v.clone() * rhs.v.clone()), + self.index, + rhs.index, + self.v.clone() / rhs.v.clone(), + ) + } +} + +impl<'t> Mul> for f64 { + type Output = Var<'t, f64>; + + fn mul(self, rhs: Var<'t, f64>) -> Self::Output { + rhs.tape.register_unary(self, rhs.index, self * rhs.v) + } +} + +/// Struct holding gradients +#[derive(Debug)] +pub struct Grad { + pub grad: Vec, +} + +impl Grad +where + T: Clone, +{ + /// Get the gradient with respect to a variable + pub fn wrt(&self, var: Var) -> T { + self.grad[var.index].clone() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use float_cmp::*; + + #[test] + fn addition_test() { + let tape = Tape::new(); + let x = tape.register(1.0); + let y = tape.register(4.0); + let z = x + y; + let grad = z.backprop(); + assert!(approx_eq!(f64, grad.wrt(x), 1.0, ulps = 5)); + assert!(approx_eq!(f64, grad.wrt(y), 1.0, ulps = 5)); + } + + #[test] + fn mul_test() { + let tape = Tape::new(); + let x = tape.register(1.0); + let y = tape.register(4.0); + let z = x * y; + let grad = z.backprop(); + assert!(approx_eq!(f64, grad.wrt(x), y.v, ulps = 5)); + assert!(approx_eq!(f64, grad.wrt(y), x.v, ulps = 5)); + } + + #[test] + fn neg_test() { + let tape = Tape::new(); + let x = tape.register(1.0); + let z = -x; + let grad = z.backprop(); + assert!(approx_eq!(f64, grad.wrt(x), -1.0, ulps = 5)); + } + + #[test] + fn multiple_multiplications_test() { + let tape = Tape::new(); + let x = tape.register(1.0); + let y = tape.register(1.0); + let z = -2.0 * x + x * x * x * y; + let grad = z.backprop(); + assert!(approx_eq!(f64, grad.wrt(x), 1.0, ulps = 5)); + } +} diff --git a/burn-tensor/src/graph/ops/binary.rs b/burn-tensor/src/graph/ops/binary.rs new file mode 100644 index 0000000000..3ec06fd3e2 --- /dev/null +++ b/burn-tensor/src/graph/ops/binary.rs @@ -0,0 +1,102 @@ +use super::{BinaryRecordedState, RecordedOps}; +use crate::node::{Node, NodeId, NodeRef, Ones, Zeros}; +use std::ops::{Add, Mul}; + +pub trait BinaryOps: std::fmt::Debug { + fn forward(&self, left: Lhs, right: Rhs) -> Out; + fn partial_left(&self, state: &BinaryRecordedState) -> Lhs; + fn partial_right(&self, state: &BinaryRecordedState) -> Rhs; +} + +#[derive(Debug)] +pub struct BinaryOpsNode { + pub id: NodeId, + pub parent_left: NodeRef, + pub parent_right: NodeRef, + pub value: Out, + pub grad: Option, +} + +#[derive(new, Debug)] +pub struct BinaryRecordedOps { + lhs: NodeRef, + rhs: NodeRef, + out: NodeRef, + ops: Ops, +} + +impl BinaryOpsNode { + pub fn new(parent_left: NodeRef, parent_right: NodeRef, value: Out) -> Self { + Self { + id: NodeId::new(), + parent_left, + parent_right, + value, + grad: None, + } + } +} + +impl Node for BinaryOpsNode +where + Out: Zeros + Clone + Mul + Add, + Lhs: std::fmt::Debug, + Rhs: std::fmt::Debug, + Out: std::fmt::Debug, +{ + fn id(&self) -> NodeId { + self.id.clone() + } + fn value(&self) -> Out { + self.value.clone() + } + + fn grad(&mut self) -> Out { + let grad_self = match &self.grad { + Some(val) => val.clone(), + None => self.value.zeros(), + }; + self.grad = Some(grad_self.clone()); + grad_self + } + + fn update_grad(&mut self, grad: Out) { + self.grad = Some(self.grad() + grad); + } +} + +impl RecordedOps for BinaryRecordedOps +where + Lhs: Clone + Zeros + Mul, + Rhs: Clone + Zeros + Mul, + Out: Clone + Zeros + Ones + 'static, + Lhs: std::fmt::Debug, + Rhs: std::fmt::Debug, + Out: std::fmt::Debug, + Ops: BinaryOps, +{ + fn id(&self) -> NodeId { + self.out.borrow().id() + } + + fn backward(&mut self) { + let left = self.lhs.borrow().value(); + let right = self.rhs.borrow().value(); + let output = self.out.borrow().value(); + let state = BinaryRecordedState::new(&left, &right, &output); + + let partial_left = self.ops.partial_left(&state); + let partial_right: Rhs = self.ops.partial_right(&state); + let grad_mine = self.out.borrow_mut().grad(); + + self.lhs + .borrow_mut() + .update_grad(partial_left * grad_mine.clone()); + self.rhs.borrow_mut().update_grad(partial_right * grad_mine); + } + + fn set_last_ops(&mut self) { + let value = self.out.borrow().value(); + self.out.borrow_mut().update_grad(value.ones()); + } +} diff --git a/burn-tensor/src/graph/ops/mod.rs b/burn-tensor/src/graph/ops/mod.rs new file mode 100644 index 0000000000..80f0a1560c --- /dev/null +++ b/burn-tensor/src/graph/ops/mod.rs @@ -0,0 +1,9 @@ +mod binary; +mod ops; +mod root; +mod single; + +pub use binary::*; +pub use ops::*; +pub use root::*; +pub use single::*; diff --git a/burn-tensor/src/graph/ops/ops.rs b/burn-tensor/src/graph/ops/ops.rs new file mode 100644 index 0000000000..7b1a5c0477 --- /dev/null +++ b/burn-tensor/src/graph/ops/ops.rs @@ -0,0 +1,21 @@ +use crate::node::NodeId; + +#[derive(new)] +pub struct BinaryRecordedState<'a, Lhs, Rhs, Out> { + pub left: &'a Lhs, + pub right: &'a Rhs, + pub output: &'a Out, +} + +#[derive(new)] +pub struct SingleRecordedState<'a, In, Out> { + pub input: &'a In, + pub output: &'a Out, +} + +pub trait RecordedOps: std::fmt::Debug { + fn id(&self) -> NodeId; + fn backward(&mut self); + fn set_last_ops(&mut self); +} +pub type RecordedOpsRef = Box; diff --git a/burn-tensor/src/graph/ops/root.rs b/burn-tensor/src/graph/ops/root.rs new file mode 100644 index 0000000000..c874ae59f4 --- /dev/null +++ b/burn-tensor/src/graph/ops/root.rs @@ -0,0 +1,23 @@ +use super::RecordedOps; +use crate::node::{NodeId, NodeRef, Ones, Zeros}; + +#[derive(new, Debug)] +pub struct InitRecordedOps { + root: NodeRef, +} + +impl RecordedOps for InitRecordedOps +where + Out: Clone + Zeros + Ones + 'static, + Out: std::fmt::Debug, +{ + fn id(&self) -> NodeId { + self.root.borrow().id() + } + + fn backward(&mut self) {} + fn set_last_ops(&mut self) { + let value = self.root.borrow().value(); + self.root.borrow_mut().update_grad(value.ones()); + } +} diff --git a/burn-tensor/src/graph/ops/single.rs b/burn-tensor/src/graph/ops/single.rs new file mode 100644 index 0000000000..1855b6e74f --- /dev/null +++ b/burn-tensor/src/graph/ops/single.rs @@ -0,0 +1,92 @@ +use super::{RecordedOps, SingleRecordedState}; +use crate::node::{Node, NodeId, NodeRef, Ones, Zeros}; +use std::ops::{Add, Mul}; + +pub trait SingleOps: std::fmt::Debug { + fn forward(&self, input: In) -> Out; + fn partial(&self, state: &SingleRecordedState) -> In; +} + +#[derive(Debug)] +pub struct SingleOpsNode { + pub id: NodeId, + pub parent: NodeRef, + pub value: Out, + pub grad: Option, +} + +#[derive(new, Debug)] +pub struct SingleRecordedOps { + input: NodeRef, + out: NodeRef, + ops: Ops, +} + +impl SingleOpsNode { + pub fn new(parent: NodeRef, value: Out) -> Self { + Self { + id: NodeId::new(), + parent, + value, + grad: None, + } + } +} + +impl Node for SingleOpsNode +where + Out: Zeros + Clone + Mul + Add, + In: std::fmt::Debug, + Out: std::fmt::Debug, +{ + fn id(&self) -> NodeId { + self.id.clone() + } + fn value(&self) -> Out { + self.value.clone() + } + + fn grad(&mut self) -> Out { + let grad_self = match &self.grad { + Some(val) => val.clone(), + None => self.value.zeros(), + }; + self.grad = Some(grad_self.clone()); + grad_self + } + + fn update_grad(&mut self, grad: Out) { + self.grad = Some(self.grad() + grad); + } +} + +impl RecordedOps for SingleRecordedOps +where + In: Clone + Zeros + Mul, + Out: Clone + Zeros + Ones + 'static, + In: std::fmt::Debug, + Out: std::fmt::Debug, + Ops: SingleOps, +{ + fn id(&self) -> NodeId { + self.input.borrow().id() + } + + fn backward(&mut self) { + let input = self.input.borrow().value(); + let output = self.out.borrow().value(); + let state = SingleRecordedState::new(&input, &output); + + let partial = self.ops.partial(&state); + let grad_mine = self.out.borrow_mut().grad(); + + self.input + .borrow_mut() + .update_grad(partial * grad_mine.clone()); + } + + fn set_last_ops(&mut self) { + let value = self.out.borrow().value(); + self.out.borrow_mut().update_grad(value.ones()); + } +} diff --git a/burn-tensor/src/graph/tape.rs b/burn-tensor/src/graph/tape.rs new file mode 100644 index 0000000000..c544f80452 --- /dev/null +++ b/burn-tensor/src/graph/tape.rs @@ -0,0 +1,37 @@ +use crate::{node::NodeId, ops::RecordedOpsRef}; +use std::{cell::RefCell, rc::Rc}; + +#[derive(Debug)] +pub struct Tape { + pub operations: Vec, +} +pub type TapeRef = Rc>; + +impl Tape { + pub fn new() -> Self { + Self { + operations: Vec::new(), + } + } + + pub fn backward(&mut self, from: NodeId) { + let mut init = false; + + for ops in self.operations.iter_mut().rev() { + if init { + ops.backward(); + } else if ops.id() == from { + ops.set_last_ops(); + init = true; + ops.backward(); + } + } + } + + pub fn add(&mut self, ops: RecordedOpsRef) { + println!("---"); + println!("Adding ops {:?}", ops); + println!("---"); + self.operations.push(ops) + } +} diff --git a/burn-tensor/src/lib.rs b/burn-tensor/src/lib.rs new file mode 100644 index 0000000000..54921261d2 --- /dev/null +++ b/burn-tensor/src/lib.rs @@ -0,0 +1,8 @@ +#[macro_use] +extern crate derive_new; + +mod graph; +mod tensor; + +pub use graph::*; +pub use tensor::*; diff --git a/burn-tensor/src/tensor/backend/arrayfire/device.rs b/burn-tensor/src/tensor/backend/arrayfire/device.rs new file mode 100644 index 0000000000..729be9de56 --- /dev/null +++ b/burn-tensor/src/tensor/backend/arrayfire/device.rs @@ -0,0 +1,12 @@ +#[derive(Copy, Clone, PartialEq, Debug)] +pub enum GPUBackend { + CUDA, + OPENCL, +} +pub type DeviceNumber = usize; + +#[derive(Copy, Clone, PartialEq, Debug)] +pub enum Device { + CPU, + GPU(DeviceNumber, GPUBackend), +} diff --git a/burn-tensor/src/tensor/backend/arrayfire/mod.rs b/burn-tensor/src/tensor/backend/arrayfire/mod.rs new file mode 100644 index 0000000000..dab25396ff --- /dev/null +++ b/burn-tensor/src/tensor/backend/arrayfire/mod.rs @@ -0,0 +1,35 @@ +pub mod device; + +mod ops; +mod shape; +mod tensor; + +use self::device::Device; +pub use tensor::*; + +#[cfg(test)] +mod tests { + use super::*; + use crate::tensor::backend::arrayfire::device::GPUBackend; + use crate::{Data, Shape}; + use std::thread; + + #[test] + fn should_support_multiple_devices_on_different_thread() { + let function = |device| { + for _ in 0..10 { + let data_1 = Data::random(Shape::new([1000])); + let data_2 = Data::random(Shape::new([1000])); + let tensor_1 = ArrayfireTensor::::from_data(data_1, device); + let tensor_2 = ArrayfireTensor::::from_data(data_2, device); + let _data = tensor_1 + tensor_2; + } + }; + + let handler_1 = thread::spawn(move || function(Device::CPU)); + let handler_2 = thread::spawn(move || function(Device::GPU(0, GPUBackend::OPENCL))); + + handler_1.join().unwrap(); + handler_2.join().unwrap(); + } +} diff --git a/burn-tensor/src/tensor/backend/arrayfire/ops/add.rs b/burn-tensor/src/tensor/backend/arrayfire/ops/add.rs new file mode 100644 index 0000000000..ab7461fba2 --- /dev/null +++ b/burn-tensor/src/tensor/backend/arrayfire/ops/add.rs @@ -0,0 +1,77 @@ +use crate::{backend::arrayfire::ArrayfireTensor, TensorOpsAdd}; +use arrayfire::{ConstGenerator, HasAfEnum}; +use std::ops::Add; + +impl TensorOpsAdd for ArrayfireTensor +where + P: ConstGenerator + Clone + Copy, +{ + fn add(&self, other: &Self) -> Self { + self.set_backend_binary_ops(other); + + let array = (&self.array).add(&other.array); + let shape = self.shape.clone(); + let device = self.device; + + Self { + array, + shape, + device, + } + } + fn add_scalar(&self, other: &P) -> Self { + self.set_backend_single_ops(); + + let array = arrayfire::add(&self.array, other, false); + let shape = self.shape.clone(); + let device = self.device; + + Self { + array, + shape, + device, + } + } +} + +impl std::ops::Add for ArrayfireTensor +where + P: ConstGenerator + Clone + Copy, +{ + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + TensorOpsAdd::add(&self, &rhs) + } +} + +impl std::ops::Add

for ArrayfireTensor +where + P: ConstGenerator + Clone + Copy, +{ + type Output = Self; + + fn add(self, rhs: P) -> Self::Output { + TensorOpsAdd::add_scalar(&self, &rhs) + } +} + +#[cfg(test)] +mod tests { + use crate::{backend::arrayfire::Device, Data, TensorBase}; + + use super::*; + + #[test] + fn should_support_add_ops() { + let data_1 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let data_2 = Data::from([[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]); + let data_expected = Data::from([[6.0, 8.0, 10.0], [12.0, 14.0, 16.0]]); + let tensor_1 = ArrayfireTensor::::from_data(data_1, Device::CPU); + let tensor_2 = ArrayfireTensor::::from_data(data_2, Device::CPU); + + let data_actual = (tensor_1 + tensor_2).into_data(); + + assert_eq!(data_expected, data_actual); + } +} diff --git a/burn-tensor/src/tensor/backend/arrayfire/ops/index.rs b/burn-tensor/src/tensor/backend/arrayfire/ops/index.rs new file mode 100644 index 0000000000..29475096c0 --- /dev/null +++ b/burn-tensor/src/tensor/backend/arrayfire/ops/index.rs @@ -0,0 +1,87 @@ +use crate::{backend::arrayfire::ArrayfireTensor, TensorOpsIndex}; +use arrayfire::{HasAfEnum, Seq}; +use std::ops::Range; + +impl + TensorOpsIndex for ArrayfireTensor +{ + fn index(&self, index: [Range; D2]) -> Self { + self.set_backend_single_ops(); + let shape = self.shape.index(index.clone()); + + let mut seqs = Vec::new(); + for i in 0..D2 { + let range = index[i].clone(); + let start = range.start; + let end = range.end - 1; + seqs.push(Seq::new(start as f64, end as f64, 1.0)); + } + + for i in D2..D1 { + let dim = self.shape.dims[i]; + let start = 0; + let end = dim - 1; + seqs.push(Seq::new(start as f64, end as f64, 1.0)); + } + + let array = arrayfire::index(&self.array, &seqs); + let device = self.device; + + Self { + array, + shape, + device, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::backend::arrayfire::device::Device; + use crate::{Data, TensorBase}; + + #[test] + fn should_support_full_indexing_1d() { + let data = Data::from([0.0, 1.0, 2.0]); + let tensor = ArrayfireTensor::::from_data(data.clone(), Device::CPU); + + let data_actual = tensor.index([0..3]).into_data(); + + assert_eq!(data, data_actual); + } + + #[test] + fn should_support_partial_indexing_1d() { + let data = Data::from([0.0, 1.0, 2.0]); + let tensor = ArrayfireTensor::::from_data(data, Device::CPU); + + let data_actual = tensor.index([1..3]).into_data(); + + let data_expected = Data::from([1.0, 2.0]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_full_indexing_2d() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = ArrayfireTensor::::from_data(data.clone(), Device::CPU); + + let data_actual_1 = tensor.index([0..2]).into_data(); + let data_actual_2 = tensor.index([0..2, 0..3]).into_data(); + + assert_eq!(data, data_actual_1); + assert_eq!(data, data_actual_2); + } + + #[test] + fn should_support_partial_indexing_2d() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = ArrayfireTensor::::from_data(data, Device::CPU); + + let data_actual = tensor.index([0..2, 0..2]).into_data(); + + let data_expected = Data::from([[0.0, 1.0], [3.0, 4.0]]); + assert_eq!(data_expected, data_actual); + } +} diff --git a/burn-tensor/src/tensor/backend/arrayfire/ops/matmul.rs b/burn-tensor/src/tensor/backend/arrayfire/ops/matmul.rs new file mode 100644 index 0000000000..729e7c13bd --- /dev/null +++ b/burn-tensor/src/tensor/backend/arrayfire/ops/matmul.rs @@ -0,0 +1,61 @@ +use crate::{backend::arrayfire::ArrayfireTensor, Shape, TensorOpsMatmul}; +use arrayfire::{FloatingPoint, HasAfEnum}; + +impl TensorOpsMatmul for ArrayfireTensor { + fn matmul(&self, other: &Self) -> Self { + self.set_backend_binary_ops(other); + + let array = arrayfire::matmul( + &self.array, + &other.array, + arrayfire::MatProp::NONE, + arrayfire::MatProp::NONE, + ); + let device = self.device; + let shape = Shape::from(array.dims()); + + Self { + array, + shape, + device, + } + } +} +#[cfg(test)] +mod tests { + use super::*; + use crate::{backend::arrayfire::Device, Data, TensorBase}; + + #[test] + fn should_support_matmul_2_dims() { + let data_1 = Data::from([[4.0, 3.0], [8.0, 7.0]]); + let data_2 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor_1 = ArrayfireTensor::::from_data(data_1, Device::CPU); + let tensor_2 = ArrayfireTensor::::from_data(data_2, Device::CPU); + + let data_actual = tensor_1.matmul(&tensor_2).into_data(); + + let data_expected = Data::from([[9.0, 16.0, 23.0], [21.0, 36.0, 51.0]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + #[ignore = "batch operation not supported yet..."] + fn should_support_matmul_3_dims() { + let data_1 = Data::from([[[4.0, 3.0], [8.0, 7.0]], [[4.0, 3.0], [8.0, 7.0]]]); + let data_2 = Data::from([ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + ]); + let tensor_1 = ArrayfireTensor::::from_data(data_1, Device::CPU); + let tensor_2 = ArrayfireTensor::::from_data(data_2, Device::CPU); + + let data_actual = tensor_1.matmul(&tensor_2).into_data(); + + let data_expected = Data::from([ + [[9.0, 16.0, 23.0], [21.0, 36.0, 51.0]], + [[9.0, 16.0, 23.0], [21.0, 36.0, 51.0]], + ]); + assert_eq!(data_expected, data_actual); + } +} diff --git a/burn-tensor/src/tensor/backend/arrayfire/ops/mod.rs b/burn-tensor/src/tensor/backend/arrayfire/ops/mod.rs new file mode 100644 index 0000000000..9d4e98467b --- /dev/null +++ b/burn-tensor/src/tensor/backend/arrayfire/ops/mod.rs @@ -0,0 +1,6 @@ +mod add; +mod index; +mod matmul; +mod mul; +mod neg; +mod reshape; diff --git a/burn-tensor/src/tensor/backend/arrayfire/ops/mul.rs b/burn-tensor/src/tensor/backend/arrayfire/ops/mul.rs new file mode 100644 index 0000000000..0d234606b7 --- /dev/null +++ b/burn-tensor/src/tensor/backend/arrayfire/ops/mul.rs @@ -0,0 +1,85 @@ +use crate::{backend::arrayfire::ArrayfireTensor, TensorOpsMul}; +use arrayfire::{ConstGenerator, HasAfEnum}; + +impl TensorOpsMul for ArrayfireTensor +where + P: ConstGenerator + Clone + Copy, +{ + fn mul(&self, other: &Self) -> Self { + self.set_backend_binary_ops(other); + let array = arrayfire::mul(&self.array, &other.array, false); + let shape = self.shape.clone(); + let device = self.device; + + Self { + array, + shape, + device, + } + } + fn mul_scalar(&self, other: &P) -> Self { + self.set_backend_single_ops(); + let array = arrayfire::mul(&self.array, other, false); + let shape = self.shape.clone(); + let device = self.device; + + Self { + array, + shape, + device, + } + } +} + +impl std::ops::Mul

for ArrayfireTensor +where + P: ConstGenerator + Clone + Copy, +{ + type Output = ArrayfireTensor; + + fn mul(self, rhs: P) -> Self::Output { + TensorOpsMul::mul_scalar(&self, &rhs) + } +} + +impl std::ops::Mul> for ArrayfireTensor +where + P: ConstGenerator + Clone + Copy, +{ + type Output = ArrayfireTensor; + + fn mul(self, rhs: Self) -> Self::Output { + TensorOpsMul::mul(&self, &rhs) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{backend::arrayfire::Device, Data, TensorBase}; + + #[test] + fn should_support_mul_ops() { + let data_1 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let data_2 = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor_1 = ArrayfireTensor::::from_data(data_1, Device::CPU); + let tensor_2 = ArrayfireTensor::::from_data(data_2, Device::CPU); + + let data_actual = tensor_1.mul(&tensor_2).into_data(); + + let data_expected = Data::from([[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_mul_scalar_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let scalar = 2.0; + let tensor = ArrayfireTensor::::from_data(data, Device::CPU); + + let data_actual = tensor.mul_scalar(&scalar).into_data(); + + let data_expected = Data::from([[0.0, 2.0, 4.0], [6.0, 8.0, 10.0]]); + assert_eq!(data_expected, data_actual); + } +} diff --git a/burn-tensor/src/tensor/backend/arrayfire/ops/neg.rs b/burn-tensor/src/tensor/backend/arrayfire/ops/neg.rs new file mode 100644 index 0000000000..e365cbb846 --- /dev/null +++ b/burn-tensor/src/tensor/backend/arrayfire/ops/neg.rs @@ -0,0 +1,44 @@ +use arrayfire::{ConstGenerator, HasAfEnum}; +use num_traits::One; +use std::ops::Neg; + +use crate::{backend::arrayfire::ArrayfireTensor, TensorOpsMul, TensorOpsNeg}; + +impl TensorOpsNeg for ArrayfireTensor +where + P: ConstGenerator + Neg + One + Neg + Clone + Copy, +{ + fn neg(&self) -> Self { + self.set_backend_single_ops(); + let minus_one = Neg::neg(P::one()); + self.mul_scalar(&minus_one) + } +} + +impl std::ops::Neg for ArrayfireTensor +where + P: ConstGenerator + Neg + One + Neg + Clone + Copy, +{ + type Output = ArrayfireTensor; + + fn neg(self) -> Self::Output { + TensorOpsNeg::neg(&self) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{backend::arrayfire::device::Device, Data, TensorBase}; + + #[test] + fn should_support_neg_ops() { + let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = ArrayfireTensor::::from_data(data, Device::CPU); + + let data_actual = tensor.neg().into_data(); + + let data_expected = Data::from([[-0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]); + assert_eq!(data_expected, data_actual); + } +} diff --git a/burn-tensor/src/tensor/backend/arrayfire/ops/reshape.rs b/burn-tensor/src/tensor/backend/arrayfire/ops/reshape.rs new file mode 100644 index 0000000000..464fd22d1d --- /dev/null +++ b/burn-tensor/src/tensor/backend/arrayfire/ops/reshape.rs @@ -0,0 +1,19 @@ +use crate::{backend::arrayfire::ArrayfireTensor, Shape, TensorOpsReshape}; +use arrayfire::HasAfEnum; + +impl + TensorOpsReshape> for ArrayfireTensor +{ + fn reshape(&self, shape: Shape) -> ArrayfireTensor { + self.set_backend_single_ops(); + + let array = arrayfire::moddims(&self.array, shape.clone().into()); + let device = self.device; + + ArrayfireTensor { + array, + shape, + device, + } + } +} diff --git a/burn-tensor/src/tensor/backend/arrayfire/shape.rs b/burn-tensor/src/tensor/backend/arrayfire/shape.rs new file mode 100644 index 0000000000..2345cab699 --- /dev/null +++ b/burn-tensor/src/tensor/backend/arrayfire/shape.rs @@ -0,0 +1,28 @@ +use crate::Shape; +use arrayfire::Dim4; + +impl Into for Shape { + fn into(self) -> Dim4 { + if D > 4 { + panic!( + "Can't create arrayfire Tensor with more than 4 dimensions, got {}", + D + ); + } + let mut dims = [1; 4]; + for i in 0..D { + dims[i] = self.dims[i] as u64; + } + Dim4::new(&dims) + } +} + +impl From for Shape { + fn from(dim: Dim4) -> Self { + let mut values = [0; D]; + for i in 0..D { + values[i] = dim[i] as usize; + } + Shape::new(values) + } +} diff --git a/burn-tensor/src/tensor/backend/arrayfire/tensor.rs b/burn-tensor/src/tensor/backend/arrayfire/tensor.rs new file mode 100644 index 0000000000..26166a0815 --- /dev/null +++ b/burn-tensor/src/tensor/backend/arrayfire/tensor.rs @@ -0,0 +1,134 @@ +use super::{device::GPUBackend, Device}; +use crate::{ + backend::conversion::{Convertor, Order}, + Data, FloatTensor, Shape, TensorBase, +}; +use arrayfire::{Array, ConstGenerator, Dim4, FloatingPoint, HasAfEnum}; +use num_traits::Float; + +pub struct ArrayfireTensor { + pub device: Device, + pub array: Array

, + pub shape: Shape, +} + +impl ArrayfireTensor { + pub(crate) fn set_backend_binary_ops(&self, other: &Self) { + if self.device != other.device { + panic!("Not on same device"); + } + set_backend(&self.device); + } + + pub(crate) fn set_backend_single_ops(&self) { + set_backend(&self.device); + } +} + +pub(crate) fn set_backend(device: &Device) { + match device { + Device::CPU => arrayfire::set_backend(arrayfire::Backend::CPU), + &Device::GPU(device_num, backend) => { + match backend { + GPUBackend::CUDA => arrayfire::set_backend(arrayfire::Backend::CUDA), + GPUBackend::OPENCL => arrayfire::set_backend(arrayfire::Backend::OPENCL), + }; + arrayfire::set_device(device_num as i32); + } + } +} +impl FloatTensor + for ArrayfireTensor +where + P: ConstGenerator + FloatingPoint + Clone + Copy, +{ +} + +impl TensorBase + for ArrayfireTensor +{ + fn empty(shape: Shape) -> Self { + let device = Device::CPU; + set_backend(&device); + + let mut dims_arrayfire = [1; 4]; + + for i in 0..D { + dims_arrayfire[i] = shape.dims[i] as u64; + } + + let array = Array::new_empty(Dim4::new(&dims_arrayfire)); + + Self { + array, + shape, + device, + } + } + + fn from>(other: O) -> Self { + Self::from_data(other.into_data(), Device::CPU) + } + + fn shape(&self) -> &Shape { + &self.shape + } + fn into_data(self) -> Data { + let mut data = vec![P::default(); self.array.elements()]; + self.array.host(&mut data); + let convertor = Convertor::new(&self.shape, Order::Right, Order::Left); + let values = convertor.convert(&data); + Data::new(values, self.shape) + } +} + +impl ArrayfireTensor { + pub fn from_data(data: Data, device: Device) -> Self { + set_backend(&device); + + let shape = data.shape.clone(); + let dims: Dim4 = data.shape.into(); + + let convertor = Convertor::new(&shape, Order::Left, Order::Right); + let values = convertor.convert(&data.value); + let array = Array::new(&values, dims); + + Self { + array, + shape, + device, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[should_panic] + fn should_not_create_tensor_with_more_than_4_dims() { + let data_expected = Data::random(Shape::new([2, 3, 1, 4, 5])); + let _tensor = ArrayfireTensor::::from_data(data_expected.clone(), Device::CPU); + } + + #[test] + fn should_support_into_and_from_data_1d() { + let data_expected = Data::random(Shape::new([3])); + let tensor = ArrayfireTensor::::from_data(data_expected.clone(), Device::CPU); + + let data_actual = tensor.into_data(); + + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_into_and_from_data_2d() { + let data_expected = Data::random(Shape::new([2, 3])); + let tensor = ArrayfireTensor::::from_data(data_expected.clone(), Device::CPU); + + let data_actual = tensor.into_data(); + + assert_eq!(data_expected, data_actual); + } +} diff --git a/burn-tensor/src/tensor/backend/autodiff/items.rs b/burn-tensor/src/tensor/backend/autodiff/items.rs new file mode 100644 index 0000000000..d3ba715c1e --- /dev/null +++ b/burn-tensor/src/tensor/backend/autodiff/items.rs @@ -0,0 +1,68 @@ +use crate::{ + node::{Ones, Zeros}, + FloatTensor, +}; + +pub trait ADFloat: + num_traits::Float + Zeros + Ones + std::fmt::Debug + Default + 'static +{ +} +pub trait ADFloatTensor: + FloatTensor + Clone + Zeros + Ones + std::fmt::Debug + 'static +{ +} + +macro_rules! ad_items { + ( + float $float:ident + ) => { + impl ADFloat for $float {} + impl Zeros<$float> for $float { + fn zeros(&self) -> $float { + 0.0 + } + } + + impl Ones<$float> for $float { + fn ones(&self) -> $float { + 1.0 + } + } + }; +} + +ad_items!(float f64); +ad_items!(float f32); + +#[cfg(feature = "tch")] +mod tch { + use super::{ADFloat, ADFloatTensor}; + use crate::{ + backend::tch::TchTensor, + node::{Ones, Zeros}, + TensorOpsAdd, TensorOpsMul, + }; + use num_traits::Float; + + impl + ADFloat, const D: usize> ADFloatTensor + for TchTensor + { + } + + impl, const D: usize> Zeros> for TchTensor { + fn zeros(&self) -> TchTensor { + TensorOpsMul::mul_scalar(&self, &P::ZERO) + } + } + + impl Ones> for TchTensor + where + P: tch::kind::Element + Into + Float + Default, + P: std::fmt::Debug, + { + fn ones(&self) -> TchTensor { + let zeros = self.zeros(); + TensorOpsAdd::add_scalar(&zeros, &P::one()) + } + } +} diff --git a/burn-tensor/src/tensor/backend/autodiff/mod.rs b/burn-tensor/src/tensor/backend/autodiff/mod.rs new file mode 100644 index 0000000000..b3b0ff72ef --- /dev/null +++ b/burn-tensor/src/tensor/backend/autodiff/mod.rs @@ -0,0 +1,7 @@ +mod items; +mod ops; +mod tensor; + +pub use items::*; +pub use ops::*; +pub use tensor::*; diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/add.rs b/burn-tensor/src/tensor/backend/autodiff/ops/add.rs new file mode 100644 index 0000000000..2bd97f2962 --- /dev/null +++ b/burn-tensor/src/tensor/backend/autodiff/ops/add.rs @@ -0,0 +1,76 @@ +use crate::{ + backend::autodiff::{ADFloat, ADFloatTensor, ADTensor}, + define_ops, execute_ops, + ops::{ + BinaryOps, BinaryRecordedOps, BinaryRecordedState, SingleOps, SingleRecordedOps, + SingleRecordedState, + }, + register_ops, TensorOpsAdd, +}; +use num_traits::Float; + +register_ops!( + ops BinaryOps, + name ADTensorAddOps, + forward |left, right| left * right, + partial_left |state: &BinaryRecordedState| state.right.clone(), + partial_right |state: &BinaryRecordedState| state.left.clone(), +); + +register_ops!( + ops SingleOps, + name ADTensorAddScalarOps state P, + forward |state, input| input * state, + partial |state, state_recorded: &SingleRecordedState| state_recorded.input.ones() * state, +); + +impl TensorOpsAdd for ADTensor +where + T: ADFloatTensor, + P: ADFloat, +{ + fn add(&self, other: &Self) -> Self { + let node = execute_ops!( + lhs self.node.clone(), + rhs other.node.clone(), + out TensorOpsAdd::add(&self.tensor(), &other.tensor()), + tape self.tape.clone(), + ops ADTensorAddOps::new(), + ); + self.from_existing(node) + } + + fn add_scalar(&self, other: &P) -> Self { + let node = execute_ops!( + input self.node.clone(), + out TensorOpsAdd::add_scalar(&self.tensor(), &other), + tape self.tape.clone(), + ops ADTensorAddScalarOps::new(other.clone()), + ); + self.from_existing(node) + } +} + +impl std::ops::Add

for ADTensor +where + T: ADFloatTensor + 'static, + P: ADFloat + 'static, +{ + type Output = ADTensor; + + fn add(self, rhs: P) -> Self::Output { + TensorOpsAdd::add_scalar(&self, &rhs) + } +} + +impl std::ops::Add> for ADTensor +where + T: ADFloatTensor + 'static, + P: ADFloat + 'static, +{ + type Output = ADTensor; + + fn add(self, rhs: Self) -> Self::Output { + TensorOpsAdd::add(&self, &rhs) + } +} diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/macros.rs b/burn-tensor/src/tensor/backend/autodiff/ops/macros.rs new file mode 100644 index 0000000000..7cddd365dd --- /dev/null +++ b/burn-tensor/src/tensor/backend/autodiff/ops/macros.rs @@ -0,0 +1,175 @@ +#[macro_export] +macro_rules! define_ops { + ( + name $name:ident + ) => { + #[derive(Debug)] + struct $name { + _kind: $crate::tensor::backend::autodiff::ADKind

, + } + + impl $name { + pub fn new() -> Self { + Self { + _kind: $crate::tensor::backend::autodiff::ADKind::new(), + } + } + } + }; + ( + name $name:ident, + state $state_ident:ident, + ) => { + #[derive(Debug)] + struct $name { + pub state: $state_ident, + _kind: $crate::tensor::backend::autodiff::ADKind

, + } + + impl $name { + pub fn new(value: $state_ident) -> Self { + Self { + state: value, + _kind: $crate::tensor::backend::autodiff::ADKind::new(), + } + } + } + }; +} + +#[macro_export] +macro_rules! register_ops { + ( + ops $ops:ty, + name $name:ident, + forward $forward:expr, + partial_left $partial_left:expr, + partial_right $partial_right:expr, + ) => { + $crate::define_ops!( + name $name + ); + + impl $ops for $name + where + P: $crate::tensor::backend::autodiff::ADFloat, + T: $crate::tensor::backend::autodiff::ADFloatTensor, + { + fn forward(&self, left: T, right: T) -> T { + $forward(left, right) + } + + fn partial_left(&self, state: &$crate::graph::ops::BinaryRecordedState) -> T { + $partial_left(state) + } + + fn partial_right(&self, state: &$crate::graph::ops::BinaryRecordedState) -> T { + $partial_right(state) + } + } + }; + ( + ops $ops:ty, + name $name:ident state $ops_tensor_state:ident, + forward $forward:expr, + partial $partial:expr, + ) => { + define_ops!( + name $name, + state $ops_tensor_state, + ); + + impl $ops for $name + where + P: $crate::tensor::backend::autodiff::ADFloat, + T: $crate::tensor::backend::autodiff::ADFloatTensor, + { + fn forward(&self, input: T) -> T { + $forward(self.state, input) + } + + fn partial(&self, state: &$crate::graph::ops::SingleRecordedState) -> T { + $partial(self.state, state) + } + } + }; + ( + ops $ops:ty, + name $name:ident, + forward $forward:expr, + partial $partial:expr, + ) => { + define_ops!( + name $name, + ); + + impl $ops for $name + where + P: $crate::tensor::backend::autodiff::ADFloat, + T: $crate::tensor::backend::autodiff::ADFloatTensor, + { + fn forward(&self, input: T) -> T { + $forward(input) + } + + fn partial(&self, state: &$crate::graph::ops::SingleRecordedState) -> T { + $partial(state) + } + } + } + +} + +#[macro_export] +macro_rules! execute_ops { + ( + lhs $lhs:expr, + rhs $rhs:expr, + out $out:expr, + tape $tape:expr, + ops $ops:expr, + ) => + { + { + let callback = || { + let node = $crate::node_init!( + lhs $lhs, + rhs $rhs, + out $out, + ); + + let ops = $ops; + let ops = BinaryRecordedOps::new($lhs, $rhs, node.clone(), ops); + let ops = Box::new(ops); + + $tape.borrow_mut().add(ops); + node + }; + callback() + } + }; + ( + input $input:expr, + out $out:expr, + tape $tape:expr, + ops $ops:expr, + ) => + { + { + let callback = || { + let node = $crate::node_init!( + input $input, + out $out, + ); + + let ops = $ops; + let ops = SingleRecordedOps::new($input, node.clone(), ops); + let ops = Box::new(ops); + + $tape.borrow_mut().add(ops); + node + }; + callback() + } + }; +} diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs b/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs new file mode 100644 index 0000000000..1134220ebd --- /dev/null +++ b/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs @@ -0,0 +1,5 @@ +mod add; +mod mul; + +mod macros; +pub use macros::*; diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/mul.rs b/burn-tensor/src/tensor/backend/autodiff/ops/mul.rs new file mode 100644 index 0000000000..df87d6884e --- /dev/null +++ b/burn-tensor/src/tensor/backend/autodiff/ops/mul.rs @@ -0,0 +1,112 @@ +use crate::{ + backend::autodiff::{ADFloat, ADFloatTensor, ADTensor}, + define_ops, execute_ops, + ops::{ + BinaryOps, BinaryRecordedOps, BinaryRecordedState, SingleOps, SingleRecordedOps, + SingleRecordedState, + }, + register_ops, TensorOpsMul, +}; +use num_traits::Float; + +register_ops!( + ops BinaryOps, + name ADTensorMulOps, + forward |left, right| left * right, + partial_left |state: &BinaryRecordedState| state.right.clone(), + partial_right |state: &BinaryRecordedState| state.left.clone(), +); + +register_ops!( + ops SingleOps, + name ADTensorMulScalarOps state P, + forward |state, input| input * state, + partial |state, state_recorded: &SingleRecordedState| state_recorded.input.ones() * state, +); + +impl TensorOpsMul for ADTensor +where + T: ADFloatTensor, + P: ADFloat, +{ + fn mul(&self, other: &Self) -> Self { + let node = execute_ops!( + lhs self.node.clone(), + rhs other.node.clone(), + out TensorOpsMul::mul(&self.tensor(), &other.tensor()), + tape self.tape.clone(), + ops ADTensorMulOps::new(), + ); + self.from_existing(node) + } + + fn mul_scalar(&self, other: &P) -> Self { + let node = execute_ops!( + input self.node.clone(), + out TensorOpsMul::mul_scalar(&self.tensor(), &other), + tape self.tape.clone(), + ops ADTensorMulScalarOps::new(other.clone()), + ); + self.from_existing(node) + } +} + +impl std::ops::Mul

for ADTensor +where + T: ADFloatTensor + 'static, + P: ADFloat + 'static, +{ + type Output = ADTensor; + + fn mul(self, rhs: P) -> Self::Output { + TensorOpsMul::mul_scalar(&self, &rhs) + } +} + +impl std::ops::Mul> for ADTensor +where + T: ADFloatTensor + 'static, + P: ADFloat + 'static, +{ + type Output = ADTensor; + + fn mul(self, rhs: Self) -> Self::Output { + TensorOpsMul::mul(&self, &rhs) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + backend::tch::TchTensor, + node_init, + tape::{Tape, TapeRef}, + Data, TensorBase, + }; + use std::cell::RefCell; + + #[test] + fn should_diff_mul() { + let tape = TapeRef::new(RefCell::new(Tape::new())); + let data_1 = Data::from([1.0]); + let data_2 = Data::from([4.0]); + + let tensor_1 = TchTensor::from_data(data_1.clone(), tch::Device::Cpu); + let tensor_2 = TchTensor::from_data(data_2.clone(), tch::Device::Cpu); + + let tensor_ad_1 = ADTensor::new(node_init!(root tensor_1), tape.clone()); + let tensor_ad_2 = ADTensor::new(node_init!(root tensor_2), tape.clone()); + + let tensor_ad_3 = tensor_ad_1.mul(&tensor_ad_2); + let data_ad_3 = tensor_ad_3.tensor().into_data(); + assert_eq!(data_ad_3, Data::from([4.0])); + + tensor_ad_3.backprob(); + let grad_1 = tensor_ad_1.grad(); + let grad_2 = tensor_ad_2.grad(); + + assert_eq!(grad_1.into_data(), data_2); + assert_eq!(grad_2.into_data(), data_1); + } +} diff --git a/burn-tensor/src/tensor/backend/autodiff/tensor.rs b/burn-tensor/src/tensor/backend/autodiff/tensor.rs new file mode 100644 index 0000000000..01794388ec --- /dev/null +++ b/burn-tensor/src/tensor/backend/autodiff/tensor.rs @@ -0,0 +1,81 @@ +use crate::{ + node::{NodeRef, Ones, Zeros}, + ops::InitRecordedOps, + tape::TapeRef, + FloatTensor, Shape, +}; +use num_traits::Float; + +#[derive(Debug)] +pub struct ADTensor { + pub node: NodeRef, + pub shape: Shape, + pub kind: ADKind

, + pub tape: TapeRef, +} + +impl ADTensor +where + P: Float + Zeros

+ Default + 'static, + T: FloatTensor + Clone + Zeros + Ones + 'static, +{ + pub fn new(node: NodeRef, tape: TapeRef) -> Self { + let tensor = node.borrow().value(); + let shape = tensor.shape().clone(); + let kind = ADKind::new(); + + let ops = InitRecordedOps::new(node.clone()); + let ops = Box::new(ops); + tape.borrow_mut().add(ops); + + Self { + node, + shape, + kind, + tape, + } + } + + pub fn from_existing(&self, node: NodeRef) -> Self { + let tape = self.tape.clone(); + let shape = self.shape.clone(); + let kind = self.kind.clone(); + + Self { + node, + shape, + kind, + tape, + } + } +} + +impl ADTensor { + pub fn tensor(&self) -> T { + self.node.borrow().value() + } +} + +impl ADTensor { + pub fn backprob(&self) { + let id = self.node.borrow().id(); + self.tape.borrow_mut().backward(id); + } +} + +impl ADTensor { + pub fn grad(&self) -> T { + self.node.borrow_mut().grad() + } +} + +#[derive(Clone, Debug)] +pub struct ADKind

{ + _p: P, +} + +impl ADKind

{ + pub fn new() -> Self { + Self { _p: P::default() } + } +} diff --git a/burn-tensor/src/tensor/backend/conversion.rs b/burn-tensor/src/tensor/backend/conversion.rs new file mode 100644 index 0000000000..523c0b480f --- /dev/null +++ b/burn-tensor/src/tensor/backend/conversion.rs @@ -0,0 +1,266 @@ +use std::collections::HashMap; + +use crate::Shape; + +#[derive(Debug)] +pub struct Convertor { + indices_from: HashMap>, + indices_to: HashMap, usize>, +} + +impl Convertor { + pub fn new(shape: &Shape, order_from: Order, order_to: Order) -> Self { + Self { + indices_from: revert_indices_map(build_indices(shape, order_from)), + indices_to: build_indices(shape, order_to), + } + } + + pub fn convert(&self, data: &Vec

) -> Vec

{ + let mut data_converted = data.clone(); + + for (i, val) in data.iter().enumerate() { + let indices = self.indices_from.get(&i).unwrap(); + let i_converted = self.indices_to.get(indices).unwrap(); + data_converted[*i_converted] = *val; + } + data_converted + } +} + +pub enum Order { + Left, + Right, +} +pub fn build_indices(shape: &Shape, order: Order) -> HashMap, usize> { + let num_elements = shape.num_elements(); + let mut indices = init_indices::(); + let mut num_repeat_next = 1; + + for i in 0..D { + let index = match order { + Order::Left => D - i - 1, + Order::Right => i, + }; + let size = shape.dims[index]; + let num_repeat = num_repeat_next; + let times = num_elements / (num_repeat * size); + num_repeat_next *= size; + + let dim = IndicesDimGenerator::new(size, num_repeat, times); + indices[index] = dim.generate(); + } + + build_map_from(shape, indices) +} + +fn init_indices() -> Vec> { + let mut indices: Vec> = Vec::with_capacity(D); + for _ in 0..D { + indices.push(Vec::new()); + } + indices +} + +fn build_map_from( + shape: &Shape, + indices: Vec>, +) -> HashMap, usize> { + let num_elements = shape.num_elements(); + let mut map = HashMap::new(); + + for e in 0..num_elements { + let mut index = Vec::with_capacity(D); + for d in 0..D { + let arr = &indices[d]; + let num = arr[e]; + index.push(num); + } + map.insert(index, e); + } + map +} + +fn revert_indices_map(map: HashMap, usize>) -> HashMap> { + let mut map_revert = HashMap::with_capacity(map.len()); + for (key, value) in map.into_iter() { + map_revert.insert(value, key); + } + map_revert +} + +#[derive(new, Debug)] +struct IndicesDimGenerator { + size: usize, + num_repeat: usize, + times: usize, +} + +impl IndicesDimGenerator { + fn generate(self) -> Vec { + let mut vec = Vec::new(); + + for _ in 0..self.times { + for i in 0..self.size { + for _ in 0..self.num_repeat { + vec.push(i); + } + } + } + vec + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn should_convert_data_2d_from_left_to_right() { + let shape = Shape::new([2, 3]); + let data_ij = data_ij(&shape); + let data_ji = data_ji(&shape); + let convertor = Convertor::new(&shape, Order::Left, Order::Right); + + let data_ij_converted = convertor.convert(&data_ij); + + assert_eq!(data_ji, data_ij_converted); + } + + #[test] + fn should_convert_data_2d_from_right_to_left() { + let shape = Shape::new([2, 3]); + let data_ij = data_ij(&shape); + let data_ji = data_ji(&shape); + let convertor = Convertor::new(&shape, Order::Right, Order::Left); + + let data_ji_converted = convertor.convert(&data_ji); + + assert_eq!(data_ij, data_ji_converted); + } + + fn data_ij(shape: &Shape<2>) -> Vec { + let mut data = Vec::new(); + + for i in 0..shape.dims[0] { + for j in 0..shape.dims[1] { + data.push(i + j); + } + } + data + } + + fn data_ji(shape: &Shape<2>) -> Vec { + let mut data = Vec::new(); + + for i in 0..shape.dims[1] { + for j in 0..shape.dims[0] { + data.push(i + j); + } + } + data + } + + #[test] + fn should_build_indices_1d_simple() { + let shape = Shape::new([2]); + + let indices = build_indices(&shape, Order::Left); + + let expected = HashMap::from([(vec![0], 0), (vec![1], 1)]); + + assert_eq!(expected, indices); + } + + #[test] + fn should_build_indices_2d_simple() { + let shape = Shape::new([2, 2]); + + let indices = build_indices(&shape, Order::Left); + + let expected = HashMap::from([ + (vec![0, 0], 0), + (vec![0, 1], 1), + (vec![1, 0], 2), + (vec![1, 1], 3), + ]); + + assert_eq!(expected, indices); + } + + #[test] + fn should_build_indices_2d_complexe() { + let shape = Shape::new([2, 3]); + + let indices = build_indices(&shape, Order::Left); + + let expected = HashMap::from([ + (vec![0, 0], 0), + (vec![0, 1], 1), + (vec![0, 2], 2), + (vec![1, 0], 3), + (vec![1, 1], 4), + (vec![1, 2], 5), + ]); + + assert_eq!(expected, indices); + } + + #[test] + fn should_build_indices_3d_complexe() { + let shape = Shape::new([2, 5, 3]); + + let indices = build_indices(&shape, Order::Left); + + let expected = HashMap::from([ + (vec![0, 0, 0], 0), + (vec![0, 0, 1], 1), + (vec![0, 0, 2], 2), + (vec![0, 1, 0], 3), + (vec![0, 1, 1], 4), + (vec![0, 1, 2], 5), + (vec![0, 2, 0], 6), + (vec![0, 2, 1], 7), + (vec![0, 2, 2], 8), + (vec![0, 3, 0], 9), + (vec![0, 3, 1], 10), + (vec![0, 3, 2], 11), + (vec![0, 4, 0], 12), + (vec![0, 4, 1], 13), + (vec![0, 4, 2], 14), + (vec![1, 0, 0], 15), + (vec![1, 0, 1], 16), + (vec![1, 0, 2], 17), + (vec![1, 1, 0], 18), + (vec![1, 1, 1], 19), + (vec![1, 1, 2], 20), + (vec![1, 2, 0], 21), + (vec![1, 2, 1], 22), + (vec![1, 2, 2], 23), + (vec![1, 3, 0], 24), + (vec![1, 3, 1], 25), + (vec![1, 3, 2], 26), + (vec![1, 4, 0], 27), + (vec![1, 4, 1], 28), + (vec![1, 4, 2], 29), + ]); + + assert_eq!(expected, indices); + } + + #[test] + fn should_build_indices_4d_weird() { + let shape = Shape::new([2, 1, 2, 1]); + + let indices = build_indices(&shape, Order::Left); + + let expected = HashMap::from([ + (vec![0, 0, 0, 0], 0), + (vec![0, 0, 1, 0], 1), + (vec![1, 0, 0, 0], 2), + (vec![1, 0, 1, 0], 3), + ]); + + assert_eq!(expected, indices); + } +} diff --git a/burn-tensor/src/tensor/backend/mod.rs b/burn-tensor/src/tensor/backend/mod.rs new file mode 100644 index 0000000000..fe67a30b81 --- /dev/null +++ b/burn-tensor/src/tensor/backend/mod.rs @@ -0,0 +1,6 @@ +#[cfg(feature = "arrayfire")] +pub mod arrayfire; +pub mod autodiff; +pub mod conversion; +#[cfg(feature = "tch")] +pub mod tch; diff --git a/burn-tensor/src/tensor/backend/tch/mod.rs b/burn-tensor/src/tensor/backend/tch/mod.rs new file mode 100644 index 0000000000..995ae6e610 --- /dev/null +++ b/burn-tensor/src/tensor/backend/tch/mod.rs @@ -0,0 +1,4 @@ +mod ops; +mod tensor; + +pub use tensor::*; diff --git a/burn-tensor/src/tensor/backend/tch/ops/add.rs b/burn-tensor/src/tensor/backend/tch/ops/add.rs new file mode 100644 index 0000000000..7443d8ac5a --- /dev/null +++ b/burn-tensor/src/tensor/backend/tch/ops/add.rs @@ -0,0 +1,71 @@ +use crate::{backend::tch::TchTensor, Data, TensorOpsAdd}; +use std::ops::Add; + +impl TensorOpsAdd + for TchTensor +{ + fn add(&self, other: &Self) -> Self { + let tensor = (&self.tensor).add(&other.tensor); + let kind = self.kind.clone(); + let shape = self.shape.clone(); + + Self { + tensor, + shape, + kind, + } + } + fn add_scalar(&self, other: &P) -> Self { + let elems: [P; D] = [*other; D]; + let data = Data::from(elems); + let other = TchTensor::from_data(data, self.tensor.device()); + let tensor = (&self.tensor).add(&other.tensor); + let kind = self.kind.clone(); + let shape = self.shape.clone(); + + Self { + tensor, + shape, + kind, + } + } +} + +impl std::ops::Add + for TchTensor +{ + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + TensorOpsAdd::add(&self, &rhs) + } +} + +impl std::ops::Add

+ for TchTensor +{ + type Output = Self; + + fn add(self, rhs: P) -> Self::Output { + TensorOpsAdd::add_scalar(&self, &rhs) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::TensorBase; + + #[test] + fn should_support_add_ops() { + let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let data_2 = Data::::from([[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]); + let data_expected = Data::from([[6.0, 8.0, 10.0], [12.0, 14.0, 16.0]]); + let tensor_1 = TchTensor::from_data(data_1, tch::Device::Cpu); + let tensor_2 = TchTensor::from_data(data_2, tch::Device::Cpu); + + let data_actual = (tensor_1 + tensor_2).into_data(); + + assert_eq!(data_expected, data_actual); + } +} diff --git a/burn-tensor/src/tensor/backend/tch/ops/index.rs b/burn-tensor/src/tensor/backend/tch/ops/index.rs new file mode 100644 index 0000000000..e89bd443d6 --- /dev/null +++ b/burn-tensor/src/tensor/backend/tch/ops/index.rs @@ -0,0 +1,78 @@ +use crate::{backend::tch::TchTensor, TensorOpsIndex}; +use std::ops::Range; + +impl< + P: tch::kind::Element + std::fmt::Debug + Copy + Default, + const D1: usize, + const D2: usize, + > TensorOpsIndex for TchTensor +{ + fn index(&self, indices: [Range; D2]) -> Self { + let mut tensor = self.tensor.shallow_clone(); + + for i in 0..D2 { + let index = indices[i].clone(); + let start = index.start as i64; + let length = (index.end - index.start) as i64; + tensor = tensor.narrow(i as i64, start, length) + } + let shape = self.shape.index(indices); + let kind = self.kind.clone(); + + Self { + kind, + tensor, + shape, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{Data, TensorBase}; + + #[test] + fn should_support_full_indexing_1d() { + let data = Data::::from([0.0, 1.0, 2.0]); + let tensor = TchTensor::from_data(data.clone(), tch::Device::Cpu); + + let data_actual = tensor.index([0..3]).into_data(); + + assert_eq!(data, data_actual); + } + + #[test] + fn should_support_partial_indexing_1d() { + let data = Data::::from([0.0, 1.0, 2.0]); + let tensor = TchTensor::from_data(data.clone(), tch::Device::Cpu); + + let data_actual = tensor.index([1..3]).into_data(); + + let data_expected = Data::from([1.0, 2.0]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_full_indexing_2d() { + let data = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = TchTensor::from_data(data.clone(), tch::Device::Cpu); + + let data_actual_1 = tensor.index([0..2]).into_data(); + let data_actual_2 = tensor.index([0..2, 0..3]).into_data(); + + assert_eq!(data, data_actual_1); + assert_eq!(data, data_actual_2); + } + + #[test] + fn should_support_partial_indexing_2d() { + let data = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = TchTensor::from_data(data.clone(), tch::Device::Cpu); + + let data_actual = tensor.index([0..2, 0..2]).into_data(); + + let data_expected = Data::from([[0.0, 1.0], [3.0, 4.0]]); + assert_eq!(data_expected, data_actual); + } +} diff --git a/burn-tensor/src/tensor/backend/tch/ops/matmul.rs b/burn-tensor/src/tensor/backend/tch/ops/matmul.rs new file mode 100644 index 0000000000..887165157a --- /dev/null +++ b/burn-tensor/src/tensor/backend/tch/ops/matmul.rs @@ -0,0 +1,53 @@ +use crate::{backend::tch::TchTensor, Shape, TensorOpsMatmul}; + +impl TensorOpsMatmul for TchTensor { + fn matmul(&self, other: &Self) -> Self { + let tensor = self.tensor.matmul(&other.tensor); + let kind = self.kind.clone(); + let shape = Shape::from(tensor.size()); + + Self { + kind, + tensor, + shape, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{Data, TensorBase}; + + #[test] + fn should_support_matmul_2_dims() { + let data_1 = Data::::from([[4.0, 3.0], [8.0, 7.0]]); + let data_2 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor_1 = TchTensor::from_data(data_1, tch::Device::Cpu); + let tensor_2 = TchTensor::from_data(data_2, tch::Device::Cpu); + + let data_actual = tensor_1.matmul(&tensor_2).into_data(); + + let data_expected = Data::from([[9.0, 16.0, 23.0], [21.0, 36.0, 51.0]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_matmul_3_dims() { + let data_1 = Data::::from([[[4.0, 3.0], [8.0, 7.0]], [[4.0, 3.0], [8.0, 7.0]]]); + let data_2 = Data::::from([ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + ]); + let tensor_1 = TchTensor::from_data(data_1, tch::Device::Cpu); + let tensor_2 = TchTensor::from_data(data_2, tch::Device::Cpu); + + let data_actual = tensor_1.matmul(&tensor_2).into_data(); + + let data_expected = Data::from([ + [[9.0, 16.0, 23.0], [21.0, 36.0, 51.0]], + [[9.0, 16.0, 23.0], [21.0, 36.0, 51.0]], + ]); + assert_eq!(data_expected, data_actual); + } +} diff --git a/burn-tensor/src/tensor/backend/tch/ops/mod.rs b/burn-tensor/src/tensor/backend/tch/ops/mod.rs new file mode 100644 index 0000000000..9d4e98467b --- /dev/null +++ b/burn-tensor/src/tensor/backend/tch/ops/mod.rs @@ -0,0 +1,6 @@ +mod add; +mod index; +mod matmul; +mod mul; +mod neg; +mod reshape; diff --git a/burn-tensor/src/tensor/backend/tch/ops/mul.rs b/burn-tensor/src/tensor/backend/tch/ops/mul.rs new file mode 100644 index 0000000000..3b7acef166 --- /dev/null +++ b/burn-tensor/src/tensor/backend/tch/ops/mul.rs @@ -0,0 +1,79 @@ +use crate::{backend::tch::TchTensor, TensorOpsMul}; +use std::ops::Mul; + +impl, const D: usize> TensorOpsMul for TchTensor { + fn mul(&self, other: &Self) -> Self { + let tensor = (&self.tensor) * &other.tensor; + let shape = self.shape.clone(); + let kind = self.kind.clone(); + + Self { + tensor, + shape, + kind, + } + } + fn mul_scalar(&self, other: &P) -> Self { + let other: f64 = (other.clone()).into(); + let tensor = (&self.tensor).mul(other); + let shape = self.shape.clone(); + let kind = self.kind.clone(); + + Self { + tensor, + shape, + kind, + } + } +} + +impl, const D: usize> std::ops::Mul

for TchTensor { + type Output = TchTensor; + + fn mul(self, rhs: P) -> Self::Output { + TensorOpsMul::mul_scalar(&self, &rhs) + } +} + +impl, const D: usize> std::ops::Mul> + for TchTensor +{ + type Output = TchTensor; + + fn mul(self, rhs: Self) -> Self::Output { + TensorOpsMul::mul(&self, &rhs) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{Data, TensorBase}; + + #[test] + fn should_support_mul_ops() { + let data_1 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let data_2 = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor_1 = TchTensor::from_data(data_1, tch::Device::Cpu); + let tensor_2 = TchTensor::from_data(data_2, tch::Device::Cpu); + + let output = tensor_1 * tensor_2; + let data_actual = output.into_data(); + + let data_expected = Data::from([[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_mul_scalar_ops() { + let data = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let scalar = 2.0; + let tensor = TchTensor::from_data(data, tch::Device::Cpu); + + let output = tensor * scalar; + let data_actual = output.into_data(); + + let data_expected = Data::from([[0.0, 2.0, 4.0], [6.0, 8.0, 10.0]]); + assert_eq!(data_expected, data_actual); + } +} diff --git a/burn-tensor/src/tensor/backend/tch/ops/neg.rs b/burn-tensor/src/tensor/backend/tch/ops/neg.rs new file mode 100644 index 0000000000..2d40d9983d --- /dev/null +++ b/burn-tensor/src/tensor/backend/tch/ops/neg.rs @@ -0,0 +1,44 @@ +use crate::{backend::tch::TchTensor, TensorOpsNeg}; + +impl TensorOpsNeg + for TchTensor +{ + fn neg(&self) -> Self { + let tensor = -(&self.tensor); + let kind = self.kind.clone(); + let shape = self.shape.clone(); + + Self { + tensor, + shape, + kind, + } + } +} + +impl std::ops::Neg + for TchTensor +{ + type Output = Self; + + fn neg(self) -> Self::Output { + TensorOpsNeg::neg(&self) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{Data, TensorBase}; + + #[test] + fn should_support_neg_ops() { + let data = Data::::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = TchTensor::from_data(data, tch::Device::Cpu); + + let data_actual = tensor.neg().into_data(); + + let data_expected = Data::from([[-0.0, -1.0, -2.0], [-3.0, -4.0, -5.0]]); + assert_eq!(data_expected, data_actual); + } +} diff --git a/burn-tensor/src/tensor/backend/tch/ops/reshape.rs b/burn-tensor/src/tensor/backend/tch/ops/reshape.rs new file mode 100644 index 0000000000..c0ba71fa80 --- /dev/null +++ b/burn-tensor/src/tensor/backend/tch/ops/reshape.rs @@ -0,0 +1,23 @@ +use crate::{ + backend::tch::{TchShape, TchTensor}, + Shape, TensorOpsReshape, +}; + +impl< + P: tch::kind::Element + std::fmt::Debug + Copy + Default, + const D1: usize, + const D2: usize, + > TensorOpsReshape> for TchTensor +{ + fn reshape(&self, shape: Shape) -> TchTensor { + let shape_tch: TchShape = shape.clone().into(); + let tensor = self.tensor.reshape(&shape_tch.dims); + let kind = self.kind.clone(); + + TchTensor { + tensor, + kind, + shape, + } + } +} diff --git a/burn-tensor/src/tensor/backend/tch/tensor.rs b/burn-tensor/src/tensor/backend/tch/tensor.rs new file mode 100644 index 0000000000..c8bb00d2cc --- /dev/null +++ b/burn-tensor/src/tensor/backend/tch/tensor.rs @@ -0,0 +1,136 @@ +use crate::{Data, FloatTensor, Shape, TensorBase}; +use num_traits::Float; + +#[derive(Debug, PartialEq)] +pub struct TchTensor { + pub kind: TchKind

, + pub tensor: tch::Tensor, + pub shape: Shape, +} + +impl Clone for TchTensor { + fn clone(&self) -> Self { + Self { + kind: self.kind.clone(), + tensor: self.tensor.shallow_clone(), + shape: self.shape.clone(), + } + } +} + +pub struct TchShape { + pub dims: [i64; D], +} + +impl< + P: Float + tch::kind::Element + Default + Copy + std::fmt::Debug + Into, + const D: usize, + > FloatTensor for TchTensor +{ +} + +impl From> for TchShape { + fn from(shape: Shape) -> Self { + let mut dims = [0; D]; + for i in 0..D { + dims[i] = shape.dims[i] as i64; + } + TchShape { dims } + } +} + +impl From> for Shape { + fn from(shape: Vec) -> Self { + let mut dims = [0; D]; + for i in 0..D { + dims[i] = *shape.get(i).unwrap() as usize; + } + Self::new(dims) + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct TchKind { + _p: P, +} + +impl TchKind

{ + pub fn new() -> Self { + Self { _p: P::default() } + } + pub fn kind(&self) -> tch::Kind { + P::KIND + } +} + +impl TchTensor { + pub fn from_data(data: Data, device: tch::Device) -> Self { + let tensor = tch::Tensor::of_slice(data.value.as_slice()).to(device); + let shape = data.shape.clone(); + let shape_tch = TchShape::from(data.shape); + let kind = TchKind::new(); + let tensor = tensor.reshape(&shape_tch.dims).to_kind(kind.kind()); + let tensor = tensor.set_requires_grad(false); + + Self { + kind, + tensor, + shape, + } + } +} + +impl TensorBase + for TchTensor +{ + fn empty(shape: Shape) -> Self { + let shape_tch = TchShape::from(shape.clone()); + let device = tch::Device::Cpu; + let kind = TchKind::new(); + let tensor = tch::Tensor::empty(&shape_tch.dims, (kind.kind(), device.clone())); + let tensor = tensor.set_requires_grad(false); + + Self { + kind, + tensor, + shape, + } + } + + fn from>(other: O) -> Self { + Self::from_data(other.into_data(), tch::Device::Cpu) + } + + fn shape(&self) -> &Shape { + &self.shape + } + fn into_data(self) -> Data { + let values = self.tensor.into(); + Data::new(values, self.shape) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn should_support_into_and_from_data_1d() { + let data_expected = Data::::random(Shape::new([3])); + let tensor = TchTensor::from_data(data_expected.clone(), tch::Device::Cpu); + + let data_actual = tensor.into_data(); + + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_into_and_from_data_2d() { + let data_expected = Data::::random(Shape::new([2, 3])); + let tensor = TchTensor::from_data(data_expected.clone(), tch::Device::Cpu); + + let data_actual = tensor.into_data(); + + assert_eq!(data_expected, data_actual); + } +} diff --git a/burn-tensor/src/tensor/data.rs b/burn-tensor/src/tensor/data.rs new file mode 100644 index 0000000000..e62048fdea --- /dev/null +++ b/burn-tensor/src/tensor/data.rs @@ -0,0 +1,89 @@ +use crate::Shape; +use rand::{distributions::Standard, prelude::Distribution}; + +#[derive(new, Debug, Clone, PartialEq)] +pub struct Data { + pub value: Vec

, + pub shape: Shape, +} + +impl Data +where + Standard: Distribution

, +{ + pub fn random(shape: Shape) -> Data { + let num_elements = shape.num_elements(); + let mut data = Vec::with_capacity(num_elements); + + for _ in 0..num_elements { + data.push(rand::random()); + } + + Data::new(data, shape) + } +} + +impl From<[P; A]> for Data { + fn from(elems: [P; A]) -> Self { + let mut data = Vec::with_capacity(2 * A); + for i in 0..A { + data.push(elems[i]); + } + + Data::new(data, Shape::new([A])) + } +} + +impl From<[[P; B]; A]> for Data { + fn from(elems: [[P; B]; A]) -> Self { + let mut data = Vec::with_capacity(A * B); + for i in 0..A { + for j in 0..B { + data.push(elems[i][j]); + } + } + + Data::new(data, Shape::new([A, B])) + } +} + +impl + From<[[[P; C]; B]; A]> for Data +{ + fn from(elems: [[[P; C]; B]; A]) -> Self { + let mut data = Vec::with_capacity(A * B * C); + for i in 0..A { + for j in 0..B { + for k in 0..C { + data.push(elems[i][j][k]); + } + } + } + + Data::new(data, Shape::new([A, B, C])) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn should_have_right_num_elements() { + let shape = Shape::new([3, 5, 6]); + let data = Data::::random(shape.clone()); + assert_eq!(shape.num_elements(), data.value.len()); + } + + #[test] + fn should_have_right_shape() { + let data = Data::from([[3.0, 5.0, 6.0]]); + assert_eq!(data.shape, Shape::new([1, 3])); + + let data = Data::from([[4.0, 5.0, 8.0], [3.0, 5.0, 6.0]]); + assert_eq!(data.shape, Shape::new([2, 3])); + + let data = Data::from([3.0, 5.0, 6.0]); + assert_eq!(data.shape, Shape::new([3])); + } +} diff --git a/burn-tensor/src/tensor/mod.rs b/burn-tensor/src/tensor/mod.rs new file mode 100644 index 0000000000..67d414250a --- /dev/null +++ b/burn-tensor/src/tensor/mod.rs @@ -0,0 +1,9 @@ +pub mod backend; + +mod data; +mod shape; +mod tensor; + +pub use data::*; +pub use shape::*; +pub use tensor::*; diff --git a/burn-tensor/src/tensor/shape.rs b/burn-tensor/src/tensor/shape.rs new file mode 100644 index 0000000000..8edc814bad --- /dev/null +++ b/burn-tensor/src/tensor/shape.rs @@ -0,0 +1,37 @@ +use std::ops::Range; + +#[derive(new, Debug, Clone, PartialEq)] +pub struct Shape { + pub dims: [usize; D], +} + +impl Shape { + pub fn num_elements(&self) -> usize { + let mut num_elements = 1; + for i in 0..D { + num_elements *= self.dims[i]; + } + + num_elements + } +} + +impl Shape { + pub fn index(&self, indexes: [Range; D2]) -> Self { + if D2 > D1 { + panic!("Cant index that"); + } + + let mut dims = [0; D1]; + + for i in 0..D2 { + dims[i] = indexes[i].clone().count(); + } + + for i in D2..D1 { + dims[i] = self.dims[i]; + } + + Self::new(dims) + } +} diff --git a/burn-tensor/src/tensor/tensor.rs b/burn-tensor/src/tensor/tensor.rs new file mode 100644 index 0000000000..fdda66e4cf --- /dev/null +++ b/burn-tensor/src/tensor/tensor.rs @@ -0,0 +1,58 @@ +use std::ops::Range; + +use crate::{Data, Shape}; + +pub enum TensorError { + ReshapeError(String), +} + +pub trait FloatTensor: + TensorBase + + TensorOpsMul + + TensorOpsNeg + + TensorOpsAdd + + TensorOpsMatmul + + std::fmt::Debug +{ +} + +pub trait TensorBase { + fn shape(&self) -> &Shape; + fn into_data(self) -> Data; + fn from>(other: O) -> Self; + fn empty(shape: Shape) -> Self; +} + +pub trait TensorOpsAdd: + std::ops::Add + std::ops::Add +where + Self: Sized, +{ + fn add(&self, other: &Self) -> Self; + fn add_scalar(&self, other: &P) -> Self; +} + +pub trait TensorOpsMatmul { + fn matmul(&self, other: &Self) -> Self; +} + +pub trait TensorOpsNeg: std::ops::Neg { + fn neg(&self) -> Self; +} + +pub trait TensorOpsMul: + std::ops::Mul + std::ops::Mul +where + Self: Sized, +{ + fn mul(&self, other: &Self) -> Self; + fn mul_scalar(&self, other: &P) -> Self; +} + +pub trait TensorOpsReshape> { + fn reshape(&self, shape: Shape) -> T; +} + +pub trait TensorOpsIndex { + fn index(&self, indexes: [Range; D2]) -> Self; +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000000..e69de29bb2