forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexport_test.py
1917 lines (1658 loc) · 76.2 KB
/
export_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import collections
from collections.abc import Callable, Sequence
import contextlib
import dataclasses
import functools
import logging
import json
import math
import re
import unittest
from absl.testing import absltest
import jax
from jax import lax
from jax import numpy as jnp
from jax import export
from jax.experimental import pjit
from jax.experimental.shard_map import shard_map
from jax.sharding import NamedSharding
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from jax import tree_util
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src import effects
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.interpreters import mlir
from jax._src.lib.mlir.dialects import hlo
import numpy as np
# ruff: noqa: F401
try:
import flatbuffers
CAN_SERIALIZE = True
except (ModuleNotFoundError, ImportError):
CAN_SERIALIZE = False
config.parse_flags_with_absl()
_exit_stack = contextlib.ExitStack()
def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(2))
def tearDownModule():
_exit_stack.close()
### Setup for testing lowering with effects
@dataclasses.dataclass(frozen=True)
class ForTestingOrderedEffect1(effects.Effect):
pass
@dataclasses.dataclass(frozen=True)
class ForTestingOrderedEffect2(effects.Effect):
pass
@dataclasses.dataclass(frozen=True)
class ForTestingUnorderedEffect1(effects.Effect):
pass
class ForTestingOrderedEffect4NoNullary(effects.Effect):
def __init__(self, _):
pass
@dataclasses.dataclass(eq=False)
class ForTestingOrderedEffect5NoEq(effects.Effect):
pass
_testing_effects = dict(
ForTestingOrderedEffect1=ForTestingOrderedEffect1(),
ForTestingOrderedEffect2=ForTestingOrderedEffect2(),
ForTestingUnorderedEffect1=ForTestingUnorderedEffect1(),
ForTestingOrderedEffect4NoNullary=ForTestingOrderedEffect4NoNullary(42),
ForTestingOrderedEffect5NoEq=ForTestingOrderedEffect5NoEq(),
)
# Register the effects
for effect in _testing_effects.values():
effect_class = effect.__class__
effects.lowerable_effects.add_type(effect_class)
effects.control_flow_allowed_effects.add_type(effect_class)
effects.remat_allowed_effects.add_type(effect_class)
effects.custom_derivatives_allowed_effects.add_type(effect_class)
if "Ordered" in str(effect_class):
effects.ordered_effects.add_type(effect_class)
# A primitive that takes a effect_class_name kwarg with the name of the effect class
# and just doubles its argument.
testing_primitive_with_effect_p = core.Primitive("testing_primitive_with_effect")
testing_primitive_with_effect_p.def_effectful_abstract_eval(
lambda aval, *x, effect_class_name: (aval, {_testing_effects[effect_class_name]}))
def lowering_testing_primitive_with_effect(ctx, a, *, effect_class_name: str):
if "Ordered" in effect_class_name:
token_in = ctx.tokens_in.get(_testing_effects[effect_class_name])
ctx.set_tokens_out(mlir.TokenSet({_testing_effects[effect_class_name]: token_in}))
return [mlir.hlo.add(a, a)]
mlir.register_lowering(testing_primitive_with_effect_p,
lowering_testing_primitive_with_effect)
## Setup for multi-platform lowering
_testing_multi_platform_to_add = dict(cpu=2., tpu=3., cuda=4., rocm=5.)
def _testing_multi_platform_func(x, *,
effect_class_name: str | None = None):
# Behaves like x + 2 * _testing_multi_platform_to_add[platform]
def for_platform(platform: str):
if effect_class_name is None:
return 2. * _testing_multi_platform_to_add[platform]
else:
return testing_primitive_with_effect_p.bind(
_testing_multi_platform_to_add[platform],
effect_class_name=effect_class_name)
return x + lax.platform_dependent(
tpu=lambda: for_platform("tpu"),
cuda=lambda: for_platform("cuda"),
rocm=lambda: for_platform("rocm"),
default=lambda: for_platform("cpu"),
)
def _testing_multi_platform_fun_expected(x,
platform: str | None = None):
return x + 2. * _testing_multi_platform_to_add[
xb.canonicalize_platform(platform or jtu.device_under_test())
]
def get_exported(fun: Callable, vjp_order=0,
**export_kwargs) -> Callable[[...], export.Exported]:
"""Like export.export but with serialization + deserialization."""
def serde_exported(*fun_args, **fun_kwargs):
exp = export.export(fun, **export_kwargs)(*fun_args, **fun_kwargs)
if CAN_SERIALIZE:
serialized = exp.serialize(vjp_order=vjp_order)
return export.deserialize(serialized)
else:
return exp
return serde_exported
# Run tests with the maximum supported version by default
@jtu.with_config(jax_export_calling_convention_version=export.maximum_supported_calling_convention_version)
class JaxExportTest(jtu.JaxTestCase):
@classmethod
def setUpClass(cls):
# Find the available platforms
cls.platforms = []
for backend in ["cpu", "gpu", "tpu"]:
try:
jax.devices(backend)
except RuntimeError:
continue
cls.platforms.append(backend)
super().setUpClass()
def test_basic_export_only(self):
@jax.jit
def my_fun(x):
return jnp.sin(x)
exp = get_exported(my_fun)(jax.ShapeDtypeStruct((4,), dtype=np.float32))
self.assertEqual("my_fun", exp.fun_name)
expected_lowering_platform = xb.canonicalize_platform(jax.default_backend())
self.assertEqual((expected_lowering_platform,),
exp.platforms)
self.assertEqual(jax.tree.flatten(((1,), {}))[1], exp.in_tree)
self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.in_avals)
self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.out_avals)
def test_pytree_export_only(self):
a = np.arange(4, dtype=np.float32)
b = np.arange(6, dtype=np.float32)
def f(a_b_pair, *, a, b):
return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b))
exp = get_exported(jax.jit(f), platforms=("cpu",))((a, b), a=a, b=b)
a_aval = core.ShapedArray(a.shape, a.dtype)
b_aval = core.ShapedArray(b.shape, b.dtype)
self.assertEqual(exp.platforms, ("cpu",))
args = ((a, b),)
kwargs = dict(a=a, b=b)
self.assertEqual(exp.in_tree, jax.tree.flatten((args, kwargs))[1])
self.assertEqual(exp.in_avals, (a_aval, b_aval, a_aval, b_aval))
self.assertEqual(exp.out_tree, jax.tree.flatten(f(*args, **kwargs))[1])
self.assertEqual(exp.out_avals, (a_aval, b_aval, a_aval, b_aval, a_aval, b_aval))
def test_basic(self):
f = jnp.sin
x = np.arange(4, dtype=np.float32)
exp_f = get_exported(f)(x)
self.assertAllClose(f(x), exp_f.call(x))
def test_jit_static_arg(self):
with self.subTest("static_argnames"):
@functools.partial(jax.jit, static_argnames=["c"])
def f(x, *, c):
return c * jnp.sin(x)
x = np.arange(4, dtype=np.float32)
exp_f = get_exported(f)(x, c=0.1)
self.assertAllClose(f(x, c=0.1), exp_f.call(x))
with self.subTest("static_argnums"):
@functools.partial(jax.jit, static_argnums=[1])
def g(x, c):
return c * jnp.sin(x)
x = np.arange(4, dtype=np.float32)
exp_g = get_exported(g)(x, 0.1)
self.assertAllClose(g(x, 0.1), exp_g.call(x))
def test_export_error_no_jit(self):
# Can export a lambda, without jit
with self.assertRaisesRegex(ValueError,
"Function to be exported must be the result of `jit`"):
_ = export.export(lambda x: jnp.sin(x))
def test_call_exported_lambda(self):
# When we export a lambda, the exported.fun_name is not a valid MLIR function name
f = jax.jit(lambda x: jnp.sin(x))
x = np.arange(4, dtype=np.float32)
exp_f = get_exported(f)(x)
self.assertAllClose(f(x), exp_f.call(x))
def test_call_name_conflict(self):
@jax.jit
def inner(x):
# The lowering will contain a _where private function
return jnp.where(x > 0, jnp.ones_like(x), jnp.zeros_like(x))
x = jnp.arange(-20, 20, dtype=np.int32)
exp_inner = export.export(inner)(x)
self.assertIn("@_where(", str(exp_inner.mlir_module()))
@jax.jit
def outer(x):
# There should be no conflict on _where
x = exp_inner.call(x)
return inner(x)
export.export(outer)(x)
def test_call_twice_exported(self):
def f(x): return jnp.sin(x)
x = np.arange(4, dtype=np.float32)
@jax.jit
def f1(x):
exp_f = get_exported(jax.jit(f))(x)
return exp_f.call(x) + exp_f.call(x)
self.assertAllClose(2. * f(x), f1(x))
def test_unused_args(self):
f = jax.jit(lambda x, y: jnp.sin(x))
x = np.arange(4, dtype=np.float32)
y = np.arange(6, dtype=np.float32)
exp_f = get_exported(f)(x, y)
self.assertAllClose(f(x, y), exp_f.call(x, y))
def test_pytree(self):
a = np.arange(4, dtype=np.float32)
b = np.arange(6, dtype=np.float32)
def f(a_b_pair, a, b):
return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b))
exp_f = get_exported(jax.jit(f))((a, b), a=a, b=b)
self.assertAllClose(f((a, b), a=a, b=b),
exp_f.call((a, b), a=a, b=b))
def test_pytree_namedtuple(self):
T = collections.namedtuple("SomeType", ("a", "b", "c"))
export.register_namedtuple_serialization(
T,
serialized_name="test_pytree_namedtuple.SomeType",
)
x = T(a=1, b=2, c=3)
def f(x):
return (x, x) # return 2 copies, to check that types are shared
exp = export.export(jax.jit(f))(x)
res = exp.call(x)
self.assertEqual(tree_util.tree_structure(res),
tree_util.tree_structure((x, x)))
self.assertEqual(type(res[0]), type(x))
self.assertEqual(type(res[1]), type(x))
ser = exp.serialize()
exp2 = export.deserialize(ser)
self.assertEqual(exp2.in_tree, exp.in_tree)
self.assertEqual(exp2.out_tree, exp.out_tree)
res2 = exp2.call(x)
self.assertEqual(tree_util.tree_structure(res2),
tree_util.tree_structure(res))
def test_pytree_namedtuple_error(self):
T = collections.namedtuple("SomeType", ("a", "b"))
x = T(a=1, b=2)
with self.assertRaisesRegex(
ValueError,
"Cannot serialize .* unregistered type .*SomeType"):
export.export(jax.jit(lambda x: x))(x).serialize()
with self.assertRaisesRegex(
ValueError,
"If `from_children` is not present.* must call.*register_pytree_node"
):
export.register_pytree_node_serialization(
T,
serialized_name="test_pytree_namedtuple.SomeType_V2",
serialize_auxdata=lambda x: b"",
deserialize_auxdata=lambda b: None
)
with self.assertRaisesRegex(ValueError,
"Use .*register_pytree_node_serialization"):
export.register_namedtuple_serialization(str, serialized_name="n/a")
export.register_namedtuple_serialization(
T,
serialized_name="test_pytree_namedtuple_error.SomeType",
)
with self.assertRaisesRegex(
ValueError,
"Duplicate serialization registration .*test_pytree_namedtuple_error.SomeType"
):
export.register_namedtuple_serialization(
T,
serialized_name="test_pytree_namedtuple_error.OtherType",
)
with self.assertRaisesRegex(
ValueError,
"Duplicate serialization registration for serialized_name.*test_pytree_namedtuple_error.SomeType"
):
export.register_namedtuple_serialization(
collections.namedtuple("SomeOtherType", ("a", "b")),
serialized_name="test_pytree_namedtuple_error.SomeType",
)
def test_pytree_custom_types(self):
x1 = collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)])
@tree_util.register_pytree_node_class
class CustomType:
def __init__(self, a: int, b: CustomType | None, string: str):
self.a = a
self.b = b
self.string = string
def tree_flatten(self):
return ((self.a, self.b), self.string)
@classmethod
def tree_unflatten(cls, aux_data, children):
string = aux_data
return cls(*children, string)
export.register_pytree_node_serialization(
CustomType,
serialized_name="test_pytree_custom_types.CustomType",
serialize_auxdata=lambda aux: aux.encode("utf-8"),
deserialize_auxdata=lambda b: b.decode("utf-8")
)
x2 = CustomType(4, 5, "foo")
def f(x1, x2):
return (x1, x2, x1, x2) # return 2 copies, to check that types are shared
exp = export.export(jax.jit(f))(x1, x2)
res = exp.call(x1, x2)
self.assertEqual(tree_util.tree_structure(res),
tree_util.tree_structure(((x1, x2, x1, x2))))
self.assertEqual(type(res[0]), type(x1))
self.assertEqual(type(res[1]), type(x2))
self.assertEqual(type(res[2]), type(x1))
self.assertEqual(type(res[3]), type(x2))
ser = exp.serialize()
exp2 = export.deserialize(ser)
self.assertEqual(exp2.in_tree, exp.in_tree)
self.assertEqual(exp2.out_tree, exp.out_tree)
res2 = exp2.call(x1, x2)
self.assertEqual(tree_util.tree_structure(res2),
tree_util.tree_structure(res))
def test_error_wrong_intree(self):
def f(a_b_pair, *, c):
return jnp.sin(a_b_pair[0]) + jnp.cos(a_b_pair[1]) + c
a = b = c = np.arange(4, dtype=np.float32)
exp_f = get_exported(jax.jit(f))((a, b), c=c)
with self.assertRaisesRegex(
ValueError,
"The invocation args and kwargs must have the same pytree structure"):
exp_f.call(a, b, c=(a, b))
def test_error_wrong_avals(self):
def f(a, *, b): # a: f32[4] and b: f32[4]
return jnp.sin(a) + jnp.cos(b)
f32_4 = np.arange(4, dtype=np.float32)
exp_f = get_exported(jax.jit(f))(f32_4, b=f32_4)
with self.assertRaisesRegex(ValueError,
r"Shape mismatch for args\[0\].shape\[0\]"):
exp_f.call(np.arange(6, dtype=np.float32), b=f32_4)
with self.assertRaisesRegex(ValueError,
r"Shape mismatch for kwargs\['b'\].shape\[0\]"):
exp_f.call(f32_4, b=np.arange(6, dtype=np.float32))
with self.assertRaisesRegex(ValueError,
r"Rank mismatch for args\[0\]"):
exp_f.call(f32_4.reshape((1, 4)), b=f32_4)
with self.assertRaisesRegex(ValueError,
r"Dtype mismatch for args\[0\]"):
exp_f.call(f32_4.astype(np.float16), b=f32_4)
def test_default_export_platform(self):
test_platform = jtu.device_under_test()
if test_platform == "gpu":
test_platform = "rocm" if jtu.is_device_rocm() else "cuda"
self.assertEqual(export.default_export_platform(), test_platform)
exp = export.export(jnp.sin)(1.)
self.assertEqual(exp.platforms, (export.default_export_platform(),))
@jtu.parameterized_filterable(
testcase_name=lambda kw: kw["platform"],
kwargs=[dict(platform=p)
for p in ("cpu", "cuda", "rocm", "tpu")])
def test_error_wrong_platform(self, platform):
a = np.arange(4, dtype=np.float32)
exp_f = get_exported(jnp.sin, platforms=(platform,))(a)
if xb.canonicalize_platform(jtu.device_under_test()) == platform:
raise unittest.SkipTest("Uninteresting scenario")
with self.assertRaisesRegex(
ValueError, "Function .* was exported for platform"):
exp_f.call(a)
# Now try with the platform check disabled
exp_f_no_platform_check = get_exported(
jnp.sin, platforms=(platform,),
disabled_checks=[export.DisabledSafetyCheck.platform()])(a)
res = exp_f_no_platform_check.call(a)
self.assertAllClose(res, jnp.sin(a))
@jtu.parameterized_filterable(
testcase_name=lambda kw: kw["dialect"],
kwargs=[dict(dialect=dialect)
for dialect in ("stablehlo",)]
)
def test_error_disallowed_custom_call(self, dialect):
# If we use hlo.custom_call we detect invalid custom call targets.
# Set up a primitive with custom lowering rules
test_primitive = core.Primitive("_test_primitive_disallowed_custom_call")
test_primitive.def_abstract_eval(lambda in_aval: in_aval)
def test_primitive_lowering(ctx, arg):
op = dict(stablehlo=hlo.CustomCallOp)[dialect]
return op([arg.type], [arg], "disallowed_call_target").results
mlir.register_lowering(test_primitive, test_primitive_lowering)
self.addCleanup(lambda: mlir.register_lowering(test_primitive, None))
a = np.arange(3, dtype=np.float32)
with self.assertRaisesRegex(ValueError,
"Cannot serialize code with custom calls whose targets .*"):
get_exported(
jax.jit(lambda a: a + test_primitive.bind(a))
)(a)
# Now try again with the safety check disabled
exp = get_exported(
jax.jit(lambda a: a + test_primitive.bind(a)),
disabled_checks=[export.DisabledSafetyCheck.custom_call("disallowed_call_target")]
)(a)
self.assertIn("disallowed_call_target", exp.mlir_module())
def test_lowering_parameters_for_export(self):
# Test that we propagate properly the LoweringParameters.for_export
test_primitive = core.Primitive("_test_primitive_for_export")
test_primitive.def_abstract_eval(lambda in_aval: in_aval)
# Store here the context for lowering
context = {}
def test_primitive_lowering(ctx, arg):
context["for_export"] = ctx.module_context.lowering_parameters.for_export
context["export_ignore_forward_compatibility"] = ctx.module_context.lowering_parameters.export_ignore_forward_compatibility
return mlir.hlo.AddOp(arg, arg).results
mlir.register_lowering(test_primitive, test_primitive_lowering)
self.addCleanup(lambda: mlir.register_lowering(test_primitive, None))
f = jax.jit(test_primitive.bind)
a = np.arange(3, dtype=np.float32)
context.clear()
res = f(a) # Works with JIT
self.assertAllClose(res, a + a)
self.assertEqual(context,
dict(for_export=False,
export_ignore_forward_compatibility=False))
context.clear()
f.lower(a) # Works with most AOT
# The above was cached
self.assertEqual(context, {})
_ = export.export(f)(a)
self.assertEqual(context,
dict(for_export=True,
export_ignore_forward_compatibility=False))
context.clear()
with config.export_ignore_forward_compatibility(True):
_ = export.export(f)(a)
self.assertEqual(context,
dict(for_export=True,
export_ignore_forward_compatibility=True))
def test_grad(self):
f = lambda x: jnp.sum(jnp.sin(x))
x = np.arange(4, dtype=np.float32)
exp_f = get_exported(jax.jit(f), vjp_order=1)(x)
f1 = exp_f.call
self.assertAllClose(jax.grad(f)(x), jax.grad(f1)(x))
def test_higher_order_grad(self):
f = lambda x: x ** 3
x = np.float32(4.)
exp_f = get_exported(jax.jit(f), vjp_order=3)(x)
f1 = exp_f.call
self.assertAllClose(jax.grad(f)(x),
jax.grad(f1)(x))
self.assertAllClose(jax.grad(jax.grad(f))(x),
jax.grad(jax.grad(f1))(x))
self.assertAllClose(jax.grad(jax.grad(jax.grad(f)))(x),
jax.grad(jax.grad(jax.grad(f1)))(x))
@jtu.parameterized_filterable(
kwargs=[dict(poly_shape=True), dict(poly_shape=False)])
def test_grad_int(self, poly_shape):
def f(xi, xf):
return (2 * xi.T, xf.T * xf.T)
xi = np.arange(6, dtype=np.int32).reshape((2, 3))
xf = np.arange(12, dtype=np.float32).reshape((3, 4))
# Native JAX 1st order vjp
(f_outi, f_outf), f_vjp = jax.vjp(f, xi, xf)
f_outi_ct = np.ones(f_outi.shape,
dtype=core.primal_dtype_to_tangent_dtype(f_outi.dtype))
f_outf_ct = np.ones(f_outf.shape, dtype=f_outf.dtype)
xi_ct, xf_ct = f_vjp((f_outi_ct, f_outf_ct))
# Native JAX 2nd order vjp
res, f_vjp2 = jax.vjp(f_vjp, (f_outi_ct, f_outf_ct))
self.assertAllClose(res, (xi_ct, xf_ct))
(f_outi_ct2, f_outf_ct2), = f_vjp2((xi_ct, xf_ct))
if poly_shape:
args = export.symbolic_args_specs([xi, xf], shapes_specs=["2, a", "a, 4"])
else:
args = (xi, xf)
exp = get_exported(jax.jit(f), vjp_order=2)(*args)
fr = exp.call
res = fr(xi, xf)
self.assertAllClose(res, (f_outi, f_outf))
# Reloaded 1st order vjp
(fr_outi, fr_outf), fr_vjp = jax.vjp(fr, xi, xf)
self.assertAllClose(fr_outi, f_outi)
self.assertAllClose(fr_outf, f_outf)
xri_ct, xrf_ct = fr_vjp((f_outi_ct, f_outf_ct))
self.assertAllClose(xri_ct, xi_ct)
self.assertAllClose(xrf_ct, xf_ct)
# Reloaded 2nd order vjp
res, f_vjp2 = jax.vjp(fr_vjp, (f_outi_ct, f_outf_ct))
self.assertAllClose(res, (xi_ct, xf_ct))
(fr_outi_ct2, fr_outf_ct2), = f_vjp2((xi_ct, xf_ct))
self.assertAllClose(fr_outi_ct2, f_outi_ct2)
self.assertAllClose(fr_outf_ct2, f_outf_ct2)
def test_pytree_vjp(self):
def f(a_b_pair, *, a, b):
return (dict(res=a_b_pair, a=2. * a, b=3. * b),
jnp.sin(4. * a))
a = np.arange(4, dtype=np.float32)
b = np.arange(6, dtype=np.float32)
exp_f = get_exported(jax.jit(f), vjp_order=1)((a, b), a=a, b=b)
out_ct = f((a, b), a=a, b=b) # The output has the right structure as the cotangent
def f1_jax(a, b): # For VJP, make a function without kwargs
res = f((a, b), a=a, b=b)
return res
def f1_exp(a, b): # For VJP, make a function without kwargs
res = exp_f.call((a, b), a=a, b=b)
return res
jax_vjp = jax.vjp(f1_jax, a, b)[1](out_ct)
exp_vjp = jax.vjp(f1_exp, a, b)[1](out_ct)
self.assertAllClose(jax_vjp, exp_vjp)
def test_roundtrip(self):
def f1(x):
return jnp.sin(x)
a = np.arange(4, dtype=np.float32)
exp_f1 = get_exported(jax.jit(f1))(a)
def f2(x):
res1 = exp_f1.call(x)
res2 = exp_f1.call(res1)
return jnp.cos(res2)
exp_f2 = get_exported(jax.jit(f2))(a)
self.assertAllClose(jnp.cos(jnp.sin(jnp.sin(a))),
exp_f2.call(a))
def test_poly_export_only(self):
a = np.arange(12, dtype=np.float32).reshape((3, 4))
def f(a, b): # a: f32[2w,h] b: f32[w,h]
return jnp.concatenate([a, b], axis=0)
scope = export.SymbolicScope()
exp = get_exported(jax.jit(f))(
jax.ShapeDtypeStruct(export.symbolic_shape("(2*w, h)", scope=scope), a.dtype),
jax.ShapeDtypeStruct(export.symbolic_shape("(w, h)", scope=scope), a.dtype))
self.assertEqual("(2*w, h)", str(exp.in_avals[0].shape))
self.assertEqual("(w, h)", str(exp.in_avals[1].shape))
self.assertEqual("(3*w, h)", str(exp.out_avals[0].shape))
# Peek at the module
module_str = exp.mlir_module()
self.assertEqual(config.jax_export_calling_convention_version.value >= 7,
"shape_assertion" in module_str)
self.assertIn("jax.uses_shape_polymorphism = true", module_str)
wrapped_main_expected_re = (
r"@_wrapped_jax_export_main\("
r"%arg0: tensor<i..> {jax.global_constant = \"h\".*"
r"%arg1: tensor<i..> {jax.global_constant = \"w\".*"
r"%arg2: tensor<\?x\?xf32>"
)
self.assertRegex(module_str, wrapped_main_expected_re)
# Look for private inner functions that are generated to compute the
# dimension variables and shape assertions. All those functions must
# have jax.global_constant attributes on all the arguments.
for func_name, func_args in re.findall(
r"func.func private @([\w]+)\((.+)\) ->",
module_str):
if func_name == "_wrapped_jax_export_main":
continue
func_args_count = len(re.findall(r"%arg\d+", func_args))
func_args_constant_attrs = len(re.findall(r"jax.global_constant = ",
func_args))
self.assertEqual(func_args_count, func_args_constant_attrs)
def test_poly_pytree_export_only(self):
a = np.arange(12, dtype=np.float32).reshape((3, 4))
def f(a0, a1, *, ak):
return jnp.concatenate([a0, a1, ak], axis=0)
a_poly_spec = jax.ShapeDtypeStruct(export.symbolic_shape("(w, h)"), a.dtype)
exp = get_exported(jax.jit(f))(a_poly_spec, a_poly_spec, ak=a_poly_spec)
self.assertEqual("(w, h)", str(exp.in_avals[0].shape))
self.assertEqual("(3*w, h)", str(exp.out_avals[0].shape))
def test_poly_export_error_symbolic_scope(self):
a = np.arange(12, dtype=np.float32).reshape((3, 4))
def f(x, y):
return jnp.concatenate([x, y], axis=1)
x_poly_spec = jax.ShapeDtypeStruct(export.symbolic_shape("(w, h1)"), a.dtype)
y_poly_spec = jax.ShapeDtypeStruct(export.symbolic_shape("(w, h2)"), a.dtype)
with self.assertRaisesRegex(
ValueError,
re.compile(
"Invalid mixing of symbolic scopes when exporting f.*"
r"Expected current \(from args\[0\]\) scope .*"
r"and found for 'w' \(args\[1\]\) scope .*", re.DOTALL)):
get_exported(jax.jit(f))(x_poly_spec, y_poly_spec)
def test_poly_export_callable_with_no_name(self):
# This was reported by a user
class MyCallable:
def __call__(self, x):
return jnp.sin(x)
# This makes it look like a jitted-function
def lower(self, x, _experimental_lowering_parameters=None):
return jax.jit(self.__call__).lower(
x,
_experimental_lowering_parameters=_experimental_lowering_parameters)
def trace(self, x, _experimental_lowering_parameters=None):
return jax.jit(self.__call__).trace(
x,
_experimental_lowering_parameters=_experimental_lowering_parameters)
a, = export.symbolic_shape("a,")
# No error
_ = get_exported(jax.jit(MyCallable()))(
jax.ShapeDtypeStruct((a, a), dtype=np.float32)
)
@jtu.parameterized_filterable(
kwargs=[
dict(v=v)
for v in range(export.minimum_supported_calling_convention_version - 1,
export.maximum_supported_calling_convention_version + 2)])
def test_poly_basic_versions(self, v: int):
with config.jax_export_calling_convention_version(v):
logging.info(
"Using JAX calling convention version %s",
config.jax_export_calling_convention_version.value)
with contextlib.ExitStack() as e:
if not (export.minimum_supported_calling_convention_version <= v
<= export.maximum_supported_calling_convention_version):
e.enter_context(self.assertRaisesRegex(
ValueError,
f"The requested export calling convention version {v} is outside the range of supported versions"))
exp = get_exported(jnp.sin)(
jax.ShapeDtypeStruct(export.symbolic_shape("w, h"), np.float32))
x = np.arange(30, dtype=np.float32).reshape((5, 6))
res = exp.call(x)
self.assertAllClose(res, np.sin(x))
# A function is exported with f32[poly_spec] and is called with different arg
# shapes. We use export.call and we also run the shape check
# module.
@jtu.parameterized_filterable(
testcase_name=lambda kw:f"poly_spec={kw['poly_spec']}_arg_shape={kw['arg_shape']}", # type: ignore
kwargs=[
dict(poly_spec="3,4,12", arg_shape=(3, 4, 12)),
dict(poly_spec="3,4,12", arg_shape=(3, 4, 13),
# The shape check module does not test constant dimensions
expect_error=re.escape(
r"Shape mismatch for args[0].shape[2] (expected same constant)")),
dict(poly_spec="3,4,6*a", arg_shape=(3, 4, 12)),
dict(poly_spec="3,a,a+8", arg_shape=(3, 4, 12)),
dict(poly_spec="3,4,a+1", arg_shape=(3, 4, 1),
expect_error=re.escape(
"Expected value >= 1 for dimension variable 'a'. "
"Using the following polymorphic shapes specifications: args[0].shape = (3, 4, a + 1). "
"Obtained dimension variables: 'a' = 0"
)),
dict(poly_spec="3,4,6*a", arg_shape=(3, 4, 13),
expect_error=re.escape(
"Division had remainder 1 when computing the value of 'a'"
)),
dict(poly_spec="3,a,a+8", arg_shape=(3, 4, 13),
expect_error=re.escape(
"Found inconsistency between dimension size "
"args[0].shape[2] (= 13) and the specification 'a + 8' (= 12)"
)),
])
def test_poly_shape_checks(
self, poly_spec="3,a,a+8",
arg_shape=(3, 4, 12), arg_dtype=np.float32,
expect_error=None): # If given, error from running the exported module
def f(x): # x: f32[poly_spec]
return jnp.reshape(x, (-1, x.shape[1]))
disabled_checks = ()
exp_f = get_exported(jax.jit(f), disabled_checks=disabled_checks)(
jax.ShapeDtypeStruct(export.symbolic_shape(poly_spec), np.float32))
self.assertEqual(exp_f.uses_global_constants, poly_spec != "3,4,12")
arg = np.arange(np.prod(arg_shape),
dtype=arg_dtype).reshape(arg_shape) # arg : f32[3,4,12]
with contextlib.ExitStack() as stack:
if expect_error is not None:
stack.push(self.assertRaisesRegex(Exception, expect_error))
assert core.is_constant_shape(arg.shape)
res = exp_f.call(arg)
if not expect_error:
self.assertAllClose(res, f(arg))
# An inner function is exported with polymorphic shapes inner_poly_spec, and
# is called from an outer function, which is exported with outer_poly_spec.
@jtu.parameterized_filterable(
testcase_name=lambda kw:f"inner={kw['inner_poly_spec']}_outer={kw['outer_poly_spec']}", # type: ignore
#one_containing="",
# By default arg_shape = (3, 4, 12) for both the outer function and the inner
# The inner function is exported for f32.
kwargs=[
# Both inner and outer are static shapes
dict(inner_poly_spec="3,4,12", outer_poly_spec="3,4,12"),
# Inner has poly shapes but outer has static shapes. When we call inner
# we do the shape constraint checking
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,4,12"),
dict(inner_poly_spec="3,4,3*a", outer_poly_spec="3,4,12"),
dict(inner_poly_spec="3,a,a", outer_poly_spec="3,4,12",
expect_error_outer_exp=re.escape(
"Found inconsistency between dimension size "
"args[0].shape[2] (= 12) and the specification 'a' (= 4)")),
dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,12",
expect_error_outer_exp=re.escape(
"Division had remainder 2 when computing the value of 'a'")),
dict(inner_poly_spec="3,4,12+a", outer_poly_spec="3,4,12",
expect_error_outer_exp=re.escape(
"Expected value >= 1 for dimension variable 'a'. "
"Using the following polymorphic shapes specifications: args[0].shape = (3, 4, a + 12). "
"Obtained dimension variables: 'a' = 0 from specification "
"'a + 12' for dimension args[0].shape[2] (= 12)")),
# Both inner and outer have poly shapes.
dict(inner_poly_spec="3,a,b", outer_poly_spec="3,4,c"),
dict(inner_poly_spec="3,4,3*a", outer_poly_spec="3,4,6*c"),
dict(inner_poly_spec="3,a,a+8", outer_poly_spec="3,c+2,c+10"),
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,4,c",
expect_error_outer_exp=re.escape(
"Expected value >= 1 for dimension variable 'b'. "
"Using the following polymorphic shapes specifications: args[0].shape = (3, a, b + a). "
"Obtained dimension variables: 'a' = 4 from specification "
"'a' for dimension args[0].shape[1] (= 4), "
"'b' = c - 4 from specification 'b + a' for dimension args[0].shape[2] (= c),")),
dict(inner_poly_spec="3,a,a", outer_poly_spec="3,4,c",
expect_error_outer_exp=re.escape(
"Found inconsistency between dimension size "
"args[0].shape[2] (= c) and the specification 'a' (= 4)")),
dict(inner_poly_spec="3,a,a", arg_shape=(3, 4),
outer_poly_spec="3,c",
expect_error_outer_exp=r"Rank mismatch for args\[0\]"),
dict(inner_poly_spec="3,a,a+b", arg_dtype=np.int32,
outer_poly_spec="3,c,d",
expect_error_outer_exp=r"Dtype mismatch for args\[0\]"),
dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,c",
expect_error_outer_exp=re.escape(
"Division had remainder mod(c, 5) when computing the value of 'a'")),
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="3,c,c",
expect_error_outer_exp=re.escape(
"Expected value >= 1 for dimension variable 'b'. "
"Using the following polymorphic shapes specifications: args[0].shape = (3, a, b + a). "
"Obtained dimension variables: 'a' = c from "
"specification 'a' for dimension args[0].shape[1] (= c), "
"'b' = 0 from specification 'b + a' for dimension args[0].shape[2] (= c)")),
dict(inner_poly_spec="3,a,a+b", outer_poly_spec="c,4,12",
expect_error_outer_exp=re.escape(
"Shape mismatch for args[0].shape[0] (expected same constant)")),
dict(inner_poly_spec="3,4,5*a", outer_poly_spec="3,4,25*c",
expect_error_run=re.escape(
"Division had remainder 12 when computing the value of 'c'")),
dict(inner_poly_spec="3,a,b", outer_poly_spec="3,c+4,12",
expect_error_run=re.escape(
"Expected value >= 1 for dimension variable 'c'. "
"Using the following polymorphic shapes specifications: args[0].shape = (3, c + 4, 12). "
"Obtained dimension variables: 'c' = 0")),
dict(inner_poly_spec="3,a,a", outer_poly_spec="3,a,a",
expect_error_run=re.escape(
"Found inconsistency between dimension size "
"args[0].shape[2] (= 12) and the specification 'a' (= 4)")),
])
def test_poly_shape_checks_nested(
self, inner_poly_spec="3,4,5*a",
arg_shape=(3, 4, 12), arg_dtype=np.float32,
outer_poly_spec="3,4,25*c",
expect_error_outer_exp=None,
expect_error_run=None):
# Polymorphic export called with static or polymorphic shapes
def inner(x): # x: inner_poly_spec
return jnp.reshape(x, (-1, x.shape[1]))
arg = np.arange(np.prod(arg_shape),
dtype=arg_dtype).reshape(arg_shape) # x : f32[3,4,12]
inner_exp = get_exported(jax.jit(inner))(
jax.ShapeDtypeStruct(export.symbolic_shape(inner_poly_spec), np.float32))
self.assertEqual(inner_exp.uses_global_constants,
(inner_poly_spec != "3,4,12"))
def outer(x): # x: outer_poly_spec
# Use an addition to test that the shapes are refined properly for the
# result of the call_exported.
return inner_exp.call(x) + inner(x)
with contextlib.ExitStack() as stack:
if expect_error_outer_exp is not None:
stack.push(self.assertRaisesRegex(ValueError, expect_error_outer_exp))
# Call it after exporting again, with polymorphic shapes
outer_exp = get_exported(jax.jit(outer))(
jax.ShapeDtypeStruct(export.symbolic_shape(outer_poly_spec), arg.dtype))
if expect_error_outer_exp is not None:
return
self.assertEqual(outer_exp.uses_global_constants,
(inner_poly_spec != "3,4,12" or outer_poly_spec != "3,4,12"))
with contextlib.ExitStack() as stack:
if expect_error_run is not None:
stack.push(self.assertRaisesRegex(Exception, expect_error_run))
res = outer_exp.call(arg)
if expect_error_run is not None:
return
self.assertAllClose(2. * inner(arg), res)
# Tests details of the shape constraints errors
# This test exists also in shape_poly_test.py. Here we test the
# call_exported error reporting.
@jtu.parameterized_filterable(
testcase_name=lambda kw: kw["shape"], # assume "shape" is unique
kwargs=[
dict(shape=(8, 2, 9), # a = 2, b = 3, c = 4
poly_spec="(a + 2*b, a, a + b + c)"),
dict(shape=(2, 2, 6), # a = 2, b = 0, c = 4
poly_spec="(a + 2*b, a, a + b + c)",
expect_error=(
"Input shapes do not match the polymorphic shapes specification. "
"Expected value >= 1 for dimension variable 'b'. "
"Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). "
"Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), "
"'b' = 0 from specification '2*b + a' for dimension args[0].shape[0] (= 2), . "
"Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details."
)),
dict(shape=(3, 2, 6), # a = 2, b = 0.5, c = 4 - b is not integer
poly_spec="(a + 2*b, a, a + b + c)",
expect_error=(
"Input shapes do not match the polymorphic shapes specification. "
"Division had remainder 1 when computing the value of 'b'. "
"Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, c + b + a). "
"Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), . "
"Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details."
)),
dict(shape=(8, 2, 6), # a = 2, b = 3 - inconsistency
poly_spec="(a + 2*b, a, a + b)",
expect_error=(
"Input shapes do not match the polymorphic shapes specification. "
"Found inconsistency between dimension size args[0].shape[0] (= 8) and the specification '2*b + a' (= 10). "
"Using the following polymorphic shapes specifications: args[0].shape = (2*b + a, a, b + a). "
"Obtained dimension variables: 'a' = 2 from specification 'a' for dimension args[0].shape[1] (= 2), "
"'b' = 4 from specification 'b + a' for dimension args[0].shape[2] (= 6), . "
"Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-assertion-errors for more details."
)),
dict(shape=(7, 2, 36), # a = 2, b = 3, c = 6 - cannot solve c
poly_spec="(2 * a + b, a, c * c)",
expect_error=(
"Cannot solve for values of dimension variables {'c'}. "
"We can only solve linear uni-variate constraints. "
"Using the following polymorphic shapes specifications: args[0].shape = (b + 2*a, a, c^2). "
"Unprocessed specifications: 'c^2' for dimension size args[0].shape[2]. "
"Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details."
)),
])
def test_shape_constraints_errors(self, *,
shape, poly_spec: str, expect_error: str | None = None):
def f_jax(x): # x: f32[a + 2*b, a, a + b + c]
return 0.
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
with contextlib.ExitStack() as stack:
if expect_error is not None:
stack.push(self.assertRaisesRegex(Exception, re.escape(expect_error)))
exp = get_exported(jax.jit(f_jax))(
jax.ShapeDtypeStruct(export.symbolic_shape(poly_spec), x.dtype))
exp.call(x)
def test_poly_booleans(self):
# For booleans we use a special case ConvertOp to cast to and from
# dynamic shapes arguments.
@jax.jit
def f_jax(x): # x: bool[b]
return jnp.logical_not(x)
x = np.array([True, False, True, False], dtype=np.bool_)