Skip to content
/ trax Public
forked from google/trax

Trax — Deep Learning with Clear Code and Speed

License

Notifications You must be signed in to change notification settings

jcombari/trax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

trax: Train Neural Nets with JAX

train tracks

trax: T2T Radically Simpler with JAX

Why? Because T2T has gotten too complex. We are simplifying the main code too, but we wanted to try a more radical step. So you can write code as in pure NumPy and debug directly. So you can easily pinpoint each line where things happen and understand each function. But we also want it to run fast on accelerators, and that's possible with JAX.

Status: preview; things work: models train, checkpoints are saved, TensorBoard has summaries, you can decode. But we are changing a lot every day for now. Please let us know what we should add, delete, keep, change. We plan to move the best parts into core JAX.

Entrypoints:

  • Script: trainer.py
  • Main library entrypoint: trax.train

Examples

Example Colab

See our example constructing language models from scratch in a GPU-backed colab notebook at Trax Demo

MLP on MNIST

python -m trax.trainer \
  --dataset=mnist \
  --model=MLP \
  --config="train.train_steps=1000"

Resnet50 on Imagenet

python -m trax.trainer \
  --config_file=$PWD/trax/configs/resnet50_imagenet_8gb.gin

TransformerDecoder on LM1B

python -m trax.trainer \
  --config_file=$PWD/trax/configs/transformer_lm1b_8gb.gin

How trax differs from T2T

  • Configuration is done with gin. trainer.py takes --config_file as well as --config for file overrides.
  • Models are defined with stax in models/. They are made gin-configurable in models/__init__.py.
  • Datasets are simple iterators over batches. Datasets from tensorflow/datasets and tensor2tensor are built-in and can be addressed by name.

About

Trax — Deep Learning with Clear Code and Speed

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Python 58.7%
  • Jupyter Notebook 41.1%
  • Shell 0.2%