Skip to content

Commit

Permalink
Add isinf and sign tensor functions (flashlight#703)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: flashlight#703

Modeled after [`isnan`](https://numpy.org/doc/stable/reference/generated/numpy.isnan.html)  - add [`isinf`](https://numpy.org/doc/stable/reference/generated/numpy.isinf.html) - can change naming if others feel strongly.

Adds [`sign`](https://numpy.org/doc/stable/reference/generated/numpy.sign.html) which differs from [ArrayFire's](https://arrayfire.org/docs/group__arith__func__sign.htm) and requires an additional step to zero relevant indices.

Reviewed By: benoitsteiner

Differential Revision: D30022375

fbshipit-source-id: 562de5a36d7c71196f2a7663adaff7a8b50de440
  • Loading branch information
jacobkahn authored and facebook-github-bot committed Aug 7, 2021
1 parent 839da54 commit 9acaa3c
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 0 deletions.
2 changes: 2 additions & 0 deletions flashlight/fl/tensor/TensorBackend.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ class TensorBackend {
virtual Tensor
clip(const Tensor& tensor, const Tensor& low, const Tensor& high) = 0;
virtual Tensor isnan(const Tensor& tensor) = 0;
virtual Tensor isinf(const Tensor& tensor) = 0;
virtual Tensor sign(const Tensor& tensor) = 0;
virtual Tensor
where(const Tensor& condition, const Tensor& x, const Tensor& y) = 0;

Expand Down
8 changes: 8 additions & 0 deletions flashlight/fl/tensor/TensorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,14 @@ Tensor isnan(const Tensor& tensor) {
return tensor.backend().isnan(tensor);
}

Tensor isinf(const Tensor& tensor) {
return tensor.backend().isinf(tensor);
}

Tensor sign(const Tensor& tensor) {
return tensor.backend().sign(tensor);
}

Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y) {
FL_TENSOR_BACKENDS_MATCH_CHECK(condition, x, y);
return condition.backend().where(condition, x, y);
Expand Down
20 changes: 20 additions & 0 deletions flashlight/fl/tensor/TensorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,26 @@ Tensor clip(const Tensor& tensor, const double& low, const double& high);
*/
Tensor isnan(const Tensor& tensor);

/**
* Returns a boolean tensor which is true where the input tensor was infinity,
* and false otherwise.
*
* @param[in] tensor the input tensor
* @return a boolean tensor with true in positions that contained Inf in the
* input tensor
*/
Tensor isinf(const Tensor& tensor);

/**
* Returns a tensor that contains -1 if an element is less than 0, 0 if an
* element is 0, and 1 if an element is greater than zero. Returns NaN for NaN
* values.
*
* @param[in] tensor the input tensor
* @return a tensor containing element-wise sign values.
*/
Tensor sign(const Tensor& tensor);

/**
* Conditionally return elements from one of two tensors based on a condition.
*
Expand Down
10 changes: 10 additions & 0 deletions flashlight/fl/tensor/backend/af/ArrayFireBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,16 @@ Tensor ArrayFireBackend::isnan(const Tensor& tensor) {
return toTensor<ArrayFireTensor>(af::isNaN(toArray(tensor)));
}

Tensor ArrayFireBackend::isinf(const Tensor& tensor) {
return toTensor<ArrayFireTensor>(af::isInf(toArray(tensor)));
}

Tensor ArrayFireBackend::sign(const Tensor& tensor) {
auto wSigned = 1 - 2 * af::sign(toArray(tensor));
wSigned(toArray(tensor) == 0) = 0;
return toTensor<ArrayFireTensor>(std::move(wSigned));
}

Tensor ArrayFireBackend::where(
const Tensor& condition,
const Tensor& x,
Expand Down
2 changes: 2 additions & 0 deletions flashlight/fl/tensor/backend/af/ArrayFireBackend.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ class ArrayFireBackend : public TensorBackend {
Tensor clip(const Tensor& tensor, const Tensor& low, const Tensor& high)
override;
Tensor isnan(const Tensor& tensor) override;
Tensor isinf(const Tensor& tensor) override;
Tensor sign(const Tensor& tensor) override;
Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y)
override;

Expand Down
20 changes: 20 additions & 0 deletions flashlight/fl/test/tensor/TensorBaseTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,26 @@ TEST(TensorBaseTest, isnan) {
fl::full(s, false).astype(fl::dtype::b8)));
}

TEST(TensorBaseTest, isinf) {
Shape s = {3, 3};
ASSERT_TRUE(allClose(
fl::isinf(fl::full(s, 1.) / 3),
fl::full(s, false).astype(fl::dtype::b8)));
ASSERT_TRUE(allClose(
fl::isinf(fl::full(s, 1.) / 0.),
fl::full(s, true).astype(fl::dtype::b8)));
}

TEST(TensorBaseTest, sign) {
auto vals = fl::rand({5, 5}) - 0.5;
vals(2, 2) = 0.;
auto signs = fl::sign(vals);
vals(vals > 0) = 1;
vals(vals == 0) = 0;
vals(vals < 0) = -1;
ASSERT_TRUE(allClose(signs, vals));
}

TEST(TensorBaseTest, where) {
auto a = Tensor::fromVector<int>({2, 5}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
auto out = fl::where(a < 5, a, a * 10);
Expand Down

0 comments on commit 9acaa3c

Please sign in to comment.