This repository was archived by the owner on Jun 3, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtest_simulator.py
107 lines (82 loc) · 3.11 KB
/
test_simulator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# @Author : FederalLab
# @Date : 2021-09-25 16:57:22
# @Last Modified by : Chen Dengsheng
# @Last Modified time: 2021-09-25 16:57:22
# Copyright (c) FederalLab. All rights reserved.
import random
import pytest
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import openfed
from openfed.data import IIDPartitioner, PartitionerDataset
def main_function(props):
props = openfed.federated.FederatedProperties.load(props)
assert len(props) == 1
props = props[0]
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
network = nn.Linear(784, 10).to(device)
loss_fn = nn.CrossEntropyLoss().to(device)
sgd = torch.optim.SGD(
network.parameters(), lr=1.0 if props.aggregator else 0.1)
fed_sgd = openfed.optim.FederatedOptimizer(sgd, props.role)
maintainer = openfed.core.Maintainer(props,
network.state_dict(keep_vars=True))
with maintainer:
openfed.functional.device_alignment()
if props.aggregator:
openfed.functional.count_step(props.address.world_size - 1)
rounds = 1
if maintainer.aggregator:
api = openfed.API(maintainer, fed_sgd, rounds,
openfed.functional.average_aggregation)
api.run()
else:
mnist = MNIST(r'/tmp/', True, ToTensor(), download=True)
fed_mnist = PartitionerDataset(
mnist, total_parts=100, partitioner=IIDPartitioner())
dataloader = DataLoader(
fed_mnist,
batch_size=10,
shuffle=True,
num_workers=0,
drop_last=False)
for _ in range(rounds):
maintainer.step(upload=False)
part_id = random.randint(0, 9)
fed_mnist.set_part_id(part_id)
network.train()
losses = []
for data in dataloader:
x, y = data
x, y = x.to(device), y.to(device)
output = network(x.view(-1, 784))
loss = loss_fn(output, y)
fed_sgd.zero_grad()
loss.backward()
fed_sgd.step()
losses.append(loss.item())
loss = sum(losses) / len(losses)
fed_sgd.round()
maintainer.update_version()
maintainer.package(fed_sgd)
maintainer.step(download=False)
fed_sgd.clear_state_dict()
@pytest.mark.run(order=0)
def test_build_centralized_topology():
from openfed.tools.simulator import build_centralized_topology
build_centralized_topology(3, tcp=True)
@pytest.mark.run(order=3)
def test_simulator_aggregator():
test_build_centralized_topology()
main_function('/tmp/aggregator.json')
@pytest.mark.run(order=3)
def test_simulator_collaborator_alpha():
test_build_centralized_topology()
main_function('/tmp/collaborator-1.json')
@pytest.mark.run(order=3)
def test_simulator_collaborator_beta():
test_build_centralized_topology()
main_function('/tmp/collaborator-2.json')