diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 419ce12ee9fd39..9138a662d4dcf8 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -1210,6 +1210,55 @@ def wrapper(func): return wrapper +def set_xla_env_flag(func=None, flag=""): + """Decorator for setting XLA_FLAGS prior to running a test. + + This function returns a decorator intended to be applied to test methods in + a `tf.test.TestCase` class. Doing so will allow users to set any xla flags + exposed via the XLA_FLAGS environment variable, execute the test, then reset + the XLA_FLAGS to the state it was in prior to this test. + + Example: + + class MyTest(test.TestCase): + + @set_xla_env_flag(flag='--xla_gpu_enable_fast_min_max=false') + def testFoo(self): + ... + + Args: + func: The function to be wrapped. + flag: The xla flag to be set in the XLA_FLAGS env variable. + + Returns: + The wrapped function. + """ + + def decorator(f): + + @functools.wraps(f) + def decorated(*args, **kwargs): + original_xla_flags = os.environ.get("XLA_FLAGS") + new_xla_flags = flag + if original_xla_flags: + new_xla_flags = new_xla_flags + " " + original_xla_flags + os.environ["XLA_FLAGS"] = new_xla_flags + try: + return f(*args, **kwargs) + finally: + if original_xla_flags is None: + del os.environ["XLA_FLAGS"] + else: + os.environ["XLA_FLAGS"] = original_xla_flags + + return decorated + + if func is not None: + return decorator(func) + + return decorator + + def build_as_function_and_v1_graph(func=None): """Run a test case in v1 graph mode and inside tf.function in eager mode. diff --git a/tensorflow/python/kernel_tests/array_ops/denormal_test.py b/tensorflow/python/kernel_tests/array_ops/denormal_test.py index f746be82411ef1..120df9bc238fa2 100644 --- a/tensorflow/python/kernel_tests/array_ops/denormal_test.py +++ b/tensorflow/python/kernel_tests/array_ops/denormal_test.py @@ -14,9 +14,11 @@ # ============================================================================== """Tests for denormal handling.""" -import numpy as np +import os import platform +import numpy as np + from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util from tensorflow.python.platform import test @@ -57,4 +59,14 @@ def testFlushDenormalsGPU(self): if __name__ == "__main__": + # When eager_op_as_function mode is enabled xla auto-clustering kicks in. + # By default xla does not enable flush-to-zero semantics in the GPU backend. + # This env flag has to be set before the test is setup. Setting it using the + # decorator does not seem to propagate the flag to all required locations. + original_xla_flags = os.environ.get("XLA_FLAGS") + new_xla_flags = "--xla_gpu_ftz=true" + if original_xla_flags: + new_xla_flags = new_xla_flags + " " + original_xla_flags + os.environ["XLA_FLAGS"] = new_xla_flags + test.main() diff --git a/tensorflow/python/kernel_tests/math_ops/cwise_ops_test.py b/tensorflow/python/kernel_tests/math_ops/cwise_ops_test.py index 5b2eef928f1167..1b79d4857e0fcc 100644 --- a/tensorflow/python/kernel_tests/math_ops/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/math_ops/cwise_ops_test.py @@ -741,6 +741,7 @@ def testShapeMismatch(self): array_ops.where(c, xt, yt) +@test_util.with_eager_op_as_function class MinMaxOpTest(test.TestCase): def _compare(self, x, y, use_gpu): @@ -762,6 +763,11 @@ def testBasic(self): self._compare(x.astype(t), y.astype(t), use_gpu=False) self._compare(x.astype(t), y.astype(t), use_gpu=True) + # When eager_op_as_function mode is enabled xla auto-clustering kicks in. + # By default xla enables fast min_max computations which do not propagate NaN. + # TODO(b/205140614): remove decorators once TF and XLA behaviour are the same. + @test_util.set_xla_env_flag(flag="--xla_cpu_enable_fast_min_max=false") + @test_util.set_xla_env_flag(flag="--xla_gpu_enable_fast_min_max=false") def testNaNPropagation(self): x = np.array([1., np.nan, 1., np.nan], dtype=np.float64) y = np.array([1., 1., np.nan, np.nan], dtype=np.float64)