Skip to content

Implementation of Automatic Differentiation from scratch in JAX

Notifications You must be signed in to change notification settings

lorentztransform/jax-grad

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

JAX Automatic Differentiation Implementation

A shot at implementing automatic differentiation from scratch using JAX in my free time.

Setup

  1. Create a virtual environment using conda:
conda create --name jax-grad python=3.10
conda activate jax-grad
  1. Install dependencies:
pip install -r requirements.txt

Project Structure

jax-grad/
├── jax_grad/                  
│   ├── core/                  
│   │   ├── __init__.py
│   │   └── autodiff.py        
│   ├── ops/                   
│   │   ├── __init__.py
│   │   └── math.py           
│   ├── tensor/               
│   │   ├── __init__.py
│   │   └── ops.py           
│   ├── utils/
│   │   ├── __init__.py
│   │   └── grad_check.py    
│   └── viz/                  
├── tests/                    
│   ├── unit/
│   ├── integration/
│   └── benchmarks/
├── examples/
│   └── basic_usage.py
├── docs/
├── setup.py
├── requirements.txt
└── README.md

Features

  • Forward-mode automatic differentiation
  • Basic mathematical operations support
  • Test suite for verification

Usage

Basic example:

from jax_grad import grad

def f(x):
    return x ** 2

df = grad(f)
result = df(3.0)  # Should return 6.0

Roadmap

For detailed implementation roadmap and future plans, please see ROADMAP.md

About

Implementation of Automatic Differentiation from scratch in JAX

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages