Skip to content

Commit

Permalink
Revert D29794958 + compilation fix (pytorch#61937)
Browse files Browse the repository at this point in the history
Summary:
This PR un-reverts pytorch#61475 + fixes compilation with MSVC, that does not recognize alternative operator spellings (i.e. using `or` instead of `||` )

Pull Request resolved: pytorch#61937

Reviewed By: albanD

Differential Revision: D29805941

Pulled By: malfet

fbshipit-source-id: 01e5963c6717c1b44b260300d87ba0bf57f26ce9
  • Loading branch information
malfet authored and facebook-github-bot committed Jul 21, 2021
1 parent a152c12 commit 604f503
Show file tree
Hide file tree
Showing 10 changed files with 177 additions and 55 deletions.
2 changes: 2 additions & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ namespace c10 {
_(aten, linalg_householder_product)\
_(aten, transpose) \
_(aten, transpose_) \
_(aten, trapz) \
_(aten, trapezoid) \
_(aten, unsqueeze_) \
_(aten, __getitem__) \
_(aten, _set_item) \
Expand Down
36 changes: 28 additions & 8 deletions aten/src/ATen/native/Integration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <ATen/WrapDimUtils.h>
#include <ATen/core/DimVector.h>
#include <c10/util/Exception.h>
#include <c10/core/ScalarType.h>

namespace at {
namespace native {
Expand All @@ -16,16 +17,17 @@ namespace {
//
// TODO: if we extend TensorIterator to accept 3 inputs,
// we can probably make this a bit more performant.
Tensor do_trapz(const Tensor& y, const Tensor& dx, int64_t dim) {
Tensor do_trapezoid(const Tensor& y, const Tensor& dx, int64_t dim) {
Tensor left = y.slice(dim, 0, -1);
Tensor right = y.slice(dim, 1);

// If the dimensions of 'dx' and '(left + right)' do not match
// broadcasting is attempted here.
return ((left + right) * dx).sum(dim) / 2.;
}

// When dx is constant, the above formula simplifies
// to dx * [(\sum_{i=1}^n y_i) - (y_1 + y_n)/2]
Tensor do_trapz(const Tensor& y, double dx, int64_t dim) {
Tensor do_trapezoid(const Tensor& y, double dx, int64_t dim) {
return (y.sum(dim) - (y.select(dim, 0) + y.select(dim, -1)) * (0.5)) * dx;
}

Expand All @@ -38,35 +40,53 @@ Tensor zeros_like_except(const Tensor& y, int64_t dim) {

}

Tensor trapz(const Tensor& y, const Tensor& x, int64_t dim) {
Tensor trapezoid(const Tensor& y, const Tensor& x, int64_t dim) {
dim = maybe_wrap_dim(dim, y);
// asking for the integral with zero samples is a bit nonsensical,
// but we'll return "0" to match numpy behavior.
if (y.size(dim) == 0) {
return zeros_like_except(y, dim);
}
TORCH_CHECK(y.scalar_type() != kBool && x.scalar_type() != kBool, "trapezoid: received a bool input for `x` or `y`, but bool is not supported")
Tensor x_viewed;
if (x.dim() == 1) {
TORCH_CHECK(x.size(0) == y.size(dim), "trapz: There must be one `x` value for each sample point");
// This step takes 'x' with dimension (n,), and returns 'x_view' with
// dimension (1,1,...,n,...,1,1) based on dim and y.dim() so that 'x'
// can be broadcasted later to match 'y'.
// Note: This behavior differs from numpy in that numpy tries to
// broadcast 'dx', but this tries to broadcast 'x' to match 'y' instead.
TORCH_CHECK(x.size(0) == y.size(dim), "trapezoid: There must be one `x` value for each sample point");
DimVector sizes(y.dim(), 1);
sizes[dim] = x.size(0);
x_viewed = x.view(sizes);
} else {
x_viewed = x;
}
// Note the .slice operation reduces the dimension along 'dim' by 1.
// The sizes of other dimensions are untouched.
Tensor x_left = x_viewed.slice(dim, 0, -1);
Tensor x_right = x_viewed.slice(dim, 1);

Tensor dx = x_right - x_left;
return do_trapz(y, dx, dim);
return do_trapezoid(y, dx, dim);
}

Tensor trapz(const Tensor& y, double dx, int64_t dim) {
Tensor trapezoid(const Tensor& y, const Scalar& dx, int64_t dim) {
// see above
if (y.size(dim) == 0) {
return zeros_like_except(y, dim);
}
return do_trapz(y, dx, dim);
TORCH_CHECK(y.scalar_type() != kBool, "trapezoid: received a bool input for `y`, but bool is not supported")
TORCH_CHECK(!(dx.isComplex() || dx.isBoolean()), "trapezoid: Currently, we only support dx as a real number.");
return do_trapezoid(y, dx.toDouble(), dim);
}

Tensor trapz(const Tensor& y, const Tensor& x, int64_t dim) {
return at::native::trapezoid(y, x, dim);
}

Tensor trapz(const Tensor& y, double dx, int64_t dim) {
return at::native::trapezoid(y, dx, dim);
}

}} // namespace at::native
4 changes: 4 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4289,6 +4289,10 @@
dispatch:
CompositeExplicitAutograd: rot90

- func: trapezoid.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor

- func: trapezoid.dx(Tensor y, *, Scalar dx=1, int dim=-1) -> Tensor

- func: trapz.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor

- func: trapz.dx(Tensor y, *, float dx=1, int dim=-1) -> Tensor
Expand Down
1 change: 1 addition & 0 deletions docs/source/torch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,7 @@ BLAS and LAPACK Operations
symeig
lobpcg
trapz
trapezoid
triangular_solve
vdot

Expand Down
9 changes: 0 additions & 9 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2789,15 +2789,6 @@ def test_cat_empty(self):
lambda a, b: torch.cat((a, b)),
True, f_args_variable, f_args_tensor, check_forward_ad=True)

def test_trapz(self):
f_args_variable = (torch.randn(2, 3, dtype=torch.double, requires_grad=True),
torch.tensor([[1.0, 2.0, 5.5], [2.3, 0.5, 6.2]], dtype=torch.double, requires_grad=True))
f_args_tensor = deepcopy(unpack_variables(f_args_variable))
run_functional_checks(self, "test_trapz", "trapz",
lambda y, x: torch.trapz(y, x),
True, f_args_variable, f_args_tensor)


def test_var_mean_differentiable(self):
dim = [2, 4]
keepdim = False
Expand Down
6 changes: 3 additions & 3 deletions test/test_binary_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2536,17 +2536,17 @@ def _test_atan2(x, y, expected, device, dtype):
_test_atan2(1, -1, math.pi / -4 , device, dtype)
_test_atan2(-1, 1, math.pi * 3 / 4 , device, dtype)

def test_trapz(self, device):
def test_trapezoid(self, device):
def test_dx(sizes, dim, dx, device):
t = torch.randn(sizes, device=device)
actual = torch.trapz(t, dx=dx, dim=dim)
actual = torch.trapezoid(t, dx=dx, dim=dim)
expected = np.trapz(t.cpu().numpy(), dx=dx, axis=dim)
self.assertEqual(expected.shape, actual.shape)
self.assertEqual(expected, actual, exact_dtype=False)

def test_x(sizes, dim, x, device):
t = torch.randn(sizes, device=device)
actual = torch.trapz(t, x=torch.tensor(x, device=device), dim=dim)
actual = torch.trapezoid(t, x=torch.tensor(x, device=device), dim=dim)
expected = np.trapz(t.cpu().numpy(), x=x, axis=dim)
self.assertEqual(expected.shape, actual.shape)
self.assertEqual(expected, actual.cpu(), exact_dtype=False)
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/gen_variable_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
'diag', 'masked_scatter', 'masked_select', 'index_fill', 'trace', 'polar', 'cumsum', 'rsub',
'eig', 'lerp', 'linalg_vector_norm', 'cumprod', 'prod', 'index_copy', 'lu', 'unfold', 'unfold_backward',
'index', 'masked_fill', 'cross', 'lu_unpack', 'renorm', '_conj_physical',
'scatter', 'scatter_add', 'sigmoid', 'sigmoid_backward', 'conj_physical_', '_neg_view'
'scatter', 'scatter_add', 'sigmoid', 'sigmoid_backward', 'trapezoid', 'conj_physical_', '_neg_view'
}

GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {
Expand Down
137 changes: 103 additions & 34 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10724,54 +10724,123 @@ def merge_dicts(*dicts):
[2, 2],
[2, 3],
[3, 3]])
""")

add_docstr(torch.trapz,
add_docstr(torch.trapezoid,
r"""
trapz(y, x, *, dim=-1) -> Tensor
trapezoid(y, x=None, *, dx=None, dim=-1) -> Tensor
Estimate :math:`\int y\,dx` along `dim`, using the trapezoid rule.
Computes the `trapezoidal rule <https://en.wikipedia.org/wiki/Trapezoidal_rule>_ along
:attr:`dim`. By default the spacing between elements is assumed to be 1, but
:attr:`dx` can be used to specify a different constant spacing, and :attr:`x` can be
used to specify arbitrary spacing along :attr:`dim`.
Arguments:
y (Tensor): The values of the function to integrate
x (Tensor): The points at which the function `y` is sampled.
If `x` is not in ascending order, intervals on which it is decreasing
contribute negatively to the estimated integral (i.e., the convention
:math:`\int_a^b f = -\int_b^a f` is followed).
dim (int): The dimension along which to integrate.
By default, use the last dimension.
Returns:
A Tensor with the same shape as the input, except with `dim` removed.
Each element of the returned tensor represents the estimated integral
:math:`\int y\,dx` along `dim`.
Assuming :attr:`y` is a one-dimensional tensor with elements :math:`{y_0, y_1, ..., y_n}`,
the default computation is
Example::
.. math::
\begin{aligned}
\sum_{i = 1}^{n-1} \frac{1}{2} (y_i + y_{i-1})
\end{aligned}
>>> y = torch.randn((2, 3))
>>> y
tensor([[-2.1156, 0.6857, -0.2700],
[-1.2145, 0.5540, 2.0431]])
>>> x = torch.tensor([[1, 3, 4], [1, 2, 3]])
>>> torch.trapz(y, x)
tensor([-1.2220, 0.9683])
When :attr:`dx` is specified the computation becomes
.. math::
\begin{aligned}
\sum_{i = 1}^{n-1} \frac{\Delta x}{2} (y_i + y_{i-1})
\end{aligned}
effectively multiplying the result by :attr:`dx`. When :attr:`x` is specified,
assuming :attr:`x` is also a one-dimensional tensor with
elements :math:`{x_0, x_1, ..., x_n}`, the computation becomes
.. math::
\begin{aligned}
\sum_{i = 1}^{n-1} \frac{(x_i - x_{i-1})}{2} (y_i + y_{i-1})
\end{aligned}
.. function:: trapz(y, *, dx=1, dim=-1) -> Tensor
When :attr:`y` is two or more dimensions, this computation is performed independently
along dimension :attr:`dim`. If :attr:`x` is also specified and is one-dimensional,
then that dimension defines the spacing for each computation.
If :attr:`x` is also specified and is not one-dimensional, then it is broadcast to
the shape of :attr:`y` and the corresponding sizes are used for each computation.
See the examples below for details.
As above, but the sample points are spaced uniformly at a distance of `dx`.
.. note::
The trapezoidal rule is a technique for approximating the definite integral of a function
by averaging its left and right Riemann sums. The approximation becomes more accurate as
the resolution of the partition increases.
Arguments:
y (Tensor): The values of the function to integrate
y (Tensor): Values to use when computing the trapezoidal rule.
x (Tensor): If specified, defines spacing between values as specified above.
Keyword args:
dx (float): The distance between points at which `y` is sampled.
dim (int): The dimension along which to integrate.
By default, use the last dimension.
Keyword arguments:
dx (float): constant spacing between values. If neither :attr:`x` or :attr:`dx`
are specified then this defaults to 1. Effectively multiplies the result by its value.
dim (int): The dimension along which to compute the trapezoidal rule.
The last (inner-most) dimension by default.
Examples::
>>> # Computes the trapezoidal rule in 1D, spacing is implicitly 1
>>> y = torch.tensor([1, 5, 10])
>>> torch.trapezoid(y)
tensor(10.5)
>>> # Computes the same trapezoidal rule directly to verify
>>> (1 + 10 + 10) / 2
10.5
>>> # Computes the trapezoidal rule in 1D with constant spacing of 2
>>> # NOTE: the result is the same as before, but multiplied by 2
>>> torch.trapezoid(y, dx=2)
21.0
>>> # Computes the trapezoidal rule in 1D with arbitrary spacing
>>> x = torch.tensor([1, 3, 6])
>>> torch.trapezoid(y, x)
28.5
>>> # Computes the same trapezoidal rule directly to verify
>>> ((3 - 1) * (1 + 5) + (6 - 3) * (5 + 10)) / 2
28.5
>>> # Computes the trapezoidal rule for each row of a 3x3 matrix
>>> y = torch.arange(9).reshape(3, 3)
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> torch.trapezoid(y)
tensor([ 2., 8., 14.])
>>> # Computes the trapezoidal rule for each column of the matrix
>>> torch.trapezoid(y, dim=0)
tensor([ 6., 8., 10.])
>>> # Computes the trapezoidal rule for each row of a 3x3 ones matrix
>>> # with the same arbitrary spacing
>>> y = torch.ones(3, 3)
>>> x = torch.tensor([1, 3, 6])
>>> torch.trapezoid(y, x)
array([5., 5., 5.])
>>> # Computes the trapezoidal rule for each row of a 3x3 ones matrix
>>> # with different arbitrary spacing per row
>>> y = torch.ones(3, 3)
>>> x = torch.tensor([[1, 2, 3], [1, 3, 5], [1, 4, 7]])
>>> torch.trapezoid(y, x)
array([2., 4., 6.])
""")

add_docstr(torch.trapz,
r"""
trapz(y, x, *, dim=-1) -> Tensor
Alias for :func:`torch.trapezoid`.
Returns:
A Tensor with the same shape as the input, except with `dim` removed.
Each element of the returned tensor represents the estimated integral
:math:`\int y\,dx` along `dim`.
""")

add_docstr(torch.repeat_interleave,
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.trace: lambda input: -1,
torch.transpose: lambda input, dim0, dim1: -1,
torch.trapz: lambda y, x=None, dim=-1: -1,
torch.trapezoid: lambda y, x=None, dim=-1: -1,
torch.triangular_solve: lambda input, A, upper=True, transpose=False, unitriangular=False: -1,
torch.tril: lambda input, diagonal=0, out=None: -1,
torch.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False,
Expand Down
34 changes: 34 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2477,6 +2477,32 @@ def generator():

return list(generator())

def sample_trapezoid(op_info, device, dtype, requires_grad, **kwargs):
y_shape_x_shape_and_kwargs = [
((2, 3), (2, 3), {}),
((2, 3), (2, 3), {'dim': 1}),
((6,), (6,), {}),
((6,), None, {}),
# When 'trapezoid' is called with an empty input, it does not produce an output with requires_grad
# See Issue #{61619}
# ((6,0), (6,0), {}),
((2, 3), (1, 3), {}),
((3, 3), (3, 3), {}),
((3, 3), (3, 3), {'dim': -2}),
((5,), None, {'dx': 2.0}),
((2, 2), None, {'dx': 3.0})
]
samples = []
for y_shape, x_shape, kwarg in y_shape_x_shape_and_kwargs:
y_tensor = make_tensor(y_shape, device, dtype, low=None, high=None,
requires_grad=requires_grad)
if x_shape is not None:
x_tensor = make_tensor(x_shape, device, dtype, low=None, high=None,
requires_grad=requires_grad)
samples.append(SampleInput(y_tensor, args=(x_tensor,), kwargs=kwarg))
else:
samples.append(SampleInput(y_tensor, kwargs=kwarg))
return samples

def sample_unsqueeze(op_info, device, dtype, requires_grad, **kwargs):
shapes_and_axes = [
Expand Down Expand Up @@ -7427,6 +7453,14 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_out=False,
sample_inputs_func=sample_repeat_tile),
OpInfo('trapz', # TODO: in the future, 'trapz' should be made a proper alias of 'trapezoid'
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
supports_out=False,
sample_inputs_func=sample_trapezoid),
OpInfo('trapezoid',
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
supports_out=False,
sample_inputs_func=sample_trapezoid),
OpInfo('unsqueeze',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_out=False,
Expand Down

0 comments on commit 604f503

Please sign in to comment.