forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathautodidax.py
3003 lines (2441 loc) · 103 KB
/
autodidax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# ---
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# jupyter:
# jupytext:
# formats: ipynb,md:myst,py
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.14.4
# kernelspec:
# display_name: Python 3
# name: python3
# ---
# [![Open in
# Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/autodidax.ipynb)
# # Autodidax: JAX core from scratch
#
# Ever want to learn how JAX works, but the implementation seemed impenetrable?
# Well, you're in luck! By reading this tutorial, you'll learn every big idea in
# JAX's core system. You'll even get clued into our weird jargon!
#
# **This is a work-in-progress draft.** There are some important ingredients
# missing, still to come in parts 5 and 6 (and more?). There are also some
# simplifications here that we haven't yet applied to the main system, but we
# will.
# ## Part 1: Transformations as interpreters: standard evaluation, `jvp`, and `vmap`
#
# We want to transform functions that look like this:
#
# ```python
# def f(x):
# y = sin(x) * 2.
# z = - y + x
# return z
# ```
#
# Think of functions like `sin` and the arithmetic operations underlying the
# infix operators (`mul`, `add`, and `neg`) as primitive operations, meaning
# atomic units of processing rather than compositions.
#
# "Transform" means "interpret differently." Instead of standard interpretation
# where we apply primitive operations to numerical inputs to produce numerical
# outputs, we want to override primitive application and let different values
# flow through our program. For example, we might want to replace the
# application of every primitive with an application of [its JVP
# rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html),
# and let primal-tangent pairs flow through our program. Moreover, we want to be
# able to compose multiple transformations, leading to stacks of interpreters.
# ### JAX core machinery
#
# We can implement stacks of interpreters and even have them all discharge on
# the fly as we execute the Python function to be transformed. To start, let's
# define these primitives so that we can intercept their application:
# +
from typing import NamedTuple
class Primitive(NamedTuple):
name: str
add_p = Primitive('add')
mul_p = Primitive('mul')
neg_p = Primitive("neg")
sin_p = Primitive("sin")
cos_p = Primitive("cos")
reduce_sum_p = Primitive("reduce_sum")
greater_p = Primitive("greater")
less_p = Primitive("less")
transpose_p = Primitive("transpose")
broadcast_p = Primitive("broadcast")
def add(x, y): return bind1(add_p, x, y)
def mul(x, y): return bind1(mul_p, x, y)
def neg(x): return bind1(neg_p, x)
def sin(x): return bind1(sin_p, x)
def cos(x): return bind1(cos_p, x)
def greater(x, y): return bind1(greater_p, x, y)
def less(x, y): return bind1(less_p, x, y)
def transpose(x, perm): return bind1(transpose_p, x, perm=perm)
def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes)
def reduce_sum(x, axis=None):
if axis is None:
axis = tuple(range(np.ndim(x)))
if type(axis) is int:
axis = (axis,)
return bind1(reduce_sum_p, x, axis=axis)
def bind1(prim, *args, **params):
out, = bind(prim, *args, **params)
return out
# -
# We'll set up array data types and infix operator methods in a moment.
#
# A `Primitive` is just an object with a name, to which we attach our
# interpretation rules (one for each transformation). The `bind` function is our
# interception point: it'll figure out which transformation rule to apply, based
# on how the arguments are boxed in tracers and what interpreters are active.
#
# The functions that user code calls, like `add` and `sin`, are just wrappers
# around calls to `bind`. These wrappers let us control how arguments are passed
# to `bind`, and in particular we follow a handy internal convention: when we
# call `bind`, we pass values representing array data as positional arguments,
# and we pass metadata like the `axis` argument to `sum_p` via keyword. This
# calling convention simplifies some core logic (since e.g. instances of the
# `Tracer` class to be defined below can only occur in positional arguments to
# `bind`). The wrappers can also provide docstrings!
#
# We represent active interpreters as a stack. The stack is just a simple
# `list`, and each element is a container with an integer level (corresponding
# to the element's height in the stack), an interpreter type (which we'll call a
# `trace_type`), and an optional field for any global data the interpreter
# needs. We call each element a `MainTrace`, though maybe "Interpreter" would be
# more descriptive.
# +
from contextlib import contextmanager
from typing import Type, List, Tuple, Sequence, Optional, Any
class MainTrace(NamedTuple):
level: int
trace_type: Type['Trace']
global_data: Optional[Any]
trace_stack: List[MainTrace] = []
dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3
@contextmanager
def new_main(trace_type: Type['Trace'], global_data=None):
level = len(trace_stack)
main = MainTrace(level, trace_type, global_data)
trace_stack.append(main)
try:
yield main
finally:
trace_stack.pop()
# -
# When we're about to apply a transformation, we'll push another interpreter
# onto the stack using `new_main`. Then, as we apply primitives in the function,
# we can think of the `bind` first being interpreted by the trace at the top of
# the stack (i.e. with the highest level). If that first interpreter itself
# binds other primitives in its interpretation rule for the primitive, like how
# the JVP rule of `sin_p` might bind `cos_p` and `mul_p`, then those `bind`
# calls will be handled by the interpreter at the next level down.
#
# What goes at the bottom of the interpreter stack? At the bottom, we know all
# the transformation interpreters are finished, and we just want to do standard
# evaluation. So at the bottom we'll put an evaluation interpreter.
#
# Let's sketch out the interface for interpreters, which is based on the `Trace`
# and `Tracer` base classes. A `Tracer` represents a boxed-up value, perhaps
# carrying some extra context data used by the interpreter. A `Trace` handles
# boxing up values into `Tracers` and also handles primitive application.
class Trace:
main: MainTrace
def __init__(self, main: MainTrace) -> None:
self.main = main
def pure(self, val): assert False # must override
def lift(self, val): assert False # must override
def process_primitive(self, primitive, tracers, params):
assert False # must override
# The first two methods are about boxing up values in `Tracer`s, which are the
# objects that flow through the Python programs we transform. The last method is
# the callback we'll use to interpret primitive application.
#
# The `Trace` itself doesn't contain any data, other than a reference to its
# corresponding `MainTrace` instance. In fact, multiple instances of a `Trace`
# might be created and discarded during an application of a transformation,
# whereas only a single `MainTrace` instance is created per application of a
# transformation.
#
# As for `Tracer`s themselves, each one carries an abstract value (and forwards
# infix operators to it), and the rest is up to the transformation. (The
# relationship between `Tracer`s and `AbstractValue`s is that there's one
# `Tracer` per transformation, and at least one `AbstractValue` per base type,
# like arrays.)
# +
import numpy as np
class Tracer:
_trace: Trace
__array_priority__ = 1000
@property
def aval(self):
assert False # must override
def full_lower(self):
return self # default implementation
def __neg__(self): return self.aval._neg(self)
def __add__(self, other): return self.aval._add(self, other)
def __radd__(self, other): return self.aval._radd(self, other)
def __mul__(self, other): return self.aval._mul(self, other)
def __rmul__(self, other): return self.aval._rmul(self, other)
def __gt__(self, other): return self.aval._gt(self, other)
def __lt__(self, other): return self.aval._lt(self, other)
def __bool__(self): return self.aval._bool(self)
def __nonzero__(self): return self.aval._nonzero(self)
def __getattr__(self, name):
try:
return getattr(self.aval, name)
except AttributeError:
raise AttributeError(f"{self.__class__.__name__} has no attribute {name}")
def swap(f): return lambda x, y: f(y, x)
# +
class ShapedArray:
array_abstraction_level = 1
shape: Tuple[int, ...]
dtype: np.dtype
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype
@property
def ndim(self):
return len(self.shape)
_neg = staticmethod(neg)
_add = staticmethod(add)
_radd = staticmethod(swap(add))
_mul = staticmethod(mul)
_rmul = staticmethod(swap(mul))
_gt = staticmethod(greater)
_lt = staticmethod(less)
@staticmethod
def _bool(tracer):
raise Exception("ShapedArray can't be unambiguously converted to bool")
@staticmethod
def _nonzero(tracer):
raise Exception("ShapedArray can't be unambiguously converted to bool")
def str_short(self):
return f'{self.dtype.name}[{",".join(str(d) for d in self.shape)}]'
def __hash__(self):
return hash((self.shape, self.dtype))
def __eq__(self, other):
return (type(self) is type(other) and
self.shape == other.shape and self.dtype == other.dtype)
def __repr__(self):
return f"ShapedArray(shape={self.shape}, dtype={self.dtype})"
class ConcreteArray(ShapedArray):
array_abstraction_level = 2
val: np.ndarray
def __init__(self, val):
self.val = val
self.shape = val.shape
self.dtype = val.dtype
@staticmethod
def _bool(tracer):
return bool(tracer.aval.val)
@staticmethod
def _nonzero(tracer):
return bool(tracer.aval.val)
def get_aval(x):
if isinstance(x, Tracer):
return x.aval
elif type(x) in jax_types:
return ConcreteArray(np.asarray(x))
else:
raise TypeError(x)
jax_types = {bool, int, float,
np.bool_, np.int32, np.int64, np.float32, np.float64, np.ndarray}
# -
# Notice that we actually have two `AbstractValue`s for arrays, representing
# different levels of abstraction. A `ShapedArray` represents the set of all
# possible arrays with a given shape and dtype. A `ConcreteArray` represents a
# singleton set consisting of a single array value.
#
# Now that we've set up the interpreter stack, the Trace/Tracer API for
# interpreters, and abstract values, we can come back to implement `bind`:
def bind(prim, *args, **params):
top_trace = find_top_trace(args)
tracers = [full_raise(top_trace, arg) for arg in args]
outs = top_trace.process_primitive(prim, tracers, params)
return [full_lower(out) for out in outs]
# The main action is that we call `find_top_trace` to figure out which
# interpreter should handle this primitive application. We then call that top
# trace's `process_primitive` so that the trace can apply its interpretation
# rule. The calls to `full_raise` just ensure that the inputs are boxed in the
# top trace's `Tracer` instances, and the call to `full_lower` is an optional
# optimization so that we unbox values out of `Tracer`s as much as possible.
# +
import operator as op
def find_top_trace(xs) -> Trace:
top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),
default=trace_stack[0], key=op.attrgetter('level'))
if dynamic_trace and dynamic_trace.level > top_main.level:
top_main = dynamic_trace
return top_main.trace_type(top_main)
# -
# In words, ignoring the `dynamic_trace` step until Part 3, `find_top_trace`
# returns the highest-level interpreter associated with the `Tracer`s on its
# inputs, and otherwise returns the interpreter at the bottom of the stack
# (which is always an evaluation trace, at least for now). This is a deviation
# from the description above, where we always start by running the interpreter
# at the top of the stack and then work our way down, applying every interpreter
# in the stack. Instead, we're only applying an interpreter when the input
# arguments to a primitive bind are boxed in a `Tracer` corresponding to that
# interpreter. This optimization lets us skip irrelevant transformations, but
# bakes in an assumption that transformations mostly follow data dependence
# (except for the special bottom-of-the-stack interpreter, which interprets
# everything).
#
# An alternative would be to have every interpreter in the stack interpret every
# operation. That's worth exploring! JAX is designed around data dependence in
# large part because that's so natural for automatic differentiation, and JAX's
# roots are in autodiff. But it may be over-fit.
# +
def full_lower(val: Any):
if isinstance(val, Tracer):
return val.full_lower()
else:
return val
def full_raise(trace: Trace, val: Any) -> Tracer:
if not isinstance(val, Tracer):
assert type(val) in jax_types
return trace.pure(val)
level = trace.main.level
if val._trace.main is trace.main:
return val
elif val._trace.main.level < level:
return trace.lift(val)
elif val._trace.main.level > level:
raise Exception(f"Can't lift level {val._trace.main.level} to {level}.")
else: # val._trace.level == level
raise Exception(f"Different traces at same level: {val._trace}, {trace}.")
# -
# The logic in `full_raise` serves to box values into `Tracer`s for a particular
# `Trace`, calling different methods on the `Trace` based on context:
# `Trace.pure` is called on non-`Tracer` constants, and `Trace.lift` is called
# for values that are already `Tracer`s from a lower-level interpreter. These
# two methods could share the same implementation, but by distinguishing them in
# the core logic we can provide more information to the `Trace` subclass.
#
# That's it for the JAX core! Now we can start adding interpreters.
# ### Evaluation interpreter
#
# We'll start with the simplest interpreter: the evaluation interpreter that
# will sit at the bottom of the interpreter stack.
# +
class EvalTrace(Trace):
pure = lift = lambda self, x: x # no boxing in Tracers needed
def process_primitive(self, primitive, tracers, params):
return impl_rules[primitive](*tracers, **params)
trace_stack.append(MainTrace(0, EvalTrace, None)) # special bottom of the stack
# NB: in JAX, instead of a dict we attach impl rules to the Primitive instance
impl_rules = {}
impl_rules[add_p] = lambda x, y: [np.add(x, y)]
impl_rules[mul_p] = lambda x, y: [np.multiply(x, y)]
impl_rules[neg_p] = lambda x: [np.negative(x)]
impl_rules[sin_p] = lambda x: [np.sin(x)]
impl_rules[cos_p] = lambda x: [np.cos(x)]
impl_rules[reduce_sum_p] = lambda x, *, axis: [np.sum(x, axis)]
impl_rules[greater_p] = lambda x, y: [np.greater(x, y)]
impl_rules[less_p] = lambda x, y: [np.less(x, y)]
impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)]
def broadcast_impl(x, *, shape, axes):
for axis in sorted(axes):
x = np.expand_dims(x, axis)
return [np.broadcast_to(x, shape)]
impl_rules[broadcast_p] = broadcast_impl
# -
# With this interpreter, we can evaluate user functions:
# +
def f(x):
y = sin(x) * 2.
z = - y + x
return z
print(f(3.0))
# -
# Woo! Like going around in a big circle. But the point of this indirection is
# that now we can add some real transformations.
# ### Forward-mode autodiff with `jvp`
#
# First, a few helper functions:
# +
def zeros_like(val):
aval = get_aval(val)
return np.zeros(aval.shape, aval.dtype)
def unzip2(pairs):
lst1, lst2 = [], []
for x1, x2 in pairs:
lst1.append(x1)
lst2.append(x2)
return lst1, lst2
map_ = map
def map(f, *xs):
return list(map_(f, *xs))
zip_ = zip
def zip(*args):
fst, *rest = args = map(list, args)
n = len(fst)
for arg in rest:
assert len(arg) == n
return list(zip_(*args))
# -
# The `Tracer` for forward-mode autodiff carries a primal-tangent pair. The
# `Trace` applies JVP rules.
# +
class JVPTracer(Tracer):
def __init__(self, trace, primal, tangent):
self._trace = trace
self.primal = primal
self.tangent = tangent
@property
def aval(self):
return get_aval(self.primal)
class JVPTrace(Trace):
pure = lift = lambda self, val: JVPTracer(self, val, zeros_like(val))
def process_primitive(self, primitive, tracers, params):
primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
jvp_rule = jvp_rules[primitive]
primal_outs, tangent_outs = jvp_rule(primals_in, tangents_in, **params)
return [JVPTracer(self, x, t) for x, t in zip(primal_outs, tangent_outs)]
jvp_rules = {}
# -
# Notice both `pure` and `lift` package a value into a `JVPTracer` with the
# minimal amount of context, which is a zero tangent value.
#
# Let's add some JVP rules for primitives:
# +
def add_jvp(primals, tangents):
(x, y), (x_dot, y_dot) = primals, tangents
return [x + y], [x_dot + y_dot]
jvp_rules[add_p] = add_jvp
def mul_jvp(primals, tangents):
(x, y), (x_dot, y_dot) = primals, tangents
return [x * y], [x_dot * y + x * y_dot]
jvp_rules[mul_p] = mul_jvp
def sin_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
return [sin(x)], [cos(x) * x_dot]
jvp_rules[sin_p] = sin_jvp
def cos_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
return [cos(x)], [-sin(x) * x_dot]
jvp_rules[cos_p] = cos_jvp
def neg_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
return [neg(x)], [neg(x_dot)]
jvp_rules[neg_p] = neg_jvp
def reduce_sum_jvp(primals, tangents, *, axis):
(x,), (x_dot,) = primals, tangents
return [reduce_sum(x, axis)], [reduce_sum(x_dot, axis)]
jvp_rules[reduce_sum_p] = reduce_sum_jvp
def greater_jvp(primals, tangents):
(x, y), _ = primals, tangents
out_primal = greater(x, y)
return [out_primal], [zeros_like(out_primal)]
jvp_rules[greater_p] = greater_jvp
def less_jvp(primals, tangents):
(x, y), _ = primals, tangents
out_primal = less(x, y)
return [out_primal], [zeros_like(out_primal)]
jvp_rules[less_p] = less_jvp
# -
# Finally, we add a transformation API to kick off the trace:
def jvp_v1(f, primals, tangents):
with new_main(JVPTrace) as main:
trace = JVPTrace(main)
tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]
out = f(*tracers_in)
tracer_out = full_raise(trace, out)
primal_out, tangent_out = tracer_out.primal, tracer_out.tangent
return primal_out, tangent_out
# And with that, we can differentiate!
x = 3.0
y, sin_deriv_at_3 = jvp_v1(sin, (x,), (1.0,))
print(sin_deriv_at_3)
print(cos(3.0))
# +
def f(x):
y = sin(x) * 2.
z = - y + x
return z
x, xdot = 3., 1.
y, ydot = jvp_v1(f, (x,), (xdot,))
print(y)
print(ydot)
# +
def deriv(f):
return lambda x: jvp_v1(f, (x,), (1.,))[1]
print(deriv(sin)(3.))
print(deriv(deriv(sin))(3.))
print(deriv(deriv(deriv(sin)))(3.))
print(deriv(deriv(deriv(deriv(sin))))(3.))
# +
def f(x):
if x > 0.: # Python control flow
return 2. * x
else:
return x
print(deriv(f)(3.))
print(deriv(f)(-3.))
# -
# ## Pytrees and flattening user functions' inputs and outputs
# A limitation with `jvp_v1` is that it assumes the user function accepts arrays
# as positional arguments and produces a single array as output. What if it
# produced a list as output? Or accepted nested containers as inputs? It would
# be a pain to deal with all the possible containers in inputs and outputs at
# every layer of the stack. Instead, we can wrap the user function so that the
# wrapped version accepts arrays as inputs and returns a flat list of arrays as
# output. The wrapper just needs to unflatten its input, call the user function,
# and flatten the output.
#
# Here's how we'd like to write `jvp`, assuming the user always gives us
# functions that take arrays as inputs and produces a flat list of arrays as
# outputs:
def jvp_flat(f, primals, tangents):
with new_main(JVPTrace) as main:
trace = JVPTrace(main)
tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
primals_out, tangents_out = unzip2((t.primal, t.tangent) for t in tracers_out)
return primals_out, tangents_out
# To support user functions that have arbitrary containers in the inputs and
# outputs, here's how we'd write the user-facing `jvp` wrapper:
def jvp(f, primals, tangents):
primals_flat, in_tree = tree_flatten(primals)
tangents_flat, in_tree2 = tree_flatten(tangents)
if in_tree != in_tree2: raise TypeError
f, out_tree = flatten_fun(f, in_tree)
primals_out_flat, tangents_out_flat = jvp_flat(f, primals_flat, tangents_flat)
primals_out = tree_unflatten(out_tree(), primals_out_flat)
tangents_out = tree_unflatten(out_tree(), tangents_out_flat)
return primals_out, tangents_out
# Notice that we had to plumb the tree structure of the user function output
# back to the caller of `flatten_fun`. That information isn't available until we
# actually run the user function, so `flatten_fun` just returns a reference to a
# mutable cell, represented as a thunk. These side-effects are safe because we
# always run the user function exactly once. (This safe regime is the reason for
# the "linear" name in `linear_util.py`, in the sense of [linear
# types](https://en.wikipedia.org/wiki/Substructural_type_system).)
#
# All that remains is to write `tree_flatten`, `tree_unflatten`, and
# `flatten_fun`.
# + tags=["hide-input"]
def flatten_fun(f, in_tree):
store = Store()
def flat_fun(*args_flat):
pytree_args = tree_unflatten(in_tree, args_flat)
out = f(*pytree_args)
out_flat, out_tree = tree_flatten(out)
store.set_value(out_tree)
return out_flat
return flat_fun, store
class Empty: pass
empty = Empty()
class Store:
val = empty
def set_value(self, val):
assert self.val is empty
self.val = val
def __call__(self):
return self.val
# + tags=["hide-input"]
import itertools as it
from typing import Callable, Type, Hashable, Dict, Iterable, Iterator
class NodeType(NamedTuple):
name: str
to_iterable: Callable
from_iterable: Callable
def register_pytree_node(ty: Type, to_iter: Callable, from_iter: Callable
) -> None:
node_types[ty] = NodeType(str(ty), to_iter, from_iter)
node_types: Dict[Type, NodeType] = {}
register_pytree_node(tuple, lambda t: (None, t), lambda _, xs: tuple(xs))
register_pytree_node(list, lambda l: (None, l), lambda _, xs: list(xs))
register_pytree_node(dict,
lambda d: map(tuple, unzip2(sorted(d.items()))),
lambda keys, vals: dict(zip(keys, vals)))
class PyTreeDef(NamedTuple):
node_type: NodeType
node_metadata: Hashable
child_treedefs: Tuple['PyTreeDef', ...]
class Leaf: pass
leaf = Leaf()
def tree_flatten(x: Any) -> Tuple[List[Any], PyTreeDef]:
children_iter, treedef = _tree_flatten(x)
return list(children_iter), treedef
def _tree_flatten(x: Any) -> Tuple[Iterable, PyTreeDef]:
node_type = node_types.get(type(x))
if node_type:
node_metadata, children = node_type.to_iterable(x)
children_flat, child_trees = unzip2(map(_tree_flatten, children))
flattened = it.chain.from_iterable(children_flat)
return flattened, PyTreeDef(node_type, node_metadata, tuple(child_trees))
else:
return [x], leaf
def tree_unflatten(treedef: PyTreeDef, xs: List[Any]) -> Any:
return _tree_unflatten(treedef, iter(xs))
def _tree_unflatten(treedef: PyTreeDef, xs: Iterator) -> Any:
if treedef is leaf:
return next(xs)
else:
children = (_tree_unflatten(t, xs) for t in treedef.child_treedefs)
return treedef.node_type.from_iterable(treedef.node_metadata, children)
# -
# With this pytree-handling `jvp` implementation, we can now handle arbitrary
# input and output containers. That'll come in handy with future transformations
# too!
# +
def f(x):
y = sin(x) * 2.
z = - y + x
return {'hi': z, 'there': [x, y]}
x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
# -
# ### Vectorized batching with `vmap`
#
# First, a couple helper functions, one for producing mapped abstract values
# from unmapped ones (by removing an axis), and one for moving batch dimensions
# around:
# +
def mapped_aval(batch_dim, aval):
shape = list(aval.shape)
del shape[batch_dim]
return ShapedArray(tuple(shape), aval.dtype)
def move_batch_axis(axis_size, src, dst, x):
if src is not_mapped:
target_shape = list(np.shape(x))
target_shape.insert(dst, axis_size)
return broadcast(x, target_shape, [dst])
elif src == dst:
return x
else:
return moveaxis(x, src, dst)
def moveaxis(x, src: int, dst: int):
perm = [i for i in range(np.ndim(x)) if i != src]
perm.insert(dst, src)
return transpose(x, perm)
# -
# The `Tracer` for vectorized batching carries a batched value and an optional
# integer indicating which axis (if any) is the batch axis.
# +
from typing import Union
class NotMapped: pass
not_mapped = NotMapped()
BatchAxis = Union[NotMapped, int]
class BatchTracer(Tracer):
def __init__(self, trace, val, batch_dim: BatchAxis):
self._trace = trace
self.val = val
self.batch_dim = batch_dim
@property
def aval(self):
if self.batch_dim is not_mapped:
return get_aval(self.val)
else:
return mapped_aval(self.batch_dim, get_aval(self.val))
def full_lower(self):
if self.batch_dim is not_mapped:
return full_lower(self.val)
else:
return self
class BatchTrace(Trace):
pure = lift = lambda self, val: BatchTracer(self, val, not_mapped)
def process_primitive(self, primitive, tracers, params):
vals_in, bdims_in = unzip2((t.val, t.batch_dim) for t in tracers)
vmap_rule = vmap_rules[primitive]
val_outs, bdim_outs = vmap_rule(self.axis_size, vals_in, bdims_in, **params)
return [BatchTracer(self, x, bd) for x, bd in zip(val_outs, bdim_outs)]
@property
def axis_size(self):
return self.main.global_data
vmap_rules = {}
# -
# Here we've implemented the optional `Tracer.full_lower` method, which lets us
# peel off a batching tracer if it's not needed because it doesn't represent a
# batched value.
#
# For `BatchTrace`, analogous to `JVPTrace`, the methods `pure` and `lift` just
# box a value in a `BatchTracer` with the minimal amount of context, which in
# this case is a `batch_dim` taking the sentinel value `not_mapped`. Notice we
# use the `MainTrace`'s interpreter-global data field to store the batch axis
# size.
#
# Next we can define batching interpreter rules for each primitive:
# +
from functools import partial
def binop_batching_rule(op, axis_size, vals_in, dims_in):
(x, y), (x_bdim, y_bdim) = vals_in, dims_in
if x_bdim != y_bdim:
if x_bdim is not_mapped:
x = move_batch_axis(axis_size, x_bdim, y_bdim, x)
x_bdim = y_bdim
else:
y = move_batch_axis(axis_size, y_bdim, x_bdim, y)
return [op(x, y)], [x_bdim]
vmap_rules[add_p] = partial(binop_batching_rule, add)
vmap_rules[mul_p] = partial(binop_batching_rule, mul)
def vectorized_unop_batching_rule(op, axis_size, vals_in, dims_in):
(x,), (x_bdim,) = vals_in, dims_in
return [op(x)], [x_bdim]
vmap_rules[sin_p] = partial(vectorized_unop_batching_rule, sin)
vmap_rules[cos_p] = partial(vectorized_unop_batching_rule, cos)
vmap_rules[neg_p] = partial(vectorized_unop_batching_rule, neg)
def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis):
(x,), (x_bdim,) = vals_in, dims_in
new_axis = tuple(ax + (x_bdim <= ax) for ax in axis)
out_bdim = x_bdim - sum(ax < x_bdim for ax in axis)
return [reduce_sum(x, new_axis)], [out_bdim]
vmap_rules[reduce_sum_p] = reduce_sum_batching_rule
# -
# Finally, we add a transformation API to kick off the trace:
# +
def vmap_flat(f, in_axes, *args):
axis_size, = {x.shape[ax] for x, ax in zip(args, in_axes)
if ax is not not_mapped}
with new_main(BatchTrace, axis_size) as main:
trace = BatchTrace(main)
tracers_in = [BatchTracer(trace, x, ax) if ax is not None else x
for x, ax in zip(args, in_axes)]
outs = f(*tracers_in)
tracers_out = [full_raise(trace, out) for out in outs]
vals_out, bdims_out = unzip2((t.val, t.batch_dim) for t in tracers_out)
outs_transposed = [move_batch_axis(axis_size, bdim, 0, val_out)
for val_out, bdim in zip(vals_out, bdims_out)]
return outs_transposed
def vmap(f, in_axes):
def batched_f(*args):
args_flat, in_tree = tree_flatten(args)
in_axes_flat, in_tree2 = tree_flatten(in_axes)
if in_tree != in_tree2: raise TypeError
f_flat, out_tree = flatten_fun(f, in_tree)
outs_flat = vmap_flat(f_flat, in_axes_flat, *args_flat)
return tree_unflatten(out_tree(), outs_flat)
return batched_f
# +
def add_one_to_a_scalar(scalar):
assert np.ndim(scalar) == 0
return 1 + scalar
vector_in = np.arange(3.)
vector_out = vmap(add_one_to_a_scalar, (0,))(vector_in)
print(vector_in)
print(vector_out)
# +
def jacfwd(f, x):
pushfwd = lambda v: jvp(f, (x,), (v,))[1]
vecs_in = np.eye(np.size(x)).reshape(np.shape(x) * 2)
return vmap(pushfwd, (0,))(vecs_in)
def f(x):
return sin(x)
jacfwd(f, np.arange(3.))
# -
# That's it for `jvp` and `vmap`!
# ## Part 2: Jaxprs
#
# The next transformations on the horizon are `jit` for just-in-time
# compilation and `vjp` for reverse-mode autodiff. (`grad` is just a small
# wrapper around `vjp`.) Whereas `jvp` and `vmap` only needed each `Tracer` to
# carry a little bit of extra context, for both `jit` and `vjp` we need much
# richer context: we need to represent _programs_. That is, we need jaxprs!
#
# Jaxprs are JAX's internal intermediate representation of programs. They are
# explicitly typed, functional, first-order, and in ANF form. We need a
# program representation for `jit` because the purpose of `jit` is to stage
# computation out of Python. For any computation we want to stage out, we need
# to be able to represent it as data, and build it up as we trace a Python
# function. Similarly, `vjp` needs a way to represent the computation for the
# backward pass of reverse-mode autodiff. We use the same jaxpr program
# representation for both needs.
#
# (Building a program representation is the most
# [free](https://en.wikipedia.org/wiki/Free_object) kind of
# trace-transformation, and so except for issues around handling native Python
# control flow, any transformation could be implemented by first tracing to a
# jaxpr and then interpreting the jaxpr.)
# ### Jaxpr data structures
#
# The jaxpr term syntax is roughly:
#
# ```
# jaxpr ::=
# { lambda <binder> , ... .
# let <eqn>
# ...
# in ( <atom> , ... ) }
#
# binder ::= <var>:<array_type>
# var ::= a | b | c | ...
# atom ::= <var> | <literal>
# literal ::= <int32> | <int64> | <float32> | <float64>
#
# eqn ::= <binder> , ... = <primitive> [ <params> ] <atom> , ...
# ```
#
# The syntax of types is:
#
# ```
# jaxpr_type ::= [ <array_type> , ... ] -> [ <array_type> , ... ]
# array_type ::= <dtype>[<shape>]
# dtype ::= f32 | f64 | i32 | i64
# shape ::= <int> , ...
# ```
#
# How do we represent these as Python data structures? We reuse ShapedArrays to
# represent types, and we can represent the term syntax with a few Python
# structs:
# +
from typing import Set