-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
distributedRandom.py
69 lines (59 loc) · 1.98 KB
/
distributedRandom.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# contains helper functions for generating random numbers with a given distribution
import tensorflow as tf
import tensorflow_probability as tfp
def uniform(range):
N = tf.size(range)
range = tf.reshape(range, (-1,))
def F(shape):
idx = tf.random.uniform(shape, minval=0, maxval=N, dtype=tf.int32)
res = tf.gather(range, idx)
tf.assert_equal(tf.shape(res), shape)
return res
return F
def weighted(weights, values):
N = tf.size(weights)
weights = tf.reshape(weights, (1, N))
weights = tf.cast(weights, tf.float32)
weights = tf.math.log(weights)
values = tf.reshape(values, (N,))
def F(shape):
idx = tf.random.categorical(weights, tf.reduce_prod(shape))
res = tf.gather(values, idx)
res = tf.reshape(res, shape)
tf.assert_equal(tf.shape(res), shape)
return res
return F
def PowerDistribution(power=2.0):
def F(range):
N = tf.size(range)
weights = tf.range(N, dtype=tf.float32)[::-1] + 1
weights = tf.pow(weights, power)
return weighted(weights, range)
return F
def BetaDistribution(alpha=1.0, beta=1.0):
def F(range):
N = tf.size(range)
d = tfp.distributions.Beta(alpha, beta)
weights = tf.linspace(0.0, 1.0, N)
weights = d.prob(weights)
return weighted(weights, range)
return F
def UPowerDistribution(power=2.0):
def F(range):
N = tf.size(range)
weights = tf.range(N, dtype=tf.float32)[::-1] + 1
weights = tf.pow(weights, power)
# combine left and mirrored left
weights = tf.concat([weights[:N//2], weights[:N//2][::-1]], 0)
return weighted(weights, range)
return F
def config_to_distribution(config):
if 'uniform' == config['name']:
return uniform
if 'power' == config['name']:
return PowerDistribution(config['power'])
if 'beta' == config['name']:
return BetaDistribution(alpha=config['alpha'], beta=config['beta'])
if 'upower' == config['name']:
return UPowerDistribution(config['power'])
raise ValueError('Invalid distribution: %s' % config['name'])