forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
logging_test.py
315 lines (268 loc) · 11.3 KB
/
logging_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import io
import logging
import platform
import re
import shlex
import subprocess
import sys
import tempfile
import textwrap
import unittest
import jax
import jax._src.test_util as jtu
from jax._src import xla_bridge
from jax._src.logging_config import _default_TF_CPP_MIN_LOG_LEVEL
# Note: importing absltest causes an extra absl root log handler to be
# registered, which causes extra debug log messages. We don't expect users to
# import absl logging, so it should only affect this test. We need to use
# absltest.main and config.parse_flags_with_absl() in order for jax_test flag
# parsing to work correctly with bazel (otherwise we could avoid importing
# absltest/absl logging altogether).
from absl.testing import absltest
jax.config.parse_flags_with_absl()
@contextlib.contextmanager
def jax_debug_log_modules(value):
# jax_debug_log_modules doesn't have a context manager, because it's
# not thread-safe. But since tests are always single-threaded, we
# can define one here.
original_value = jax.config.jax_debug_log_modules
jax.config.update("jax_debug_log_modules", value)
try:
yield
finally:
jax.config.update("jax_debug_log_modules", original_value)
@contextlib.contextmanager
def jax_logging_level(value):
# jax_logging_level doesn't have a context manager, because it's
# not thread-safe. But since tests are always single-threaded, we
# can define one here.
original_value = jax.config.jax_logging_level
jax.config.update("jax_logging_level", value)
try:
yield
finally:
jax.config.update("jax_logging_level", original_value)
@contextlib.contextmanager
def capture_jax_logs():
log_output = io.StringIO()
handler = logging.StreamHandler(log_output)
logger = logging.getLogger("jax")
logger.addHandler(handler)
try:
yield log_output
finally:
logger.removeHandler(handler)
class LoggingTest(jtu.JaxTestCase):
@unittest.skipIf(platform.system() == "Windows",
"Subprocess test doesn't work on Windows")
def test_no_log_spam(self):
if jtu.is_cloud_tpu() and xla_bridge._backends:
raise self.skipTest(
"test requires fresh process on Cloud TPU because only one process "
"can use the TPU at a time")
if sys.executable is None:
raise self.skipTest("test requires access to python binary")
# Save script in file to fix the problem with
# `tsl::Env::Default()->GetExecutablePath()` not working properly with
# command flag.
with tempfile.NamedTemporaryFile(
mode="w+", encoding="utf-8", suffix=".py"
) as f:
f.write(textwrap.dedent("""
import jax
jax.device_count()
f = jax.jit(lambda x: x + 1)
f(1)
f(2)
jax.numpy.add(1, 1)
"""))
python = sys.executable
assert "python" in python
# Make sure C++ logging is at default level for the test process.
proc = subprocess.run([python, f.name], capture_output=True)
lines = proc.stdout.split(b"\n")
lines.extend(proc.stderr.split(b"\n"))
allowlist = [
b"",
(
b"An NVIDIA GPU may be present on this machine, but a"
b" CUDA-enabled jaxlib is not installed. Falling back to cpu."
),
]
lines = [l for l in lines if l not in allowlist]
self.assertEmpty(lines)
def test_debug_logging(self):
# Warmup so we don't get "No GPU/TPU" warning later.
jax.jit(lambda x: x + 1)(1)
# Nothing logged by default (except warning messages, which we don't expect
# here).
with capture_jax_logs() as log_output:
jax.jit(lambda x: x + 1)(1)
self.assertEmpty(log_output.getvalue())
# Turn on all debug logging.
with jax_debug_log_modules("jax"):
with capture_jax_logs() as log_output:
jax.jit(lambda x: x + 1)(1)
self.assertIn("Finished tracing + transforming", log_output.getvalue())
self.assertIn("Compiling <lambda>", log_output.getvalue())
# Turn off all debug logging.
with jax_debug_log_modules(""):
with capture_jax_logs() as log_output:
jax.jit(lambda x: x + 1)(1)
self.assertEmpty(log_output.getvalue())
# Turn on one module.
with jax_debug_log_modules("jax._src.dispatch"):
with capture_jax_logs() as log_output:
jax.jit(lambda x: x + 1)(1)
self.assertIn("Finished tracing + transforming", log_output.getvalue())
self.assertNotIn("Compiling <lambda>", log_output.getvalue())
# Turn everything off again.
with jax_debug_log_modules(""):
with capture_jax_logs() as log_output:
jax.jit(lambda x: x + 1)(1)
self.assertEmpty(log_output.getvalue())
@jtu.skip_on_devices("tpu")
@unittest.skipIf(platform.system() == "Windows",
"Subprocess test doesn't work on Windows")
def test_subprocess_stderr_info_logging(self):
if sys.executable is None:
raise self.skipTest("test requires access to python binary")
program = """
import jax # this prints INFO logging from backend imports
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
"""
# strip the leading whitespace from the program script
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
# test INFO
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=INFO {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
log_output = p.stderr
info_lines = log_output.split("\n")
self.assertGreater(len(info_lines), 0)
self.assertIn("INFO", log_output)
self.assertNotIn("DEBUG", log_output)
@jtu.skip_on_devices("tpu")
@unittest.skipIf(platform.system() == "Windows",
"Subprocess test doesn't work on Windows")
def test_subprocess_stderr_debug_logging(self):
if sys.executable is None:
raise self.skipTest("test requires access to python binary")
program = """
import jax # this prints INFO logging from backend imports
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
"""
# strip the leading whitespace from the program script
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
# test DEBUG
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
log_output = p.stderr
self.assertIn("INFO", log_output)
self.assertIn("DEBUG", log_output)
# test JAX_DEBUG_MODULES
cmd = shlex.split(f"env JAX_DEBUG_LOG_MODULES=jax {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
log_output = p.stderr
self.assertIn("DEBUG", log_output)
@jtu.skip_on_devices("tpu")
@unittest.skipIf(platform.system() == "Windows",
"Subprocess test doesn't work on Windows")
def test_subprocess_toggling_logging_level(self):
if sys.executable is None:
raise self.skipTest("test requires access to python binary")
_separator = "---------------------------"
program = f"""
import sys
import jax # this prints INFO logging from backend imports
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
jax.config.update("jax_logging_level", None)
sys.stderr.write("{_separator}")
jax.jit(lambda x: x)(1) # should not log anything now
"""
# strip the leading whitespace from the program script
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
log_output = p.stderr
m = re.search(_separator, log_output)
self.assertTrue(m is not None)
log_output_verbose = log_output[:m.start()]
log_output_silent = log_output[m.end():]
self.assertIn("Finished tracing + transforming <lambda> for pjit",
log_output_verbose)
self.assertEqual(log_output_silent, "")
@jtu.skip_on_devices("tpu")
@unittest.skipIf(platform.system() == "Windows",
"Subprocess test doesn't work on Windows")
def test_subprocess_double_logging_absent(self):
if sys.executable is None:
raise self.skipTest("test requires access to python binary")
program = """
import jax # this prints INFO logging from backend imports
jax.config.update("jax_debug_log_modules", "jax._src.compiler,jax._src.dispatch")
jax.jit(lambda x: x)(1) # this prints logs to DEBUG (from compilation)
"""
# strip the leading whitespace from the program script
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
log_output = p.stderr
self.assertNotEmpty(log_output)
log_lines = log_output.strip().split("\n")
# only one tracing line should be printed, if there's more than one
# then logs are printing duplicated
self.assertLen([line for line in log_lines
if "Finished tracing + transforming" in line], 1)
@jtu.skip_on_devices("tpu")
@unittest.skipIf(platform.system() == "Windows",
"Subprocess test doesn't work on Windows")
def test_subprocess_cpp_logging_level(self):
if sys.executable is None:
raise self.skipTest("test requires access to python binary")
program = """
import sys
import jax # this prints INFO logging from backend imports
jax.distributed.initialize("127.0.0.1:12345", num_processes=1, process_id=0)
"""
# strip the leading whitespace from the program script
program = re.sub(r"^\s+", "", program, flags=re.MULTILINE)
# verbose logging: DEBUG, VERBOSE
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=DEBUG {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
self.assertIn("Initializing CoordinationService", p.stderr)
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=INFO {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
self.assertIn("Initializing CoordinationService", p.stderr)
# verbose logging: WARNING, None
cmd = shlex.split(f"env JAX_LOGGING_LEVEL=WARNING {sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
self.assertNotIn("Initializing CoordinationService", p.stderr)
cmd = shlex.split(f"{sys.executable} -c"
f" '{program}'")
p = subprocess.run(cmd, capture_output=True, text=True)
if int(_default_TF_CPP_MIN_LOG_LEVEL) >= 1:
self.assertNotIn("Initializing CoordinationService", p.stderr)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())