forked from baal-org/baal
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_utils.py
47 lines (39 loc) · 1.22 KB
/
test_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np
N_ITERATIONS = 50
IMG_SIZE = 3
def make_3d_fake_dist(means, stds, dims=10):
d = np.stack(
[make_fake_dist(means, stds, dims=dims) for _ in range(N_ITERATIONS)]
) # 50 iterations
d = np.rollaxis(d, 0, 3)
# [n_sample, n_class, n_iter]
return d
def make_5d_fake_dist(means, stds, dims=10):
d = np.stack(
[make_3d_fake_dist(means, stds, dims=dims) for _ in range(IMG_SIZE ** 2)], -1
) # 3x3 image
b, c, i, hw = d.shape
d = np.reshape(d, [b, c, i, IMG_SIZE, IMG_SIZE])
d = np.rollaxis(d, 2, 5)
# [n_sample, n_class, H, W, iter]
return d
def make_fake_dist(means, stds, dims=10):
"""
Create some fake discrete distributions
Args:
means: List of means
stds: List of standard deviations
dims: Dimensions of the distributions
Returns:
List of distributions
"""
n_trials = 100
distributions = []
for m, std in zip(means, stds):
dist = np.zeros([dims])
for i in range(n_trials):
dist[
np.round(np.clip(np.random.normal(m, std, 1), 0, dims - 1)).astype(int).item()
] += 1
distributions.append(dist / n_trials)
return np.array(distributions)