Skip to content

Commit

Permalink
Added correct isinf handling for Integral tensors (#15489)
Browse files Browse the repository at this point in the history
Summary:
Currently torch.isinf on integral tensor will raise RuntimeError: value cannot be converted to type int16_t without overflow: inf.
This pr will suppress the error and return false(0) for all integral tensors. The behavior will also be consistent with np.isinf
Pull Request resolved: pytorch/pytorch#15489

Reviewed By: zou3519

Differential Revision: D13540786

Pulled By: flashhack

fbshipit-source-id: e730dea849da6a59f3752d347bcfbadfd12c6483
  • Loading branch information
frankz-ai authored and facebook-github-bot committed Dec 26, 2018
1 parent d602ddc commit d4712ee
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 2 deletions.
3 changes: 3 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,9 @@ def test_type_conversions_same_gpu(self):
def test_neg(self):
_TestTorchMixin._test_neg(self, lambda t: t.cuda())

def test_isinf(self):
_TestTorchMixin._test_isinf(self, lambda t: t.cuda())

@unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
def test_arithmetic_large_tensor(self):
x = torch.empty(2**30, device='cuda')
Expand Down
18 changes: 16 additions & 2 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5316,9 +5316,23 @@ def test_isfinite_int(self):
x = torch.tensor([1, 2, 3])
self.assertEqual(torch.isfinite(x), torch.ByteTensor([1, 1, 1]))

@staticmethod
def _test_isinf(self, cast):
t1 = cast(torch.Tensor([1, inf, 2, -inf, nan]))
t2 = cast(torch.ByteTensor([1, 2, 3]))
t3 = cast(torch.CharTensor([1, 2, 3]))
t4 = cast(torch.ShortTensor([1, 2, 3]))
t5 = cast(torch.IntTensor([1, 2, 3]))
t6 = cast(torch.LongTensor([1, 2, 3]))
self.assertEqual(torch.isinf(t1), cast(torch.ByteTensor([0, 1, 0, 1, 0])))
self.assertEqual(torch.isinf(t2), cast(torch.ByteTensor([0, 0, 0])))
self.assertEqual(torch.isinf(t3), cast(torch.ByteTensor([0, 0, 0])))
self.assertEqual(torch.isinf(t4), cast(torch.ByteTensor([0, 0, 0])))
self.assertEqual(torch.isinf(t5), cast(torch.ByteTensor([0, 0, 0])))
self.assertEqual(torch.isinf(t6), cast(torch.ByteTensor([0, 0, 0])))

def test_isinf(self):
x = torch.Tensor([1, inf, 2, -inf, nan])
self.assertEqual(torch.isinf(x), torch.ByteTensor([0, 1, 0, 1, 0]))
self._test_isinf(self, lambda t: t)

def test_isnan(self):
x = torch.Tensor([1, nan, 2])
Expand Down
2 changes: 2 additions & 0 deletions torch/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ def isinf(tensor):
"""
if not isinstance(tensor, torch.Tensor):
raise ValueError("The argument is not a tensor", str(tensor))
if tensor.dtype in [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]:
return torch.zeros_like(tensor, dtype=torch.uint8)
return tensor.abs() == inf


Expand Down

0 comments on commit d4712ee

Please sign in to comment.