-
Notifications
You must be signed in to change notification settings - Fork 10
/
utils.py
126 lines (107 loc) · 4.16 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import copy
import torch
from torchvision import datasets, transforms
import numpy as np
def get_dataset(args):
""" Returns train and test datasets and a user group which is a dict where
the keys are the user index and the values are the corresponding data for
each of those users.
"""
if args.dataset == 'cifar':
data_dir = '../data/cifar/'
train_transform = transforms.Compose(
[transforms.ToTensor(),
#transforms.RandomCrop(size=24),
transforms.RandomApply(torch.nn.ModuleList([
transforms.ColorJitter(),]),p=0.5),
transforms.RandomAutocontrast(),
transforms.RandomHorizontalFlip(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])
test_transform = transforms.Compose(
[transforms.ToTensor(),
#transforms.RandomCrop(size=24),
transforms.Normalize((0.4914,0.4822,0.4465),(0.247,0.243,0.261))])
train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
transform=train_transform)
test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
transform=test_transform)
# sample training data amongst users
if args.iid:
# Sample IID user data from Mnist
user_groups = cifar_iid(train_dataset, args.num_users)
else:
# Sample Non-IID user data from Mnist
if args.unequal:
# Chose uneuqal splits for every user
raise NotImplementedError()
else:
# Chose euqal splits for every user
user_groups = cifar_noniid(train_dataset, args.num_users)
return train_dataset, test_dataset, user_groups
def average_weights(w):
"""
Returns the average of the weights.
"""
w_avg = copy.deepcopy(w[0])
for key in w_avg.keys():
for i in range(1, len(w)):
w_avg[key] += w[i][key]
w_avg[key] = torch.div(w_avg[key], len(w))
return w_avg
def exp_details(args):
print('\nExperimental details:')
print(f' Model : {args.model}')
print(f' Optimizer : {args.optimizer}')
print(f' Learning : {args.lr}')
print(f' Global Rounds : {args.epochs}\n')
print(' Federated parameters:')
if args.iid:
print(' IID')
else:
print(' Non-IID')
print(f' Fraction of users : {args.frac}')
print(f' Local Batch size : {args.local_bs}')
print(f' Local Epochs : {args.local_ep}\n')
return
def cifar_iid(dataset, num_users):
"""
Sample I.I.D. client data from CIFAR10 dataset
:param dataset:
:param num_users:
:return: dict of image index
"""
num_items = int(len(dataset)/num_users)
dict_users, all_idxs = {}, [i for i in range(len(dataset))]
for i in range(num_users):
dict_users[i] = set(np.random.choice(all_idxs, num_items,
replace=False))
all_idxs = list(set(all_idxs) - dict_users[i])
return dict_users
def cifar_noniid(dataset, num_users):
"""
Sample non-I.I.D client data from CIFAR10 dataset
:param dataset:
:param num_users:
:return:
"""
num_shards, num_imgs = 200, 250
idx_shard = [i for i in range(num_shards)]
dict_users = {i: np.array([]) for i in range(num_users)}
idxs = np.arange(num_shards*num_imgs)
# labels = dataset.train_labels.numpy()
labels = np.array(dataset.targets)
# sort labels
idxs_labels = np.vstack((idxs, labels))
#stack into two rows, sort the labels row
idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
idxs = idxs_labels[0, :]
# divide and assign
for i in range(num_users):
rand_set = set(np.random.choice(idx_shard, 2, replace=False))
idx_shard = list(set(idx_shard) - rand_set)
for rand in rand_set:
dict_users[i] = np.concatenate(
(dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
return dict_users