This is to-go pytorch template utilizing lightning and wandb.
This template uses Lightning CLI
for config management.
It follows most of Lightning CLI docs but, integrated with wandb
.
Since Lightning CLI
instantiate classes on-the-go, there were some work-around while integrating WandbLogger
to the template.
This might not be the best practice, but still it works and quite convinient.
It uses Lightning CLI
, so most of its usage can be found at its official docs.
There are some added arguments related to wandb
.
--name
or-n
: Name of the run, displayed inwandb
--version
or-v
: Version of the run, displayed inwandb
as tags
Basic cmdline usage is as follows.
We assume cwd is project root dir.
python src/main.py fit -c configs/config.yaml -n debug-fit-run -v debug-version
If using wandb
for logging, change "project"
key in cli_module/rich_wandb.py
If you want to access log directory in your LightningModule
, you can access as follows.
log_root_dir = self.logger.log_dir or self.logger.save_dir
If using wandb
for logging, model ckpt files are uploaded to wandb
.
Since the size of ckpt files are too large, clean-up process needed.
Clean-up process delete all model ckpt artifacts without any aliases (e.g. best
, lastest
)
To toggle off the clean-up process, add the following to config.yaml
. Then every version of model ckpt files will be saved to wandb
.
trainer:
logger:
init_args:
clean: false
One can save model checkpoints using Lightning Callbacks
.
It contains model weight, and other state_dict for resuming train.
There are several ways to save ckpt files at either local or cloud.
-
Just leave everything in default, ckpt files will be saved locally. (at
logs/${name}/${version}/fit/checkpoints
) -
If you want to save ckpt files as
wandb
Artifacts, add the following config. (The ckpt files will be saved locally too.)
trainer:
logger:
init_args:
log_model: all
- If you want to save ckpt files in cloud rather than local, you can change the save path by adding the config. (The ckpt files will NOT be saved locally.)
model_ckpt:
dirpath: gs://bucket_name/path/for/checkpoints
You can set async checkpoint saving by providing config as follows.
trainer:
plugins:
- AsyncCheckpointIO
Just add BatchSizeFinder
callbacks in the config
trainer:
callbacks:
- class_path: BatchSizeFinder
Or add them in the cmdline.
python src/main.py fit -c configs/config.yaml --trainer.callbacks+=BatchSizeFinder
python src/tune.py -c configs/config.yaml
NOTE: No subcommand in cmdline
Basically all logs are stored in logs/${name}/${version}/${job_type}
where ${name}
and ${version}
are configured in yaml file or cmdline.
{job_type}
can be one of fit
, test
, validate
, etc.
python src/main.py test -c configs/config.yaml -n debug-test-run -v debug-version --ckpt_path YOUR_CKPT_PATH
- Check pretrained weight loading
- Consider multiple optimizer using cases (i.e. GAN)
- Add instructions in README (on-going)
- Clean code