-
Notifications
You must be signed in to change notification settings - Fork 11
/
server.py
116 lines (100 loc) · 4.22 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import numpy as np
import random
import networkx as nx
from dtaidistance import dtw
class Server():
def __init__(self, model, device):
self.model = model.to(device)
self.W = {key: value for key, value in self.model.named_parameters()}
self.model_cache = []
def randomSample_clients(self, all_clients, frac):
return random.sample(all_clients, int(len(all_clients) * frac))
def aggregate_weights(self, selected_clients):
# pass train_size, and weighted aggregate
total_size = 0
for client in selected_clients:
total_size += client.train_size
for k in self.W.keys():
self.W[k].data = torch.div(torch.sum(torch.stack([torch.mul(client.W[k].data, client.train_size) for client in selected_clients]), dim=0), total_size).clone()
def compute_pairwise_similarities(self, clients):
client_dWs = []
for client in clients:
dW = {}
for k in self.W.keys():
dW[k] = client.dW[k]
client_dWs.append(dW)
return pairwise_angles(client_dWs)
def compute_pairwise_distances(self, seqs, standardize=False):
""" computes DTW distances """
if standardize:
# standardize to only focus on the trends
seqs = np.array(seqs)
seqs = seqs / seqs.std(axis=1).reshape(-1, 1)
distances = dtw.distance_matrix(seqs)
else:
distances = dtw.distance_matrix(seqs)
return distances
def min_cut(self, similarity, idc):
g = nx.Graph()
for i in range(len(similarity)):
for j in range(len(similarity)):
g.add_edge(i, j, weight=similarity[i][j])
cut, partition = nx.stoer_wagner(g)
c1 = np.array([idc[x] for x in partition[0]])
c2 = np.array([idc[x] for x in partition[1]])
return c1, c2
def aggregate_clusterwise(self, client_clusters):
for cluster in client_clusters:
targs = []
sours = []
total_size = 0
for client in cluster:
W = {}
dW = {}
for k in self.W.keys():
W[k] = client.W[k]
dW[k] = client.dW[k]
targs.append(W)
sours.append((dW, client.train_size))
total_size += client.train_size
# pass train_size, and weighted aggregate
reduce_add_average(targets=targs, sources=sours, total_size=total_size)
def compute_max_update_norm(self, cluster):
max_dW = -np.inf
for client in cluster:
dW = {}
for k in self.W.keys():
dW[k] = client.dW[k]
update_norm = torch.norm(flatten(dW)).item()
if update_norm > max_dW:
max_dW = update_norm
return max_dW
# return np.max([torch.norm(flatten(client.dW)).item() for client in cluster])
def compute_mean_update_norm(self, cluster):
cluster_dWs = []
for client in cluster:
dW = {}
for k in self.W.keys():
dW[k] = client.dW[k]
cluster_dWs.append(flatten(dW))
return torch.norm(torch.mean(torch.stack(cluster_dWs), dim=0)).item()
def cache_model(self, idcs, params, accuracies):
self.model_cache += [(idcs,
{name: params[name].data.clone() for name in params},
[accuracies[i] for i in idcs])]
def flatten(source):
return torch.cat([value.flatten() for value in source.values()])
def pairwise_angles(sources):
angles = torch.zeros([len(sources), len(sources)])
for i, source1 in enumerate(sources):
for j, source2 in enumerate(sources):
s1 = flatten(source1)
s2 = flatten(source2)
angles[i, j] = torch.true_divide(torch.sum(s1 * s2), max(torch.norm(s1) * torch.norm(s2), 1e-12)) + 1
return angles.numpy()
def reduce_add_average(targets, sources, total_size):
for target in targets:
for name in target:
tmp = torch.div(torch.sum(torch.stack([torch.mul(source[0][name].data, source[1]) for source in sources]), dim=0), total_size).clone()
target[name].data += tmp