forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathjax_to_ir_test.py
154 lines (129 loc) · 5.55 KB
/
jax_to_ir_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# Copyright 2019 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from absl.testing import absltest
import jax.numpy as jnp
from jax.tools import jax_to_ir
from jax._src import test_util as jtu
try:
import tensorflow as tf
except ImportError:
tf = None # type: ignore
def axpy(a, x, y):
return a * x + y[:, jnp.newaxis]
class JaxToIRTest(absltest.TestCase):
def test_jax_to_hlo_axpy(self):
hlo_proto, hlo_text = jax_to_ir.jax_to_hlo(axpy, [
('y', jax_to_ir.parse_shape_str('f32[128]')),
('a', jax_to_ir.parse_shape_str('f32[]')),
('x', jax_to_ir.parse_shape_str('f32[128,2]')),
])
# Check that hlo_text contains a broadcast, add, and multiply.
self.assertIn('broadcast', hlo_text)
self.assertIn('add', hlo_text)
self.assertIn('multiply', hlo_text)
# Check that the HLO parameters are in the order we specified in the
# jax_to_hlo call.
self.assertIn('f32[128]{0} parameter(0)', hlo_text)
self.assertIn('f32[] parameter(1)', hlo_text)
self.assertIn('f32[128,2]{1,0} parameter(2)', hlo_text)
# Check that the parameters are in the expected order.
# TODO(jlebar): Ideally we'd check that hlo_proto can be deserialized to a
# valid HLO proto, but we don't seem to have access to hlo_pb2 at the
# moment, so the best we seem to be able to do is check that it's nonempty.
assert hlo_proto
def test_jax_to_hlo_with_constants(self):
def fn(a, b, x, y):
return a / b * x + y
_, hlo_text = jax_to_ir.jax_to_hlo(
fn,
input_shapes=[
('x', jax_to_ir.parse_shape_str('f32[128]')),
('y', jax_to_ir.parse_shape_str('f32[128]')),
],
constants={
'a': 123456,
'b': 4,
})
# Because we passed `a` and `b` as constants, they get constant-folded away
# by Python/JAX to a/b = 30864.
self.assertIn('constant(30864)', hlo_text)
self.assertNotIn('123456', hlo_text)
def test_parse_shape_str_invalid(self):
with self.assertRaisesRegex(ValueError, 'Invalid shape.*foo'):
jax_to_ir.parse_shape_str('foo[]')
@unittest.skipIf(tf is None, 'TensorFlow not installed.')
@jtu.ignore_warning(
category=UserWarning,
message='jax2tf.convert with native_serialization=False is deprecated.'
)
def test_jax_to_tf_axpy(self):
tf_proto, tf_text = jax_to_ir.jax_to_tf(axpy, [
('y', jax_to_ir.parse_shape_str('f32[128]')),
('a', jax_to_ir.parse_shape_str('f32[]')),
('x', jax_to_ir.parse_shape_str('f32[128,2]')),
])
# Check that tf debug txt contains a broadcast, add, and multiply.
self.assertIn('BroadcastTo', tf_text)
self.assertIn('AddV2', tf_text)
self.assertIn('Mul', tf_text)
# Check that we can re-import our graphdef.
gdef = tf.compat.v1.GraphDef()
gdef.ParseFromString(tf_proto)
g = tf.Graph()
with g.as_default():
tf.import_graph_def(gdef, name='')
# Check that the HLO parameters are named as we specified.
ops = {o.name: o for o in g.get_operations()
if o.name in ('y', 'a', 'x', 'jax2tf_out')}
self.assertLen(ops, 4)
self.assertIdentityOp(ops['y'], [128], jnp.float32)
self.assertIdentityOp(ops['a'], [], jnp.float32)
self.assertIdentityOp(ops['x'], [128, 2], jnp.float32)
self.assertIdentityOp(ops['jax2tf_out'], [128, 2], jnp.float32)
def assertIdentityOp(self, op, expected_shape, expected_dtype):
self.assertEqual(op.type, 'Identity')
output, = op.outputs
self.assertEqual(output.shape, expected_shape)
self.assertEqual(output.dtype, expected_dtype)
def test_parse_shape_str(self):
self.assertParsedShape('f32[]', [], jnp.float32)
self.assertParsedShape('f32[1,2,3]', [1, 2, 3], jnp.float32)
self.assertParsedShape('pred[1]', [1], jnp.bool_)
if hasattr(jnp, 'int2'):
self.assertParsedShape('s2[1]', [1], jnp.int2)
self.assertParsedShape('s4[1]', [1], jnp.int4)
self.assertParsedShape('s8[1]', [1], jnp.int8)
self.assertParsedShape('s16[1]', [1], jnp.int16)
self.assertParsedShape('s32[1]', [1], jnp.int32)
self.assertParsedShape('s64[1]', [1], jnp.int64)
if hasattr(jnp, 'uint2'):
self.assertParsedShape('u2[1]', [1], jnp.uint2)
self.assertParsedShape('u4[1]', [1], jnp.uint4)
self.assertParsedShape('u8[1]', [1], jnp.uint8)
self.assertParsedShape('u16[1]', [1], jnp.uint16)
self.assertParsedShape('u32[1]', [1], jnp.uint32)
self.assertParsedShape('u64[1]', [1], jnp.uint64)
self.assertParsedShape('f16[1]', [1], jnp.float16)
self.assertParsedShape('f32[1]', [1], jnp.float32)
self.assertParsedShape('f64[1]', [1], jnp.float64)
self.assertParsedShape('bf16[1]', [1], jnp.bfloat16)
self.assertParsedShape('c64[1]', [1], jnp.complex64)
self.assertParsedShape('c128[1]', [1], jnp.complex128)
def assertParsedShape(self, s: str, expected_shape, expected_dtype):
p = jax_to_ir.parse_shape_str(s)
self.assertEqual(p.shape, tuple(expected_shape))
self.assertEqual(p.dtype, expected_dtype)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())