-
Notifications
You must be signed in to change notification settings - Fork 0
/
client.py
124 lines (86 loc) · 2.84 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
class Client:
r"""Represents a client participating in the learning process
Attributes
----------
client_id:
client_id: int
trainer: Trainer
device: str or torch.device
train_loader: torch.utils.data.DataLoader
val_loader: torch.utils.data.DataLoader
test_loader: torch.utils.data.DataLoader
train_iterator:
local_steps: int
metadata: dict
verbose:
"""
def __init__(
self,
client_id,
local_steps,
verbose,
trainer=None,
train_loader=None,
val_loader=None,
test_loader=None,
logger=None
):
self.client_id = client_id
self.trainer = trainer
self.device = self.trainer.device
if train_loader is not None:
self.train_loader = train_loader
self.val_loader = val_loader
self.test_loader = test_loader
self.num_samples = len(self.train_loader.dataset)
self.train_iterator = iter(self.train_loader)
self.is_ready = True
else:
self.is_ready = False
self.local_steps = local_steps
self.verbose = verbose
self.logger = logger
self.metadata = dict()
self.counter = 0
def get_next_batch(self):
"""yields next batch from data generator
Returns
-------
* Tuple(torch.tensor, torch.tensor): batch given as a tuple of tensors
"""
try:
batch = next(self.train_iterator)
except StopIteration:
self.train_iterator = iter(self.train_loader)
batch = next(self.train_iterator)
return batch
def step(self, by_batch=False):
"""perform one local step
Parameters
----------
by_batch: bool
Returns
-------
float: loss
"""
self.counter += 1
if by_batch:
batch = self.get_next_batch()
self.trainer.fit_batch(batch=batch)
else:
self.trainer.fit_epochs(
loader=self.train_loader,
n_epochs=self.local_steps
)
def write_logs(self, counter=None):
if counter is None:
counter = self.counter
train_loss, train_metric = self.trainer.evaluate_loader(self.train_loader)
test_loss, test_metric = self.trainer.evaluate_loader(self.test_loader)
if self.verbose > 0:
self.logger.add_scalar("Train/Loss", train_loss, counter)
self.logger.add_scalar("Train/Metric", train_metric, counter)
self.logger.add_scalar("Test/Loss", test_loss, counter)
self.logger.add_scalar("Test/Metric", test_metric, counter)
self.logger.flush()
return train_loss, train_metric, test_loss, test_metric