Skip to content

Commit

Permalink
Add extra constraints on tests when executing in auto-clustering mode…
Browse files Browse the repository at this point in the history
… to make sure we are testing the right behaviour.

PiperOrigin-RevId: 407617760
Change-Id: I6234ac941d9025e2928954452d04a474a98b35b1
  • Loading branch information
sagunb authored and tensorflower-gardener committed Nov 4, 2021
1 parent eb36589 commit de6fb84
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 1 deletion.
49 changes: 49 additions & 0 deletions tensorflow/python/framework/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,6 +1210,55 @@ def wrapper(func):
return wrapper


def set_xla_env_flag(func=None, flag=""):
"""Decorator for setting XLA_FLAGS prior to running a test.
This function returns a decorator intended to be applied to test methods in
a `tf.test.TestCase` class. Doing so will allow users to set any xla flags
exposed via the XLA_FLAGS environment variable, execute the test, then reset
the XLA_FLAGS to the state it was in prior to this test.
Example:
class MyTest(test.TestCase):
@set_xla_env_flag(flag='--xla_gpu_enable_fast_min_max=false')
def testFoo(self):
...
Args:
func: The function to be wrapped.
flag: The xla flag to be set in the XLA_FLAGS env variable.
Returns:
The wrapped function.
"""

def decorator(f):

@functools.wraps(f)
def decorated(*args, **kwargs):
original_xla_flags = os.environ.get("XLA_FLAGS")
new_xla_flags = flag
if original_xla_flags:
new_xla_flags = new_xla_flags + " " + original_xla_flags
os.environ["XLA_FLAGS"] = new_xla_flags
try:
return f(*args, **kwargs)
finally:
if original_xla_flags is None:
del os.environ["XLA_FLAGS"]
else:
os.environ["XLA_FLAGS"] = original_xla_flags

return decorated

if func is not None:
return decorator(func)

return decorator


def build_as_function_and_v1_graph(func=None):
"""Run a test case in v1 graph mode and inside tf.function in eager mode.
Expand Down
14 changes: 13 additions & 1 deletion tensorflow/python/kernel_tests/array_ops/denormal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
# ==============================================================================
"""Tests for denormal handling."""

import numpy as np
import os
import platform

import numpy as np

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
Expand Down Expand Up @@ -57,4 +59,14 @@ def testFlushDenormalsGPU(self):


if __name__ == "__main__":
# When eager_op_as_function mode is enabled xla auto-clustering kicks in.
# By default xla does not enable flush-to-zero semantics in the GPU backend.
# This env flag has to be set before the test is setup. Setting it using the
# decorator does not seem to propagate the flag to all required locations.
original_xla_flags = os.environ.get("XLA_FLAGS")
new_xla_flags = "--xla_gpu_ftz=true"
if original_xla_flags:
new_xla_flags = new_xla_flags + " " + original_xla_flags
os.environ["XLA_FLAGS"] = new_xla_flags

test.main()
6 changes: 6 additions & 0 deletions tensorflow/python/kernel_tests/math_ops/cwise_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,7 @@ def testShapeMismatch(self):
array_ops.where(c, xt, yt)


@test_util.with_eager_op_as_function
class MinMaxOpTest(test.TestCase):

def _compare(self, x, y, use_gpu):
Expand All @@ -762,6 +763,11 @@ def testBasic(self):
self._compare(x.astype(t), y.astype(t), use_gpu=False)
self._compare(x.astype(t), y.astype(t), use_gpu=True)

# When eager_op_as_function mode is enabled xla auto-clustering kicks in.
# By default xla enables fast min_max computations which do not propagate NaN.
# TODO(b/205140614): remove decorators once TF and XLA behaviour are the same.
@test_util.set_xla_env_flag(flag="--xla_cpu_enable_fast_min_max=false")
@test_util.set_xla_env_flag(flag="--xla_gpu_enable_fast_min_max=false")
def testNaNPropagation(self):
x = np.array([1., np.nan, 1., np.nan], dtype=np.float64)
y = np.array([1., 1., np.nan, np.nan], dtype=np.float64)
Expand Down

0 comments on commit de6fb84

Please sign in to comment.