forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
debugger_test.py
353 lines (308 loc) · 10.1 KB
/
debugger_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import re
import textwrap
import unittest
from typing import IO, Sequence, Tuple
from absl.testing import absltest
import jax
from jax.config import config
from jax.experimental import maps
from jax.experimental import pjit
from jax._src import debugger
from jax._src import lib as jaxlib
from jax._src import test_util as jtu
import jax.numpy as jnp
import numpy as np
config.parse_flags_with_absl()
def make_fake_stdin_stdout(commands: Sequence[str]) -> Tuple[IO[str], io.StringIO]:
fake_stdin = io.StringIO()
fake_stdin.truncate(0)
for command in commands:
fake_stdin.write(command + "\n")
fake_stdin.seek(0)
return fake_stdin, io.StringIO()
def _format_multiline(text):
return textwrap.dedent(text).lstrip()
prev_xla_flags = None
def setUpModule():
global prev_xla_flags
# This will control the CPU devices. On TPU we always have 2 devices
prev_xla_flags = jtu.set_host_platform_device_count(2)
# Reset to previous configuration in case other test modules will be run.
def tearDownModule():
prev_xla_flags()
# TODO(sharadmv): remove jaxlib guards for TPU tests when jaxlib minimum
# version is >= 0.3.15
disabled_backends = []
if jaxlib.version < (0, 3, 15):
disabled_backends.append("tpu")
class CliDebuggerTest(jtu.JaxTestCase):
@jtu.skip_on_devices(*disabled_backends)
def test_debugger_eof(self):
stdin, stdout = make_fake_stdin_stdout([])
def f(x):
y = jnp.sin(x)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y
with self.assertRaises(SystemExit):
f(2.)
jax.effects_barrier()
@jtu.skip_on_devices(*disabled_backends)
def test_debugger_can_continue(self):
stdin, stdout = make_fake_stdin_stdout(["c"])
def f(x):
y = jnp.sin(x)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y
f(2.)
jax.effects_barrier()
expected = _format_multiline(r"""
Entering jdb:
(jdb) """)
self.assertEqual(stdout.getvalue(), expected)
@jtu.skip_on_devices(*disabled_backends)
def test_debugger_can_print_value(self):
stdin, stdout = make_fake_stdin_stdout(["p x", "c"])
def f(x):
y = jnp.sin(x)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y
expected = _format_multiline(r"""
Entering jdb:
(jdb) DeviceArray(2., dtype=float32)
(jdb) """)
f(jnp.array(2., jnp.float32))
jax.effects_barrier()
self.assertEqual(stdout.getvalue(), expected)
@jtu.skip_on_devices(*disabled_backends)
def test_debugger_can_print_value_in_jit(self):
stdin, stdout = make_fake_stdin_stdout(["p x", "c"])
@jax.jit
def f(x):
y = jnp.sin(x)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y
expected = _format_multiline(r"""
Entering jdb:
(jdb) array(2., dtype=float32)
(jdb) """)
f(jnp.array(2., jnp.float32))
jax.effects_barrier()
self.assertEqual(stdout.getvalue(), expected)
@jtu.skip_on_devices(*disabled_backends)
def test_debugger_can_print_multiple_values(self):
stdin, stdout = make_fake_stdin_stdout(["p x, y", "c"])
@jax.jit
def f(x):
y = x + 1.
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y
expected = _format_multiline(r"""
Entering jdb:
(jdb) (array(2., dtype=float32), array(3., dtype=float32))
(jdb) """)
f(jnp.array(2., jnp.float32))
jax.effects_barrier()
self.assertEqual(stdout.getvalue(), expected)
@jtu.skip_on_devices(*disabled_backends)
def test_debugger_can_print_context(self):
stdin, stdout = make_fake_stdin_stdout(["l", "c"])
@jax.jit
def f(x):
y = jnp.sin(x)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y
f(2.)
jax.effects_barrier()
expected = _format_multiline(r"""
Entering jdb:
\(jdb\) > .*debugger_test\.py\([0-9]+\)
@jax\.jit
def f\(x\):
y = jnp\.sin\(x\)
-> debugger\.breakpoint\(stdin=stdin, stdout=stdout, backend="cli"\)
return y
.*
\(jdb\) """)
self.assertRegex(stdout.getvalue(), expected)
@jtu.skip_on_devices(*disabled_backends)
def test_debugger_can_print_backtrace(self):
stdin, stdout = make_fake_stdin_stdout(["bt", "c"])
@jax.jit
def f(x):
y = jnp.sin(x)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y
expected = _format_multiline(r"""
Entering jdb:.*
\(jdb\) Traceback:.*
""")
f(2.)
jax.effects_barrier()
self.assertRegex(stdout.getvalue(), expected)
@jtu.skip_on_devices(*disabled_backends)
def test_debugger_can_work_with_multiple_stack_frames(self):
stdin, stdout = make_fake_stdin_stdout(["l", "u", "p x", "d", "c"])
def f(x):
y = jnp.sin(x)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y
@jax.jit
def g(x):
y = f(x)
return jnp.exp(y)
expected = _format_multiline(r"""
Entering jdb:
\(jdb\) > .*debugger_test\.py\([0-9]+\)
def f\(x\):
y = jnp\.sin\(x\)
-> debugger\.breakpoint\(stdin=stdin, stdout=stdout, backend="cli"\)
return y
.*
\(jdb\) > .*debugger_test\.py\([0-9]+\).*
@jax\.jit
def g\(x\):
-> y = f\(x\)
return jnp\.exp\(y\)
.*
\(jdb\) array\(2\., dtype=float32\)
\(jdb\) > .*debugger_test\.py\([0-9]+\)
def f\(x\):
y = jnp\.sin\(x\)
-> debugger\.breakpoint\(stdin=stdin, stdout=stdout, backend="cli"\)
return y
.*
\(jdb\) """)
g(jnp.array(2., jnp.float32))
jax.effects_barrier()
self.assertRegex(stdout.getvalue(), expected)
@jtu.skip_on_devices(*disabled_backends)
def test_can_use_multiple_breakpoints(self):
stdin, stdout = make_fake_stdin_stdout(["p y", "c", "p y", "c"])
def f(x):
y = x + 1.
debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=True,
backend="cli")
return y
@jax.jit
def g(x):
y = f(x) * 2.
debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=True,
backend="cli")
return jnp.exp(y)
expected = _format_multiline(r"""
Entering jdb:
(jdb) array(3., dtype=float32)
(jdb) Entering jdb:
(jdb) array(6., dtype=float32)
(jdb) """)
g(jnp.array(2., jnp.float32))
jax.effects_barrier()
self.assertEqual(stdout.getvalue(), expected)
@jtu.skip_on_devices(*disabled_backends)
def test_debugger_works_with_vmap(self):
stdin, stdout = make_fake_stdin_stdout(["p y", "c", "p y", "c"])
# On TPU, the breakpoints can be reordered inside of vmap but can be fixed
# by ordering sends.
# TODO(sharadmv): change back to ordered = False when sends are ordered
ordered = jax.default_backend() == "tpu"
def f(x):
y = x + 1.
debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=ordered,
backend="cli")
return 2. * y
@jax.jit
@jax.vmap
def g(x):
y = f(x)
return jnp.exp(y)
expected = _format_multiline(r"""
Entering jdb:
(jdb) array(1., dtype=float32)
(jdb) Entering jdb:
(jdb) array(2., dtype=float32)
(jdb) """)
g(jnp.arange(2., dtype=jnp.float32))
jax.effects_barrier()
self.assertEqual(stdout.getvalue(), expected)
@jtu.skip_on_devices(*disabled_backends)
def test_debugger_works_with_pmap(self):
if jax.local_device_count() < 2:
raise unittest.SkipTest("Test requires >= 2 devices.")
stdin, stdout = make_fake_stdin_stdout(["p y", "c", "p y", "c"])
def f(x):
y = jnp.sin(x)
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y
@jax.pmap
def g(x):
y = f(x)
return jnp.exp(y)
expected = _format_multiline(r"""
Entering jdb:
\(jdb\) array\(.*, dtype=float32\)
\(jdb\) Entering jdb:
\(jdb\) array\(.*, dtype=float32\)
\(jdb\) """)
g(jnp.arange(2., dtype=jnp.float32))
jax.effects_barrier()
self.assertRegex(stdout.getvalue(), expected)
@jtu.skip_on_devices(*disabled_backends)
def test_debugger_works_with_pjit(self):
if jax.default_backend() != "tpu":
raise unittest.SkipTest("`pjit` doesn't work with CustomCall.")
stdin, stdout = make_fake_stdin_stdout(["p y", "c"])
def f(x):
y = x + 1
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
return y
def g(x):
y = f(x)
return jnp.exp(y)
g = pjit.pjit(g, in_axis_resources=pjit.PartitionSpec("dev"),
out_axis_resources=pjit.PartitionSpec("dev"))
with maps.Mesh(np.array(jax.devices()), ["dev"]):
arr = (1 + np.arange(8)).astype(np.int32)
expected = _format_multiline(r"""
Entering jdb:
\(jdb\) {}
\(jdb\) """.format(re.escape(repr(arr))))
g(jnp.arange(8, dtype=jnp.int32))
jax.effects_barrier()
print(stdout.getvalue())
print(expected)
self.assertRegex(stdout.getvalue(), expected)
@jtu.skip_on_devices(*disabled_backends)
def test_debugger_uses_local_before_global_scope(self):
stdin, stdout = make_fake_stdin_stdout(["p foo", "c"])
foo = "outer"
def f(x):
foo = "inner"
debugger.breakpoint(stdin=stdin, stdout=stdout, backend="cli")
del foo
return x
del foo
expected = _format_multiline(r"""
Entering jdb:
\(jdb\) 'inner'
\(jdb\) """)
f(2.)
jax.effects_barrier()
print(stdout.getvalue())
print(expected)
self.assertRegex(stdout.getvalue(), expected)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())