forked from QinbinLi/MOON
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
2,190 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Model-Contrastive Federated Learning | ||
This is the code for paper "Model-Contrastive Federated Learning". | ||
|
||
**Abstract**: Federated learning enables multiple parties to collaboratively train a machine learning model without communicating their local data. A key challenge in federated learning is to handle the heterogeneity of local data distribution across parties. Although many studies have been proposed to address this challenge, we find that they fail to achieve high performance in image datasets with deep learning models. In this paper, we propose MOON: model-contrastive federated learning. MOON is a simple and effective federated learning framework. The key idea of MOON is to utilize the similarity between model representations to correct the local training of individual parties, i.e., conducting contrastive learning in model-level. Our extensive experiments show that MOON significantly outperforms the other state-of-the-art federated learning algorithms on various image classification tasks. | ||
|
||
## Dependencies | ||
* PyTorch 1.0.0 | ||
* torchvision 0.2.1 | ||
* scikit-learn >= 0.23.1 | ||
|
||
|
||
|
||
## Parameters | ||
|
||
| Parameter | Description | | ||
| ----------------------------- | ---------------------------------------- | | ||
| `model` | The model architecture. Options: `simple-cnn`, `resnet50` .| | ||
| `alg` | The training algorithm. Options: `moon`, `fedavg`, `fedprox`, `local_training` | | ||
| `dataset` | Dataset to use. Options: `cifar10`. `cifar100`, `tinyimagenet`| | ||
| `lr` | Learning rate. | | ||
| `batch-size` | Batch size. | | ||
| `epochs` | Number of local epochs. | | ||
| `n_parties` | Number of parties. | | ||
| `sample_fraction` | the fraction of parties to be sampled in each round. | | ||
| `comm_round` | Number of communication rounds. | | ||
| `partition` | The partition approach. Options: `noniid`, `iid`. | | ||
| `beta` | The concentration parameter of the Dirichlet distribution for non-IID partition. | | ||
| `mu` | The parameter for MOON and FedProx. | | ||
| `temperature` | The temperature parameter for MOON. | | ||
| `out_dim` | The output dimension of the projection head. | | ||
| `datadir` | The path of the dataset. | | ||
| `logdir` | The path to store the logs. | | ||
| `device` | Specify the device to run the program. | | ||
| `seed` | The initial seed. | | ||
|
||
|
||
## Usage | ||
|
||
Here is an example to run MOON on CIFAR-10 with a simple CNN: | ||
``` | ||
python main.py --dataset=cifar10 \ | ||
--model=simple-cnn \ | ||
--alg=moon \ | ||
--lr=0.01 \ | ||
--mu=5 \ | ||
--epochs=10 \ | ||
--comm_round=100 \ | ||
--n_parties=10 \ | ||
--partition=noniid \ | ||
--beta=0.5 \ | ||
--logdir='./logs/' \ | ||
--datadir='./data/' \ | ||
``` | ||
|
||
## Citation | ||
Please cite our paper if you find this code useful for your research. | ||
``` | ||
@inproceedings{li2021model, | ||
title={Model-Contrastive Federated Learning}, | ||
author={Qinbin Li and Bingsheng He and Dawn Song}, | ||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, | ||
year={2021}, | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
import torch.utils.data as data | ||
from PIL import Image | ||
import numpy as np | ||
import torchvision | ||
from torchvision.datasets import MNIST, EMNIST, CIFAR10, CIFAR100, SVHN, FashionMNIST, ImageFolder, DatasetFolder, utils | ||
|
||
import os | ||
import os.path | ||
import logging | ||
|
||
logging.basicConfig() | ||
logger = logging.getLogger() | ||
logger.setLevel(logging.INFO) | ||
|
||
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') | ||
|
||
|
||
def mkdirs(dirpath): | ||
try: | ||
os.makedirs(dirpath) | ||
except Exception as _: | ||
pass | ||
|
||
|
||
|
||
class CIFAR10_truncated(data.Dataset): | ||
|
||
def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False): | ||
|
||
self.root = root | ||
self.dataidxs = dataidxs | ||
self.train = train | ||
self.transform = transform | ||
self.target_transform = target_transform | ||
self.download = download | ||
|
||
self.data, self.target = self.__build_truncated_dataset__() | ||
|
||
def __build_truncated_dataset__(self): | ||
|
||
cifar_dataobj = CIFAR10(self.root, self.train, self.transform, self.target_transform, self.download) | ||
|
||
if torchvision.__version__ == '0.2.1': | ||
if self.train: | ||
data, target = cifar_dataobj.train_data, np.array(cifar_dataobj.train_labels) | ||
else: | ||
data, target = cifar_dataobj.test_data, np.array(cifar_dataobj.test_labels) | ||
else: | ||
data = cifar_dataobj.data | ||
target = np.array(cifar_dataobj.targets) | ||
|
||
if self.dataidxs is not None: | ||
data = data[self.dataidxs] | ||
target = target[self.dataidxs] | ||
|
||
return data, target | ||
|
||
def truncate_channel(self, index): | ||
for i in range(index.shape[0]): | ||
gs_index = index[i] | ||
self.data[gs_index, :, :, 1] = 0.0 | ||
self.data[gs_index, :, :, 2] = 0.0 | ||
|
||
def __getitem__(self, index): | ||
""" | ||
Args: | ||
index (int): Index | ||
Returns: | ||
tuple: (image, target) where target is index of the target class. | ||
""" | ||
img, target = self.data[index], self.target[index] | ||
# img = Image.fromarray(img) | ||
# print("cifar10 img:", img) | ||
# print("cifar10 target:", target) | ||
|
||
if self.transform is not None: | ||
img = self.transform(img) | ||
|
||
if self.target_transform is not None: | ||
target = self.target_transform(target) | ||
|
||
return img, target | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
|
||
class CIFAR100_truncated(data.Dataset): | ||
|
||
def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False): | ||
|
||
self.root = root | ||
self.dataidxs = dataidxs | ||
self.train = train | ||
self.transform = transform | ||
self.target_transform = target_transform | ||
self.download = download | ||
|
||
self.data, self.target = self.__build_truncated_dataset__() | ||
|
||
def __build_truncated_dataset__(self): | ||
|
||
cifar_dataobj = CIFAR100(self.root, self.train, self.transform, self.target_transform, self.download) | ||
|
||
if torchvision.__version__ == '0.2.1': | ||
if self.train: | ||
data, target = cifar_dataobj.train_data, np.array(cifar_dataobj.train_labels) | ||
else: | ||
data, target = cifar_dataobj.test_data, np.array(cifar_dataobj.test_labels) | ||
else: | ||
data = cifar_dataobj.data | ||
target = np.array(cifar_dataobj.targets) | ||
|
||
if self.dataidxs is not None: | ||
data = data[self.dataidxs] | ||
target = target[self.dataidxs] | ||
|
||
return data, target | ||
|
||
def __getitem__(self, index): | ||
""" | ||
Args: | ||
index (int): Index | ||
Returns: | ||
tuple: (image, target) where target is index of the target class. | ||
""" | ||
img, target = self.data[index], self.target[index] | ||
img = Image.fromarray(img) | ||
# print("cifar10 img:", img) | ||
# print("cifar10 target:", target) | ||
|
||
if self.transform is not None: | ||
img = self.transform(img) | ||
|
||
if self.target_transform is not None: | ||
target = self.target_transform(target) | ||
|
||
return img, target | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
|
||
|
||
|
||
class ImageFolder_custom(DatasetFolder): | ||
def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None): | ||
self.root = root | ||
self.dataidxs = dataidxs | ||
self.train = train | ||
self.transform = transform | ||
self.target_transform = target_transform | ||
|
||
imagefolder_obj = ImageFolder(self.root, self.transform, self.target_transform) | ||
self.loader = imagefolder_obj.loader | ||
if self.dataidxs is not None: | ||
self.samples = np.array(imagefolder_obj.samples)[self.dataidxs] | ||
else: | ||
self.samples = np.array(imagefolder_obj.samples) | ||
|
||
def __getitem__(self, index): | ||
path = self.samples[index][0] | ||
target = self.samples[index][1] | ||
target = int(target) | ||
sample = self.loader(path) | ||
if self.transform is not None: | ||
sample = self.transform(sample) | ||
if self.target_transform is not None: | ||
target = self.target_transform(target) | ||
|
||
return sample, target | ||
|
||
def __len__(self): | ||
if self.dataidxs is None: | ||
return len(self.samples) | ||
else: | ||
return len(self.dataidxs) |
Oops, something went wrong.