forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
debug_nans_test.py
237 lines (191 loc) · 6.96 KB
/
debug_nans_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for --debug_nans."""
from absl.testing import absltest, parameterized
import jax
import numpy as np
from unittest import SkipTest
from jax._src import api
from jax._src import test_util as jtu
from jax import numpy as jnp
from jax.experimental import pjit
import jax._src.lib
from jax.config import config
config.parse_flags_with_absl()
class DebugNaNsTest(jtu.JaxTestCase):
def setUp(self):
self.cfg = config._read("jax_debug_nans")
config.update("jax_debug_nans", True)
def tearDown(self):
config.update("jax_debug_nans", self.cfg)
def testSinc(self):
# Regression test for #6936
self.assertEqual(jnp.sinc(0.0), 1.0)
def testSingleResultPrimitiveNoNaN(self):
A = jnp.array([[1., 2.], [2., 3.]])
ans = jnp.tanh(A)
ans.block_until_ready()
def testMultipleResultPrimitiveNoNaN(self):
A = jnp.array([[1., 2.], [2., 3.]])
ans, _ = jnp.linalg.eigh(A)
ans.block_until_ready()
def testJitComputationNoNaN(self):
A = jnp.array([[1., 2.], [2., 3.]])
ans = jax.jit(jnp.tanh)(A)
ans.block_until_ready()
def testJitComputationNaN(self):
A = jnp.array(0.)
with self.assertRaises(FloatingPointError):
ans = jax.jit(lambda x: 0. / x)(A)
ans.block_until_ready()
def testJitComputationNaNContextManager(self):
config.update("jax_debug_nans", False)
A = jnp.array(0.)
f = jax.jit(lambda x: 0. / x)
ans = f(A)
ans = f(A)
with self.assertRaises(FloatingPointError):
with jax.debug_nans(True):
ans = f(A)
ans.block_until_ready()
def testSingleResultPrimitiveNaN(self):
A = jnp.array(0.)
with self.assertRaises(FloatingPointError):
ans = 0. / A
ans.block_until_ready()
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_jit={jit._name}", "jit": jit}
for jit in jtu.JIT_IMPLEMENTATION))
def testCallDeoptimized(self, jit):
@jit
def f(x):
return jax.lax.cond(
x == 1, lambda _: np.nan, lambda _: 2., operand=None)
# This makes sure, when using the C++ jit, that the Python code has been
# run to compile, and the next call won't go through `cache_miss`.
f(2)
# 'cond' not 'xla_call'
msg = r"invalid value \(nan\) encountered in cond"
with self.assertRaisesRegex(FloatingPointError, msg):
f(1)
def testPmap(self):
pmap_funcs = [api._python_pmap, api._cpp_pmap]
for pmap in pmap_funcs:
f = pmap(lambda x: 0. / x)
# For the Cpp pmap, the first execution always goes through Python.
f(jnp.array([1.]))
with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in parallel computation"):
ans = f(jnp.array([0.]))
ans.block_until_ready()
if jax.device_count() >= 2:
with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in parallel computation"):
ans = f(jnp.array([1., 0.]))
ans.block_until_ready()
def testPmapNoNaN(self):
ans = jax.pmap(lambda x: 0. / x)(jnp.array([1.]))
ans.block_until_ready()
@jtu.ignore_warning(message=".*is an experimental.*")
def testXmap(self):
f = jax.experimental.maps.xmap(
lambda x: 0. / x,
in_axes=['i'],
out_axes=['i'],
axis_resources={'i': 'x'})
with jax.experimental.maps.Mesh(np.array(jax.local_devices()[:1]), ('x',)):
with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in parallel computation"):
ans = f(jnp.array([0.]))
ans.block_until_ready()
if jax.device_count() >= 2:
with jax.experimental.maps.Mesh(np.array(jax.local_devices()[:2]), ('x',)):
with self.assertRaises(FloatingPointError):
ans = f(jnp.array([1., 0.]))
ans.block_until_ready()
@jtu.ignore_warning(message=".*is an experimental.*")
def testPjit(self):
if jax.device_count() < 2:
raise SkipTest("test requires >=2 devices")
p = jax.experimental.PartitionSpec('x')
f = pjit.pjit(lambda x: 0. / x,
in_axis_resources=p,
out_axis_resources=p)
with jax.experimental.maps.Mesh(np.array(jax.local_devices()[:2]), ('x',)):
with self.assertRaises(FloatingPointError):
ans = f(jnp.array([0., 1.]))
ans.block_until_ready()
# TODO(skye): add parallel inf tests, ideally by factoring out test logic
class DebugInfsTest(jtu.JaxTestCase):
def setUp(self):
self.cfg = config._read("jax_debug_infs")
config.update("jax_debug_infs", True)
def tearDown(self):
config.update("jax_debug_infs", self.cfg)
def testSingleResultPrimitiveNoInf(self):
A = jnp.array([[1., 2.], [2., 3.]])
ans = jnp.tanh(A)
ans.block_until_ready()
def testMultipleResultPrimitiveNoInf(self):
A = jnp.array([[1., 2.], [2., 3.]])
ans, _ = jnp.linalg.eigh(A)
ans.block_until_ready()
def testJitComputationNoInf(self):
A = jnp.array([[1., 2.], [2., 3.]])
ans = jax.jit(jnp.tanh)(A)
ans.block_until_ready()
def testSingleResultPrimitiveInf(self):
A = jnp.array(0.)
with self.assertRaises(FloatingPointError):
ans = 1. / A
ans.block_until_ready()
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_jit={jit._name}", "jit": jit}
for jit in jtu.JIT_IMPLEMENTATION))
def testCallDeoptimized(self, jit):
@jit
def f(x):
return jax.lax.cond(
x == 1, lambda _: np.inf, lambda _: 2., operand=None)
# This makes sure, when using the C++ jit, that the Python code has been
# run to compile, and the next call won't go through `cache_miss`.
f(2)
# 'cond' not 'xla_call'
msg = r"invalid value \(inf\) encountered in cond"
with self.assertRaisesRegex(FloatingPointError, msg):
f(1)
def testDebugNansDoesntCorruptCaches(self):
# https://github.com/google/jax/issues/6614
@jax.jit
def f(x):
return jnp.divide(x, x)
for _ in range(2):
try:
with jax.debug_nans(True):
jax.grad(f)(0.)
except FloatingPointError:
pass
def testDebugNansDoesntReturnDeoptimizedResult(self):
@jax.jit
def f(x):
x + 2 # avoid trivial dispatch path by adding some eqn
return jnp.nan
with self.assertRaisesRegex(FloatingPointError, "de-optimized"):
with jax.debug_nans(True):
f(3)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())