Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
QinbinLi committed Mar 28, 2021
1 parent 8d3a952 commit 78e2c3b
Show file tree
Hide file tree
Showing 7 changed files with 2,190 additions and 0 deletions.
64 changes: 64 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Model-Contrastive Federated Learning
This is the code for paper "Model-Contrastive Federated Learning".

**Abstract**: Federated learning enables multiple parties to collaboratively train a machine learning model without communicating their local data. A key challenge in federated learning is to handle the heterogeneity of local data distribution across parties. Although many studies have been proposed to address this challenge, we find that they fail to achieve high performance in image datasets with deep learning models. In this paper, we propose MOON: model-contrastive federated learning. MOON is a simple and effective federated learning framework. The key idea of MOON is to utilize the similarity between model representations to correct the local training of individual parties, i.e., conducting contrastive learning in model-level. Our extensive experiments show that MOON significantly outperforms the other state-of-the-art federated learning algorithms on various image classification tasks.

## Dependencies
* PyTorch 1.0.0
* torchvision 0.2.1
* scikit-learn >= 0.23.1



## Parameters

| Parameter | Description |
| ----------------------------- | ---------------------------------------- |
| `model` | The model architecture. Options: `simple-cnn`, `resnet50` .|
| `alg` | The training algorithm. Options: `moon`, `fedavg`, `fedprox`, `local_training` |
| `dataset` | Dataset to use. Options: `cifar10`. `cifar100`, `tinyimagenet`|
| `lr` | Learning rate. |
| `batch-size` | Batch size. |
| `epochs` | Number of local epochs. |
| `n_parties` | Number of parties. |
| `sample_fraction` | the fraction of parties to be sampled in each round. |
| `comm_round` | Number of communication rounds. |
| `partition` | The partition approach. Options: `noniid`, `iid`. |
| `beta` | The concentration parameter of the Dirichlet distribution for non-IID partition. |
| `mu` | The parameter for MOON and FedProx. |
| `temperature` | The temperature parameter for MOON. |
| `out_dim` | The output dimension of the projection head. |
| `datadir` | The path of the dataset. |
| `logdir` | The path to store the logs. |
| `device` | Specify the device to run the program. |
| `seed` | The initial seed. |


## Usage

Here is an example to run MOON on CIFAR-10 with a simple CNN:
```
python main.py --dataset=cifar10 \
--model=simple-cnn \
--alg=moon \
--lr=0.01 \
--mu=5 \
--epochs=10 \
--comm_round=100 \
--n_parties=10 \
--partition=noniid \
--beta=0.5 \
--logdir='./logs/' \
--datadir='./data/' \
```

## Citation
Please cite our paper if you find this code useful for your research.
```
@inproceedings{li2021model,
title={Model-Contrastive Federated Learning},
author={Qinbin Li and Bingsheng He and Dawn Song},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2021},
}
```
179 changes: 179 additions & 0 deletions datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import torch.utils.data as data
from PIL import Image
import numpy as np
import torchvision
from torchvision.datasets import MNIST, EMNIST, CIFAR10, CIFAR100, SVHN, FashionMNIST, ImageFolder, DatasetFolder, utils

import os
import os.path
import logging

logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)

IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')


def mkdirs(dirpath):
try:
os.makedirs(dirpath)
except Exception as _:
pass



class CIFAR10_truncated(data.Dataset):

def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False):

self.root = root
self.dataidxs = dataidxs
self.train = train
self.transform = transform
self.target_transform = target_transform
self.download = download

self.data, self.target = self.__build_truncated_dataset__()

def __build_truncated_dataset__(self):

cifar_dataobj = CIFAR10(self.root, self.train, self.transform, self.target_transform, self.download)

if torchvision.__version__ == '0.2.1':
if self.train:
data, target = cifar_dataobj.train_data, np.array(cifar_dataobj.train_labels)
else:
data, target = cifar_dataobj.test_data, np.array(cifar_dataobj.test_labels)
else:
data = cifar_dataobj.data
target = np.array(cifar_dataobj.targets)

if self.dataidxs is not None:
data = data[self.dataidxs]
target = target[self.dataidxs]

return data, target

def truncate_channel(self, index):
for i in range(index.shape[0]):
gs_index = index[i]
self.data[gs_index, :, :, 1] = 0.0
self.data[gs_index, :, :, 2] = 0.0

def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.target[index]
# img = Image.fromarray(img)
# print("cifar10 img:", img)
# print("cifar10 target:", target)

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def __len__(self):
return len(self.data)


class CIFAR100_truncated(data.Dataset):

def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False):

self.root = root
self.dataidxs = dataidxs
self.train = train
self.transform = transform
self.target_transform = target_transform
self.download = download

self.data, self.target = self.__build_truncated_dataset__()

def __build_truncated_dataset__(self):

cifar_dataobj = CIFAR100(self.root, self.train, self.transform, self.target_transform, self.download)

if torchvision.__version__ == '0.2.1':
if self.train:
data, target = cifar_dataobj.train_data, np.array(cifar_dataobj.train_labels)
else:
data, target = cifar_dataobj.test_data, np.array(cifar_dataobj.test_labels)
else:
data = cifar_dataobj.data
target = np.array(cifar_dataobj.targets)

if self.dataidxs is not None:
data = data[self.dataidxs]
target = target[self.dataidxs]

return data, target

def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.target[index]
img = Image.fromarray(img)
# print("cifar10 img:", img)
# print("cifar10 target:", target)

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def __len__(self):
return len(self.data)




class ImageFolder_custom(DatasetFolder):
def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None):
self.root = root
self.dataidxs = dataidxs
self.train = train
self.transform = transform
self.target_transform = target_transform

imagefolder_obj = ImageFolder(self.root, self.transform, self.target_transform)
self.loader = imagefolder_obj.loader
if self.dataidxs is not None:
self.samples = np.array(imagefolder_obj.samples)[self.dataidxs]
else:
self.samples = np.array(imagefolder_obj.samples)

def __getitem__(self, index):
path = self.samples[index][0]
target = self.samples[index][1]
target = int(target)
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)

return sample, target

def __len__(self):
if self.dataidxs is None:
return len(self.samples)
else:
return len(self.dataidxs)
Loading

0 comments on commit 78e2c3b

Please sign in to comment.