forked from shenzebang/Federated-Learning-Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapi.py
162 lines (128 loc) · 7.03 KB
/
api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import torch
from tqdm import trange
import ray
from utils.general_utils import _evaluate_ray, _evaluate, _acc_ray
class FedAlgorithm(object):
def __init__(self,
init_model,
client_dataloaders,
client_dataloaders_test,
loss,
loggers,
config,
device
):
self.client_dataloaders = client_dataloaders
self.client_dataloaders_test = client_dataloaders_test
self.loss = loss
self.loggers = loggers
self.config = config
self.device = device
self.server_state = self.server_init(init_model)
self.client_states = [self.client_init(self.server_state, client_dataloader, client_dataloader_test) for client_dataloader, client_dataloader_test in
zip(self.client_dataloaders, self.client_dataloaders_test)]
def step(self, server_state, client_states, weights):
# server_state contains the (global) model, (global) auxiliary variables, weights of clients
# client_states contain the (local) auxiliary variables
# sample active clients
active_ids = torch.randperm(self.config.n_workers)[:self.config.n_workers_per_round].tolist()
client_states = self.clients_step(client_states, weights, active_ids)
# aggregate
server_state = self.server_step(server_state, client_states, weights, active_ids)
# broadcast
client_states = self.clients_update(server_state, client_states, active_ids)
return server_state, client_states
def fit(self, weights, n_rounds=None):
if n_rounds is None:
n_rounds = self.config.n_global_rounds
_range = trange
else:
_range = range
for round in _range(n_rounds):
self.server_state, self.client_states = self.step(self.server_state, self.client_states, weights)
if round % self.config.eval_freq == 0 and self.loggers is not None:
for logger in self.loggers:
logger.log(round, self.server_state.model)
# def reset_states(self):
# self.server_state = self.server_init()
# self.client_states = [self.client_init(self.server_state, client_dataloader) for client_dataloader in
# self.client_dataloaders]
def server_init(self, init_model):
raise NotImplementedError
def client_init(self, server_state, client_dataloader):
raise NotImplementedError
def clients_step(self, clients_state, weights, active_ids):
raise NotImplementedError
def server_step(self, server_state, client_states, weights, active_ids):
raise NotImplementedError
def clients_update(self, server_state, clients_state, active_ids):
raise NotImplementedError
def clients_evaluate(self, active_ids=None):
if active_ids is None:
active_ids = list(range(len(self.client_states)))
client_dataloaders = [self.client_dataloaders[i] for i in active_ids]
client_dataloaders_test = [self.client_dataloaders_test[i] for i in active_ids]
if self.config.use_ray:
clients_loss = ray.get([_evaluate_ray.remote(self.loss, self.device, self.server_state.model, client_dataloader)
for client_dataloader in client_dataloaders])
clients_acc = ray.get([_acc_ray.remote(self.device, self.server_state.model, client_dataloader_test)
for client_dataloader_test in client_dataloaders_test])
else:
raise not NotImplementedError
#clients_loss, clients_acc = [_evaluate(self.loss, self.device, self.server_state.model, client_dataloader)
#for client_dataloader in active_clients]
#for client_dataloader in active_clients])
return clients_loss, clients_acc
def clients_evaluate_train(self, active_ids=None):
if active_ids is None:
active_ids = list(range(len(self.client_states)))
client_dataloaders = [self.client_dataloaders[i] for i in active_ids]
if self.config.use_ray:
clients_loss = ray.get([_evaluate_ray.remote(self.loss, self.device, self.server_state.model, client_dataloader)
for client_dataloader in client_dataloaders])
clients_acc = ray.get([_acc_ray.remote(self.device, self.server_state.model, client_dataloader)
for client_dataloader in client_dataloaders])
else:
raise not NotImplementedError
#clients_loss, clients_acc = [_evaluate(self.loss, self.device, self.server_state.model, client_dataloader)
#for client_dataloader in active_clients]
#for client_dataloader in active_clients])
return clients_loss, clients_acc
def clients_evaluate_test(self, active_ids=None):
if active_ids is None:
active_ids = list(range(len(self.client_states)))
client_dataloaders_test = [self.client_dataloaders_test[i] for i in active_ids]
if self.config.use_ray:
clients_loss_test = ray.get([_evaluate_ray.remote(self.loss, self.device, self.server_state.model, client_dataloader_test)
for client_dataloader_test in client_dataloaders_test])
clients_acc_test = ray.get([_acc_ray.remote(self.device, self.server_state.model, client_dataloader_test)
for client_dataloader_test in client_dataloaders_test])
else:
raise not NotImplementedError
#clients_loss, clients_acc = [_evaluate(self.loss, self.device, self.server_state.model, client_dataloader)
#for client_dataloader in active_clients]
#for client_dataloader in active_clients])
return clients_loss_test, clients_acc_test
class PrimalDualFedAlgorithm(object):
# augment FedAlgorithm with additional dual updates
def __init__(self, primal_fed_algorithm: FedAlgorithm, config, loggers=None, auxiliary_data=None):
self.config = config
self.loggers = loggers
# logger logs testing metrics of the current model
self.primal_fed_algorithm = primal_fed_algorithm
# self.primal_fed_algorithm is used to update the primal variable
self.server_state = self.server_init()
# server_state contains the primal and dual variables
self.auxiliary_data = auxiliary_data
def fit(self):
for round in trange(self.config.n_pd_rounds):
self.step()
if self.loggers is not None:
for logger in self.loggers:
logger.log(round * self.config.n_p_steps, self.server_state.model)
def step(self):
# update self.server_state
raise NotImplementedError
def server_init(self):
# should utilize self.primal_fed_algorithm
raise NotImplementedError