Skip to content

Commit

Permalink
Revert "Add deterministic path for CUDA cumsum (pytorch#136224)"
Browse files Browse the repository at this point in the history
This reverts commit d1bb8e8.

Reverted pytorch#136224 on behalf of https://github.com/atalman due to Break internal CI ([comment](pytorch#136224 (comment)))
  • Loading branch information
pytorchmergebot committed Sep 27, 2024
1 parent c2637a7 commit e9d2765
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 115 deletions.
24 changes: 1 addition & 23 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1739,33 +1739,11 @@ def test_nondeterministic_alert_EmbeddingBag_max(self, device):
'embedding_bag_backward_cuda_max',
torch.device(device).type == 'cuda')

@skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
@onlyCUDA
def test_deterministic_cumsum(self, device):
test_cases = [
# size, dim
[(2, 3, 4), 0],
[(2, 3, 4), 1],
[(2, 3, 4), 2],
[(1000, 10, 2), 0],
]
for size, dim in test_cases:
input = 100 * torch.randn(*size, device=device)
with DeterministicGuard(True):
res0 = input.cumsum(dim)
for _ in range(3):
res1 = input.cumsum(dim)
self.assertEqual(res0, res1, atol=0, rtol=0)

res_cpu = input.cpu().cumsum(dim)
self.assertEqual(res0, res_cpu, atol=1e-3, rtol=1e-2)


@dtypes(*all_types_and_complex_and(torch.bool))
@skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707")
def test_nondeterministic_alert_cumsum(self, device, dtype):
input = make_tensor((10,), dtype=dtype, device=device, low=-9, high=9)
should_alert = False
should_alert = torch.device(device).type == 'cuda' and (dtype.is_floating_point or dtype.is_complex)

for op_call in [torch.Tensor.cumsum, torch.cumsum]:
self.check_nondeterministic_alert(
Expand Down
1 change: 0 additions & 1 deletion tools/pyi/gen_pyi.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool:
"requires_grad",
"range",
# defined in functional
"cumsum",
"einsum",
# Somehow, these are defined in both _C and in functional. Ick!
"broadcast_tensors",
Expand Down
2 changes: 1 addition & 1 deletion torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,7 +1235,6 @@ def use_deterministic_algorithms(
* :func:`torch.Tensor.put_` with ``accumulate=True`` when called on a CPU
tensor
* :func:`torch.Tensor.scatter_add_` when called on a CUDA tensor
* :func:`torch.cumsum` when called on a CUDA tensor
* :func:`torch.gather` when called on a CUDA tensor that requires grad
* :func:`torch.index_add` when called on CUDA tensor
* :func:`torch.index_select` when attempting to differentiate a CUDA tensor
Expand Down Expand Up @@ -1282,6 +1281,7 @@ def use_deterministic_algorithms(
* :func:`torch.kthvalue` with called on a CUDA tensor
* :func:`torch.median` with indices output when called on a CUDA tensor
* :func:`torch.nn.functional.grid_sample` when attempting to differentiate a CUDA tensor
* :func:`torch.cumsum` when called on a CUDA tensor when dtype is floating point or complex
* :func:`torch.Tensor.scatter_reduce` when ``reduce='prod'`` and called on CUDA tensor
* :func:`torch.Tensor.resize_` when called with a quantized tensor
Expand Down
31 changes: 0 additions & 31 deletions torch/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,37 +846,6 @@ def symeig(self, eigenvectors=False):

return _symeig(self, eigenvectors=eigenvectors)

def cumsum(
self,
dim=None,
*,
dtype=None,
out=None,
axis=None,
):
r"""
cumsum(dim, dtype=None) -> Tensor
See :func:`torch.cumsum`
"""
if axis is not None and dim is not None:
raise RuntimeError("expected either 'dim' or 'axis' to be given, not both")
elif axis is not None:
dim = axis
if has_torch_function_unary(self):
return handle_torch_function(
Tensor.cumsum,
(self,),
self,
dim,
dtype=dtype,
out=out,
)
if out is None:
return torch.cumsum(self, dim, dtype=dtype)
else:
return torch.cumsum(self, dim, dtype=dtype, out=out)

def lu(self, pivot=True, get_infos=False):
r"""See :func:`torch.lu`"""
# If get_infos is True, then we don't need to check for errors and vice versa
Expand Down
9 changes: 9 additions & 0 deletions torch/_tensor_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1497,6 +1497,15 @@ def add_docstr_all(method, docstr):
""",
)

add_docstr_all(
"cumsum",
r"""
cumsum(dim, dtype=None) -> Tensor
See :func:`torch.cumsum`
""",
)

add_docstr_all(
"cumsum_",
r"""
Expand Down
32 changes: 32 additions & 0 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3317,6 +3317,38 @@ def merge_dicts(*dicts):
""".format(**reduceops_common_args),
)

add_docstr(
torch.cumsum,
r"""
cumsum(input, dim, *, dtype=None, out=None) -> Tensor
Returns the cumulative sum of elements of :attr:`input` in the dimension
:attr:`dim`.
For example, if :attr:`input` is a vector of size N, the result will also be
a vector of size N, with elements.
.. math::
y_i = x_1 + x_2 + x_3 + \dots + x_i
Args:
{input}
dim (int): the dimension to do the operation over
Keyword args:
{dtype}
{out}
Example::
>>> a = torch.randint(1, 20, (10,))
>>> a
tensor([13, 7, 3, 10, 13, 3, 15, 10, 9, 10])
>>> torch.cumsum(a, dim=0)
tensor([13, 20, 23, 33, 46, 49, 64, 74, 83, 93])
""".format(**reduceops_common_args),
)

add_docstr(
torch.count_nonzero,
r"""
Expand Down
58 changes: 0 additions & 58 deletions torch/functional.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
import importlib
import itertools
import operator
from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
Expand Down Expand Up @@ -29,7 +28,6 @@
"block_diag",
"cdist",
"chain_matmul",
"cumsum",
"einsum",
"istft",
"lu",
Expand Down Expand Up @@ -2037,62 +2035,6 @@ def chain_matmul(*matrices, out=None):
return _VF.chain_matmul(matrices, out=out) # type: ignore[attr-defined]


def cumsum(
self: Tensor,
dim: Optional[int] = None,
*,
dtype: Optional[torch.dtype] = None,
out: Optional[Tensor] = None,
axis: Optional[int] = None,
):
r"""
cumsum(input, dim, *, dtype=None, out=None) -> Tensor
Returns the cumulative sum of elements of :attr:`input` in the dimension
:attr:`dim`.
For example, if :attr:`input` is a vector of size N, the result will also be
a vector of size N, with elements.
.. math::
y_i = x_1 + x_2 + x_3 + \dots + x_i
Args:
input (Tensor): the input tensor.
dim (int): the dimension to do the operation over
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
If specified, the input tensor is casted to :attr:`dtype` before the operation
is performed. This is useful for preventing data type overflows. Default: None.
out (Tensor, optional): the output tensor.
Example::
>>> torch.manual_seed(0)
>>> a = torch.randint(1, 20, (10,))
>>> a
tensor([16, 5, 1, 1, 12, 8, 6, 10, 10, 5])
>>> torch.cumsum(a, dim=0)
tensor([16, 21, 22, 23, 35, 43, 49, 59, 69, 74])
"""
if axis is not None:
if dim is None:
dim = axis
else:
raise RuntimeError("expected either 'dim' or 'axis' to be given, not both")
if has_torch_function_unary(self):
return handle_torch_function(cumsum, (self,), self, dim, dtype=dtype, out=out)
if not torch.jit.is_scripting():
if torch.are_deterministic_algorithms_enabled() and self.is_cuda:
ref_func = importlib.import_module("torch._refs").cumsum
return ref_func(self, dim, dtype=dtype, out=out)
if out is None:
return _VF.cumsum(self, dim, dtype=dtype) # type: ignore[attr-defined]
else:
return _VF.cumsum(self, dim, dtype=dtype, out=out) # type: ignore[attr-defined]


def _lu_impl(A, pivot=True, get_infos=False, out=None):
# type: (Tensor, bool, bool, Any) -> Tuple[Tensor, Tensor, Tensor]
r"""Computes the LU factorization of a matrix or batches of matrices
Expand Down
2 changes: 1 addition & 1 deletion torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.cummax: lambda input, dim, out=None: -1,
torch.cummin: lambda input, dim, out=None: -1,
torch.cumprod: lambda input, dim, out=None, dtype=None: -1,
torch.cumsum: lambda input, dim, out=None, dtype=None, axis=None: -1,
torch.cumsum: lambda input, dim, out=None, dtype=None: -1,
torch.cumulative_trapezoid: lambda y, x=None, dim=-1: -1,
torch.logcumsumexp: lambda input, dim, out=None: -1,
torch.deg2rad: lambda input, out=None: -1,
Expand Down

0 comments on commit e9d2765

Please sign in to comment.