An opinionated general purpose model trainer on PyTorch with a simple code base.
From Github:
git clone https://github.com/coqui-ai/Trainer
cd Trainer
make install
From PyPI:
pip install trainer
Prefer installing from Github as it is more stable.
Subclass and overload the functions in the TrainerModel()
See the test script here training a basic MNIST model.
$ python -m trainer.distribute --script path/to/your/train.py --gpus "0,1"
We don't use .spawn()
to initiate multi-gpu training since it causes certain limitations.
- Everything must the pickable.
.spawn()
trains the model in subprocesses and the model in the main process is not updated.- DataLoader with N processes gets really slow when the N is large.
- Create the torch profiler as you like and pass it to the trainer.
import torch profiler = torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ], schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), on_trace_ready=torch.profiler.tensorboard_trace_handler("./profiler/"), record_shapes=True, profile_memory=True, with_stack=True, ) prof = trainer.profile_fit(profiler, epochs=1, small_run=64) then run Tensorboard
- Run the tensorboard.
tensorboard --logdir="./profiler/"
- Tensorboard - actively maintained
- ClearML - actively maintained
- MLFlow
- Aim
- WandDB
To add a new logger, you must subclass BaseDashboardLogger and overload its functions.