A shot at implementing automatic differentiation from scratch using JAX in my free time.
- Create a virtual environment using conda:
conda create --name jax-grad python=3.10
conda activate jax-grad
- Install dependencies:
pip install -r requirements.txt
jax-grad/
├── jax_grad/
│ ├── core/
│ │ ├── __init__.py
│ │ └── autodiff.py
│ ├── ops/
│ │ ├── __init__.py
│ │ └── math.py
│ ├── tensor/
│ │ ├── __init__.py
│ │ └── ops.py
│ ├── utils/
│ │ ├── __init__.py
│ │ └── grad_check.py
│ └── viz/
├── tests/
│ ├── unit/
│ ├── integration/
│ └── benchmarks/
├── examples/
│ └── basic_usage.py
├── docs/
├── setup.py
├── requirements.txt
└── README.md
- Forward-mode automatic differentiation
- Basic mathematical operations support
- Test suite for verification
Basic example:
from jax_grad import grad
def f(x):
return x ** 2
df = grad(f)
result = df(3.0) # Should return 6.0
For detailed implementation roadmap and future plans, please see ROADMAP.md