Why? Because T2T has gotten too complex. We are simplifying the main code too, but we wanted to try a more radical step. So you can write code as in pure NumPy and debug directly. So you can easily pinpoint each line where things happen and understand each function. But we also want it to run fast on accelerators, and that's possible with JAX.
Status: preview; things work: models train, checkpoints are saved, TensorBoard has summaries, you can decode. But we are changing a lot every day for now. Please let us know what we should add, delete, keep, change. We plan to move the best parts into core JAX.
Entrypoints:
- Script:
trainer.py
- Main library entrypoint:
trax.train
See our example constructing language models from scratch in a GPU-backed colab notebook at Trax Demo
python -m trax.trainer \
--dataset=mnist \
--model=MLP \
--config="train.train_steps=1000"
python -m trax.trainer \
--config_file=$PWD/trax/configs/resnet50_imagenet_8gb.gin
python -m trax.trainer \
--config_file=$PWD/trax/configs/transformer_lm1b_8gb.gin
- Configuration is done with
gin
.trainer.py
takes--config_file
as well as--config
for file overrides. - Models are defined with
stax
inmodels/
. They are made gin-configurable inmodels/__init__.py
. - Datasets are simple iterators over batches. Datasets from
tensorflow/datasets
andtensor2tensor
are built-in and can be addressed by name.