A repo on federated learning for research purpose. The implementation is based on pytorch. To support acceleration with multiple GPUs, we use ray.
Currently supporting multiclass image classification task.
Inlcuding baselines:
- fed-avg
- *scaffold
- *fedprox
- *mime
- *feddyn ...
Including models:
- MLP
- LeNet-5
- Resnet
- *VGG
- *Alexnet ...
Including datasets:
- cifar10
- mnist
- *emnist
- *cifar100
- *imagenet
Structure of the code:
- load configuration
- prepare the local datasets
- prepare logger, local objective function
- run FL
- save model
To adapt the current template to your algorithm, simply implement the following five functions:
- server_init
- client_init
- clients_step
- server_step
- clients_update