Skip to content

Commit

Permalink
First Commit :D
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Jul 18, 2022
1 parent ebea760 commit 902f431
Show file tree
Hide file tree
Showing 49 changed files with 2,869 additions and 0 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,8 @@ Cargo.lock

# These are backup files generated by rustfmt
**/*.rs.bk


# Added by cargo

/target
8 changes: 8 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[package]
name = "burn"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
15 changes: 15 additions & 0 deletions burn-tensor/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Generated by Cargo
# will have compiled files and executables
/target/

# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
Cargo.lock

# These are backup files generated by rustfmt
**/*.rs.bk


# Added by cargo

/target
26 changes: 26 additions & 0 deletions burn-tensor/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
[package]
name = "burn-tensor"
version = "0.1.0"
edition = "2021"

[features]
default = ["tch"]
full = ["arrayfire", "tch"]
arrayfire = ["dep:arrayfire"]
tch = ["dep:tch"]

[dependencies]
derive-new = "0.5"
rand = "0.8"
num-traits = "0.2"
once_cell = "1.1"

# Backends
arrayfire = { version = "3.8", optional = true }
tch = { version = "0.8", optional = true }

# Autodiff
nanoid = "0.4"

[dev-dependencies]
float-cmp = "0.9.0"
7 changes: 7 additions & 0 deletions burn-tensor/env.bash
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
echo "Setup arrayfire backend"
export AF_PATH=$HOME/.local/share/arrayfire
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$AF_PATH/lib64

echo "Setup tch backend"
export LIBTORCH=${HOME}/.local/lib/libtorch
export LD_LIBRARY_PATH=${LIBTORCH}/lib:$LD_LIBRARY_PATH
5 changes: 5 additions & 0 deletions burn-tensor/src/graph/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pub mod node;
pub mod ops;
pub mod tape;

mod node_old;
95 changes: 95 additions & 0 deletions burn-tensor/src/graph/node.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use std::{
cell::RefCell,
ops::{Add, Mul},
rc::Rc,
};

#[derive(PartialEq, Eq, Debug, Clone, Hash)]
pub struct NodeId {
value: String,
}

impl NodeId {
pub fn new() -> Self {
Self {
value: nanoid::nanoid!(),
}
}
}

pub trait Node<Out>: std::fmt::Debug {
fn id(&self) -> NodeId;
fn grad(&mut self) -> Out;
fn value(&self) -> Out;
fn update_grad(&mut self, grad: Out);
}
pub type NodeRef<T> = Rc<RefCell<dyn Node<T>>>;

pub trait Zeros<T> {
fn zeros(&self) -> T;
}
pub trait Ones<T> {
fn ones(&self) -> T;
}

#[derive(Debug)]
pub struct RootNode<Out> {
pub id: NodeId,
pub value: Out,
pub grad: Option<Out>,
}

impl<Out> RootNode<Out> {
pub fn new(value: Out) -> Self {
Self {
id: NodeId::new(),
value,
grad: None,
}
}
}

impl<Out> Node<Out> for RootNode<Out>
where
Out: Zeros<Out> + Clone + Mul<Output = Out> + Add<Output = Out>,
Out: std::fmt::Debug,
{
fn id(&self) -> NodeId {
self.id.clone()
}
fn grad(&mut self) -> Out {
let grad_self: Out = match &self.grad {
Some(val) => val.clone(),
None => self.value.zeros(),
};
self.grad = Some(grad_self.clone());
grad_self
}

fn value(&self) -> Out {
self.value.clone()
}

fn update_grad(&mut self, grad: Out) {
self.grad = Some(self.grad() + grad);
}
}

#[macro_export]
macro_rules! node_init {
( lhs $lhs:expr, rhs $rhs:expr, out $out:expr, ) => {{
use $crate::graph::ops::BinaryOpsNode;
let node = BinaryOpsNode::new($lhs, $rhs, $out);
std::rc::Rc::new(std::cell::RefCell::new(node))
}};
( input $input:expr, out $out:expr, ) => {{
use $crate::graph::ops::SingleOpsNode;
let node = SingleOpsNode::new($input, $out);
std::rc::Rc::new(std::cell::RefCell::new(node))
}};
( root $out:expr ) => {{
use $crate::graph::node::RootNode;
let node = RootNode::new($out);
std::rc::Rc::new(std::cell::RefCell::new(node))
}};
}
Loading

0 comments on commit 902f431

Please sign in to comment.