forked from OpenMined/PySyft
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsandbox.py
168 lines (142 loc) · 5.34 KB
/
sandbox.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import importlib
import torch
from syft.frameworks.torch.hook.hook import TorchHook
from syft.workers.virtual import VirtualWorker
from syft.grid.private_grid import PrivateGridNetwork
from syft.exceptions import DependencyError
def create_sandbox(gbs, verbose=True, download_data=True):
"""There's some boilerplate stuff that most people who are
just playing around would like to have. This will create
that for you"""
try:
torch = gbs["torch"]
except:
torch = gbs["th"]
global hook
global bob
global theo
global alice
global andy
global jason
global jon
if download_data and importlib.util.find_spec("sklearn") is None:
raise DependencyError("sklearn", "scikit-learn")
if download_data: # pragma: no cover
from sklearn.datasets import load_boston
from sklearn.datasets import load_breast_cancer
from sklearn.datasets import load_digits
from sklearn.datasets import load_diabetes
from sklearn.datasets import load_iris
from sklearn.datasets import load_wine
from sklearn.datasets import load_linnerud
def load_sklearn(func, *tags):
dataset = func()
data = (
torch.tensor(dataset["data"])
.float()
.tag(*(list(tags) + ["#data"] + dataset["DESCR"].split("\n")[0].lower().split(" ")))
.describe(dataset["DESCR"])
)
target = (
torch.tensor(dataset["target"])
.float()
.tag(
*(list(tags) + ["#target"] + dataset["DESCR"].split("\n")[0].lower().split(" "))
)
.describe(dataset["DESCR"])
)
return data, target
def distribute_dataset(data, workers):
batch_size = int(data.shape[0] / len(workers))
n_batches = len(workers)
for batch_i in range(n_batches - 1):
batch = data[batch_i * batch_size : (batch_i + 1) * batch_size]
batch.tags = data.tags
batch.description = data.description
ptr = batch.send(workers[batch_i])
ptr.child.garbage_collect_data = False
batch = data[(n_batches - 1) * batch_size :]
batch.tags = data.tags
batch.description = data.description
ptr = batch.send(workers[n_batches - 1])
ptr.child.garbage_collect_data = False
print("Setting up Sandbox...")
if verbose:
print("\t- Hooking PyTorch")
hook = TorchHook(torch)
if verbose:
print("\t- Creating Virtual Workers:")
print("\t\t- bob")
bob = VirtualWorker(hook, id="bob")
if verbose:
print("\t\t- theo")
theo = VirtualWorker(hook, id="theo")
if verbose:
print("\t\t- jason")
jason = VirtualWorker(hook, id="jason")
if verbose:
print("\t\t- alice")
alice = VirtualWorker(hook, id="alice")
if verbose:
print("\t\t- andy")
andy = VirtualWorker(hook, id="andy")
if verbose:
print("\t\t- jon")
jon = VirtualWorker(hook, id="jon")
if verbose:
print("\tStoring hook and workers as global variables...")
gbs["hook"] = hook
gbs["bob"] = bob
gbs["theo"] = theo
gbs["jason"] = jason
gbs["alice"] = alice
gbs["andy"] = andy
gbs["jon"] = jon
gbs["workers"] = [bob, theo, jason, alice, andy, jon]
if download_data: # pragma: no cover
if verbose:
print("\tLoading datasets from SciKit Learn...")
print("\t\t- Boston Housing Dataset")
boston = load_sklearn(load_boston, *["#boston", "#housing", "#boston_housing"])
if verbose:
print("\t\t- Diabetes Dataset")
diabetes = load_sklearn(load_diabetes, *["#diabetes"])
if verbose:
print("\t\t- Breast Cancer Dataset")
breast_cancer = load_sklearn(load_breast_cancer)
if verbose:
print("\t- Digits Dataset")
digits = load_sklearn(load_digits)
if verbose:
print("\t\t- Iris Dataset")
iris = load_sklearn(load_iris)
if verbose:
print("\t\t- Wine Dataset")
wine = load_sklearn(load_wine)
if verbose:
print("\t\t- Linnerud Dataset")
linnerud = load_sklearn(load_linnerud)
workers = [bob, theo, jason, alice, andy, jon]
if verbose:
print("\tDistributing Datasets Amongst Workers...")
distribute_dataset(boston[0], workers)
distribute_dataset(boston[1], workers)
distribute_dataset(diabetes[0], workers)
distribute_dataset(diabetes[1], workers)
distribute_dataset(breast_cancer[0], workers)
distribute_dataset(breast_cancer[1], workers)
distribute_dataset(digits[0], workers)
distribute_dataset(digits[1], workers)
distribute_dataset(iris[0], workers)
distribute_dataset(iris[1], workers)
distribute_dataset(wine[0], workers)
distribute_dataset(wine[1], workers)
distribute_dataset(linnerud[0], workers)
distribute_dataset(linnerud[1], workers)
if verbose:
print("\tCollecting workers into a VirtualGrid...")
_grid = PrivateGridNetwork(*gbs["workers"])
gbs["grid"] = _grid
print("Done!")
def hook(gbs):
return create_sandbox(gbs, False, False)