forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerated_fun_test.py
285 lines (229 loc) · 7.98 KB
/
generated_fun_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from functools import partial
import numpy.random as npr
from absl.testing import absltest
from absl.testing import parameterized
import itertools as it
import jax.numpy as jnp
from jax import jit, jvp, vjp
import jax.test_util as jtu
from jax.config import config
config.parse_flags_with_absl()
npr.seed(0)
from jax._src.util import unzip2, safe_zip, safe_map
map = safe_map
zip = safe_zip
subfun_prob = 0.5
thin_prob = 0.1
size_reduction_factor = 3
Eqn = namedtuple('Eqn', ['in_vars', 'out_vars', 'fun'])
Prim = namedtuple('Prim', ['fun'])
ArrayType = namedtuple('ArrayType', ['shape', 'dtype'])
Var = namedtuple('Var', ['name', 'vartype'])
Fun = namedtuple('Fun', ['in_vars', 'out_vars', 'eqns'])
def gen_fun_and_types(size):
in_types = [gen_array_type(size) for _ in range(gen_nonneg_int(size))]
fun, _ = gen_function(size, in_types)
return fun
def gen_function(size, in_types):
eqns = []
in_vars = map(fresh_var, in_types)
cur_vars = in_vars[:]
for _ in range(gen_nonneg_int(size)):
if not cur_vars:
break
if npr.rand() < subfun_prob:
arg_vars = gen_subset(cur_vars)
arg_types = [v.vartype for v in arg_vars]
fun, out_types = gen_function(size / size_reduction_factor, arg_types)
fun = partial(eval_fun, fun)
fun = maybe_jit(fun, len(arg_types))
else:
arity = choice(list(primitive_generators))
arg_vars = gen_sized_subset(cur_vars, arity)
arg_types = [v.vartype for v in arg_vars]
prim_gen = weighted_choice(primitive_generators[arity])
fun, out_type = prim_gen(size, *arg_types)
fun = wrap_singleton(fun)
out_types = [out_type]
out_vars = map(fresh_var, out_types)
eqns.append(Eqn(arg_vars, out_vars, fun))
cur_vars.extend(out_vars)
cur_vars = thin(cur_vars, thin_prob)
out_vars = gen_subset(cur_vars)
return Fun(in_vars, out_vars, eqns), [v.vartype for v in out_vars]
def eval_fun(fun, *args):
def read(v):
return env[v]
def write(v, x):
env[v] = x
env = {}
map(write, fun.in_vars, args)
for in_vars, out_vars, f in fun.eqns:
out_vals = f(*map(read, in_vars))
map(write, out_vars, out_vals)
return map(read, fun.out_vars)
def maybe_jit(f, num_args):
static_argnums = thin(range(num_args), 0.5)
def fun(*args):
partial_args = list(args)
for i in static_argnums:
partial_args[i] = None
@jit
def jitted_fun(*partial_args):
full_args = list(partial_args)
for i in static_argnums:
full_args[i] = args[i]
return f(*full_args)
return jitted_fun(*partial_args)
return fun
counter = it.count()
def fresh_var(ty):
return Var(next(counter), ty)
def gen_array_type(size):
# TODO(dougalm): randomize this
return ArrayType((2,2), jnp.float32)
def gen_array_val(array_type):
# TODO(dougalm): different sizes and dtypes
return npr.randn(*array_type.shape)
def gen_neg(size, t):
return (lambda x: -x), t
def gen_trig(size, t):
op = choice([jnp.sin, jnp.cos])
return op, t
def gen_binop(size, t1, t2):
unifier, t_out = gen_broadcasting_unifier(t1, t2)
binop = choice([lambda x, y: x + y,
lambda x, y: x * y])
def unify_and_binop(x, y):
x_, y_ = unifier(x, y)
return binop(x_, y_)
return unify_and_binop, t_out
def thin(xs, p):
return [x for x in xs if npr.rand() > p]
def gen_broadcasting_unifier(t1, t2):
assert t1.shape == t2.shape
return lambda x, y: (x,y), t1
# TODO: generate slices and paddings to match shapes
def wrap_singleton(f):
return lambda *xs: (f(*xs),)
unary_primitive_generators = [
(3, gen_trig),
(1, gen_neg) ]
binary_primitive_generators = [
(1, gen_binop)]
primitive_generators = { 1: unary_primitive_generators,
2: binary_primitive_generators }
def gen_nonneg_int(size):
return npr.randint(size)
def choice(xs, weights=None):
# npr.choice isn't actually RS -> [a] -> a
# because it inspects the components to see if they're array-like
assert xs
n = len(xs)
if weights is None:
i = npr.randint(n)
else:
normalizer = float(sum(weights))
weights = [w / normalizer for w in weights]
i = npr.choice(range(n), p=weights)
return xs[i]
def weighted_choice(weighted_choices):
weights, choices = unzip2(weighted_choices)
return choice(choices, weights)
def gen_sized_subset(xs, size):
return [choice(xs) for _ in range(size)]
def gen_subset(xs):
if not xs:
return []
return gen_sized_subset(xs, npr.randint(len(xs) + 1))
def gen_vals(vs):
return [gen_array_val(v.vartype) for v in vs]
def inner_prod(xs, ys):
xys = zip(xs, ys)
assert all(x.shape == y.shape for x, y in xys)
return sum(jnp.sum(x * y) for x, y in xys)
def jvp_fd(fun, args, tangents):
EPS = 1e-3
def eval_eps(eps):
return fun(*[x if t is None else x + eps * t
for x, t in zip(args, tangents)])
ys_neg = eval_eps(-EPS)
ys_pos = eval_eps(EPS)
ys = eval_eps(0.0)
deriv = [(y_pos - y_neg) / (2 * EPS) for y_neg, y_pos in zip(ys_neg, ys_pos)]
return ys, deriv
def check_all_close(xs, ys, tol=1e-3):
for x, y in zip(xs, ys):
check_close(x, y, tol)
def check_close(x, y, tol=1e-3):
assert jnp.shape(x) == jnp.shape(y)
# TODO(dougalm): re-enable once we've tackled the less pendantic bugs
# assert x.dtype == y.dtype
assert jnp.allclose(x, y, rtol=tol, atol=tol), \
"Value mismatch:\n{}\n vs\n{}\n".format(x, y)
def partial_argnums(f, args, dyn_argnums):
fixed_args = [None if i in dyn_argnums else arg for i, arg in enumerate(args)]
def f_(*dyn_args):
args = fixed_args[:]
for i, arg in zip(dyn_argnums, dyn_args):
args[i] = arg
return f(*args)
dyn_args = [args[i] for i in dyn_argnums]
return f_, dyn_args
class GeneratedFunTest(jtu.JaxTestCase):
"""Tests of transformations on randomly generated functions."""
@parameterized.named_parameters(jtu.cases_from_gens(gen_fun_and_types))
def testJitIsIdentity(self, fun):
vals = gen_vals(fun.in_vars)
fun = partial(eval_fun, fun)
ans = fun(*vals)
ans_jitted = maybe_jit(fun, len(vals))(*vals)
try:
check_all_close(ans, ans_jitted)
except:
print(fun)
raise
@parameterized.named_parameters(jtu.cases_from_gens(gen_fun_and_types))
def testJVPMatchesFD(self, fun):
vals = gen_vals(fun.in_vars)
tangents = gen_vals(fun.in_vars)
fun = partial(eval_fun, fun)
dyn_argnums = thin(range(len(vals)), 0.5)
tangents = [tangents[i] for i in dyn_argnums]
fun, vals = partial_argnums(fun, vals, dyn_argnums)
ans1, deriv1 = jvp_fd(fun, vals, tangents)
ans2, deriv2 = jvp(fun, tuple(vals), tuple(tangents))
check_all_close(ans1, ans2)
check_all_close(deriv1, deriv2)
@parameterized.named_parameters(jtu.cases_from_gens(gen_fun_and_types))
def vjp_matches_fd(self, fun):
vals = gen_vals(fun.in_vars)
in_tangents = gen_vals(fun.in_vars)
in_cotangents = gen_vals(fun.out_vars)
fun = partial(eval_fun, fun)
dyn_argnums = thin(range(len(vals)), 0.5)
in_tangents = [in_tangents[i] for i in dyn_argnums]
fun, vals = partial_argnums(fun, vals, dyn_argnums)
ans1, out_tangents = jvp_fd(fun, vals, in_tangents)
ans2, vjpfun = vjp(fun, *vals)
out_cotangents = vjpfun(in_cotangents)
check_all_close(ans1, ans2)
inner_prod_fd = inner_prod(out_tangents, in_cotangents)
inner_prod_ad = inner_prod(in_tangents, out_cotangents)
check_close(inner_prod_fd, inner_prod_ad)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())