Skip to content

Commit

Permalink
Remove reliance on Tensorflow enable_numpy_behavior for some ops. (ke…
Browse files Browse the repository at this point in the history
…ras-team#18994)

keras-team#18947

Some ops would only work properly if `enable_numpy_behavior` was on. This reimplements these ops using standard tf operations.
  • Loading branch information
hertschuh authored Dec 26, 2023
1 parent df10cb2 commit 8e897fb
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 29 deletions.
57 changes: 42 additions & 15 deletions keras/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,17 +809,30 @@ def isclose(x1, x2):

@sparse.densifying_unary(True)
def isfinite(x):
return tfnp.isfinite(x)
# `tfnp.isfinite` requires `enable_numpy_behavior`, so we reimplement it.
x = convert_to_tensor(x)
dtype_as_dtype = tf.as_dtype(x.dtype)
if dtype_as_dtype.is_integer or not dtype_as_dtype.is_numeric:
return tf.ones(x.shape, tf.bool)
return tf.math.is_finite(x)


def isinf(x):
# TODO: tfnp.isinf will get python bool when input is a scalar, so we
# need the extra `convert_to_tensor`
return convert_to_tensor(tfnp.isinf(x))
# `tfnp.isinf` requires `enable_numpy_behavior`, so we reimplement it.
x = convert_to_tensor(x)
dtype_as_dtype = tf.as_dtype(x.dtype)
if dtype_as_dtype.is_integer or not dtype_as_dtype.is_numeric:
return tf.zeros(x.shape, tf.bool)
return tf.math.is_inf(x)


def isnan(x):
return tfnp.isnan(x)
# `tfnp.isnan` requires `enable_numpy_behavior`, so we reimplement it.
x = convert_to_tensor(x)
dtype_as_dtype = tf.as_dtype(x.dtype)
if dtype_as_dtype.is_integer or not dtype_as_dtype.is_numeric:
return tf.zeros(x.shape, tf.bool)
return tf.math.is_nan(x)


def less(x1, x2):
Expand Down Expand Up @@ -1039,23 +1052,20 @@ def moveaxis(x, source, destination):

def nan_to_num(x):
x = convert_to_tensor(x)
dtype = standardize_dtype(x.dtype)

# tf.bool doesn't support max and min
if dtype == "bool":
x = tf.where(tfnp.isnan(x), tf.constant(False, x.dtype), x)
x = tf.where(tfnp.isinf(x) & (x > 0), tf.constant(True, x.dtype), x)
x = tf.where(tfnp.isinf(x) & (x < 0), tf.constant(False, x.dtype), x)
dtype = x.dtype
dtype_as_dtype = tf.as_dtype(dtype)
if dtype_as_dtype.is_integer or not dtype_as_dtype.is_numeric:
return x

# Replace NaN with 0
x = tf.where(tfnp.isnan(x), tf.constant(0, x.dtype), x)
x = tf.where(tf.math.is_nan(x), tf.constant(0, dtype), x)

# Replace positive infinitiy with dtype.max
x = tf.where(tfnp.isinf(x) & (x > 0), tf.constant(x.dtype.max, x.dtype), x)
x = tf.where(tf.math.is_inf(x) & (x > 0), tf.constant(dtype.max, dtype), x)

# Replace negative infinity with dtype.min
x = tf.where(tfnp.isinf(x) & (x < 0), tf.constant(x.dtype.min, x.dtype), x)
x = tf.where(tf.math.is_inf(x) & (x < 0), tf.constant(dtype.min, dtype), x)

return x

Expand Down Expand Up @@ -1439,7 +1449,24 @@ def tensordot(x1, x2, axes=2):

@sparse.elementwise_unary
def round(x, decimals=0):
return tfnp.round(x, decimals=decimals)
# `tfnp.round` requires `enable_numpy_behavior`, so we reimplement it.
if decimals == 0:
return tf.round(x)
x_dtype = x.dtype
if tf.as_dtype(x_dtype).is_integer:
# int
if decimals > 0:
return x
# temporarilaly convert to floats
factor = tf.cast(math.pow(10, decimals), config.floatx())
x = tf.cast(x, config.floatx())
else:
# float
factor = tf.cast(math.pow(10, decimals), x.dtype)
x = tf.multiply(x, factor)
x = tf.round(x)
x = tf.divide(x, factor)
return tf.cast(x, x_dtype)


def tile(x, repeats):
Expand Down
5 changes: 0 additions & 5 deletions keras/metrics/confusion_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pytest
from absl import logging
from absl.testing import parameterized
from tensorflow.python.ops.numpy_ops import np_config

from keras import layers
from keras import metrics
Expand All @@ -13,10 +12,6 @@
from keras import testing
from keras.metrics import metrics_utils

# TODO: remove reliance on this (or alternatively, turn it on by default).
# This is no longer needed with tf-nightly.
np_config.enable_numpy_behavior()


class FalsePositivesTest(testing.TestCase):
def test_config(self):
Expand Down
4 changes: 0 additions & 4 deletions keras/ops/nn_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
import pytest
from absl.testing import parameterized
from tensorflow.python.ops.numpy_ops import np_config

from keras import backend
from keras import layers
Expand All @@ -25,9 +24,6 @@
from keras.ops import numpy as knp
from keras.testing.test_utils import named_product

# TODO: remove reliance on this (or alternatively, turn it on by default).
np_config.enable_numpy_behavior()


class NNOpsDynamicShapeTest(testing.TestCase, parameterized.TestCase):
def test_relu(self):
Expand Down
8 changes: 3 additions & 5 deletions keras/ops/numpy_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import contextlib
import itertools
import math

import numpy as np
import pytest
from absl.testing import parameterized
from tensorflow.python.ops.numpy_ops import np_config

import keras
from keras import backend
Expand All @@ -15,9 +15,6 @@
from keras.ops import numpy as knp
from keras.testing.test_utils import named_product

# TODO: remove reliance on this (or alternatively, turn it on by default).
np_config.enable_numpy_behavior()


class NumpyTwoInputOpsDynamicShapeTest(testing.TestCase):
def test_add(self):
Expand Down Expand Up @@ -4013,7 +4010,8 @@ def create_sparse_tensor(x, indices_from=None, start=0, delta=2):
if indices_from is not None:
indices = indices_from.indices
else:
flat_indices = np.arange(start, x.size, delta)
size = math.prod(x.shape)
flat_indices = np.arange(start, size, delta)
indices = np.stack(np.where(np.ones_like(x)), axis=1)[flat_indices]

if backend.backend() == "tensorflow":
Expand Down

0 comments on commit 8e897fb

Please sign in to comment.