Skip to content

Commit

Permalink
Remove optional type for ord parameter in vector_norm (pytorch#57662)
Browse files Browse the repository at this point in the history
Summary:
As per discussion here pytorch#57127 (comment)

Note that we cannot remove the optional type from the `dim` parameter because the default is to flatten the input tensor which cannot be easily captured by a value other than `None`

### BC Breaking Note
This PR changes the `ord` parameter of `torch.linalg.vector_norm` so that it no longer accepts `None` arguments. The default behavior of `2` is equivalent to the previous default of `None`.

Pull Request resolved: pytorch#57662

Reviewed By: albanD, mruberry

Differential Revision: D28228870

Pulled By: heitorschueroff

fbshipit-source-id: 040fd8055bbe013f64d3c8409bbb4b2c87c99d13
  • Loading branch information
heitorschueroff authored and facebook-github-bot committed May 7, 2021
1 parent cb1272a commit 1f1e2da
Show file tree
Hide file tree
Showing 11 changed files with 37 additions and 35 deletions.
15 changes: 8 additions & 7 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2253,7 +2253,8 @@ static Tensor& linalg_norm_out_impl(Tensor& result, const Tensor& self, const op
// 'ord' is int or None
std::vector<int64_t> dim_ = opt_dim.has_value() ? opt_dim.value().vec() : make_dim_list(ndim);
if (!opt_num_ord.has_value() || dim_.size() == 1) {
Tensor result_ = at::linalg_vector_norm(self, opt_num_ord, opt_dim, keepdim, opt_dtype);
Tensor result_ = at::linalg_vector_norm(
self, opt_num_ord.value_or(2), opt_dim, keepdim, opt_dtype);
// TODO: Resize and copy should be avoided with
// https://github.com/pytorch/pytorch/issues/52712
at::native::resize_output(result, result_.sizes());
Expand All @@ -2268,11 +2269,11 @@ static Tensor& linalg_norm_out_impl(Tensor& result, const Tensor& self, const op
return result;
}

static Tensor& linalg_vector_norm_impl(const Tensor& self, const optional<Scalar>& opt_ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype, Tensor& result) {
static Tensor& linalg_vector_norm_impl(const Tensor& self, const Scalar& scalar_ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype, Tensor& result) {
// Casting a large integer to a double will introduce some error, but for
// practical purposes, it won't matter since a large order will usually
// give an infinite result
auto ord = opt_ord.value_or(2).toDouble();
auto ord = scalar_ord.toDouble();

TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA,
"linalg.vector_norm only supports CPU and CUDA device types, but got: ",
Expand Down Expand Up @@ -2343,14 +2344,14 @@ static Tensor& linalg_vector_norm_impl(const Tensor& self, const optional<Scalar
return result;
}

Tensor linalg_vector_norm(const Tensor& self, const optional<Scalar>& opt_ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
Tensor linalg_vector_norm(const Tensor& self, const Scalar& ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
ScalarType out_dtype = opt_dtype.value_or(toValueType(self.scalar_type()));
Tensor result = create_reduction_result(self, opt_dim.value_or(IntArrayRef{}), keepdim, out_dtype);
return at::native::linalg_vector_norm_impl(self, opt_ord, opt_dim, keepdim, opt_dtype, result);
return at::native::linalg_vector_norm_impl(self, ord, opt_dim, keepdim, opt_dtype, result);
}

Tensor& linalg_vector_norm_out(const Tensor& self, const optional<Scalar>& opt_ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype, Tensor& result) {
return at::native::linalg_vector_norm_impl(self, opt_ord, opt_dim, keepdim, opt_dtype, result);
Tensor& linalg_vector_norm_out(const Tensor& self, const Scalar& ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype, Tensor& result) {
return at::native::linalg_vector_norm_impl(self, ord, opt_dim, keepdim, opt_dtype, result);
}

// Numerical or None norms
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9577,13 +9577,13 @@
python_module: linalg
variants: function

- func: linalg_vector_norm(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
- func: linalg_vector_norm(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
python_module: linalg
variants: function
dispatch:
CPU, CUDA: linalg_vector_norm

- func: linalg_vector_norm.out(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
- func: linalg_vector_norm.out(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
python_module: linalg
dispatch:
CPU, CUDA: linalg_vector_norm_out
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
("aten::_embedding_bag_dense_backward", datetime.date(9999, 1, 1)),
("aten::_amp_update_scale", datetime.date(2021, 6, 1)),
("aten::randperm", datetime.date(9999, 1, 1)),
("aten::linalg_vector_norm", datetime.date(2021, 5, 15)),
]

def allow_listed(schema, allow_list):
Expand Down
10 changes: 5 additions & 5 deletions test/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,7 +1336,7 @@ def get_value_type(dtype):
def test_vector_norm(self, device, dtype):
# This test compares torch.linalg.vector_norm's output with
# torch.linalg.norm given a flattened tensor
ord_vector = [0, 0.9, 1, 2, 3, inf, -0.5, -1, -2, -3, -inf, None]
ord_vector = [0, 0.9, 1, 2, 3, inf, -0.5, -1, -2, -3, -inf]
input_sizes = [
(10, ),
(4, 5),
Expand All @@ -1362,7 +1362,7 @@ def run_test_case(input, ord, dim, keepdim, norm_dtype):
msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}, norm_dtype={norm_dtype}'
error_msg = None
if input.numel() == 0:
if ord is not None and ord < 0:
if ord < 0:
error_msg = r'linalg.vector_norm of negative order cannot be performed on an empty tensor'
elif ord == inf and (dim is None or input.size(dim) == 0):
error_msg = (
Expand Down Expand Up @@ -1478,7 +1478,7 @@ def run_test_case(input, p, dim, keepdim):
torch.linalg.norm(input, ord, dim, keepdim, out=result_out)
self.assertEqual(result, result_out, msg=msg)

ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf, None]
ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf]
S = 10
test_cases = [
# input size, p settings, dim
Expand Down Expand Up @@ -1506,7 +1506,7 @@ def run_test_case(input, p, dim, keepdim):
@dtypes(torch.float, torch.double)
@precisionOverride({torch.float32: 2e-5})
def test_norm_matrix(self, device, dtype):
def run_test_case(input, p, dim, keepdim):
def run_test_case(input, ord, dim, keepdim):
result = torch.linalg.norm(input, ord, dim, keepdim)
input_numpy = input.cpu().numpy()
result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim)
Expand All @@ -1518,7 +1518,7 @@ def run_test_case(input, p, dim, keepdim):
torch.linalg.norm(input, ord, dim, keepdim, out=result_out)
self.assertEqual(result, result_out, msg=msg)

ord_matrix = [1, -1, 2, -2, inf, -inf, 'nuc', 'fro', None]
ord_matrix = [1, -1, 2, -2, inf, -inf, 'nuc', 'fro']
S = 10
test_cases = [
# input size, p settings, dim
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,7 @@
- name: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor
self: norm_backward(grad, self.to(grad.scalar_type()), p, result, dim, keepdim)

- name: linalg_vector_norm(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
- name: linalg_vector_norm(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
self: linalg_vector_norm_backward(grad, self, ord, result, dim, keepdim)

- name: _pdist_forward(Tensor self, float p=2) -> Tensor
Expand Down
16 changes: 8 additions & 8 deletions torch/csrc/api/include/torch/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ inline Tensor& norm_out(Tensor& result, const Tensor& self, std::string ord, opt
return torch::linalg_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype);
}

inline Tensor vector_norm(const Tensor& self, optional<Scalar> opt_ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
return torch::linalg_vector_norm(self, opt_ord, opt_dim, keepdim, opt_dtype);
inline Tensor vector_norm(const Tensor& self, Scalar ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
return torch::linalg_vector_norm(self, ord, opt_dim, keepdim, opt_dtype);
}

inline Tensor& vector_norm_out(Tensor& result, const Tensor& self, optional<Scalar> opt_ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
return torch::linalg_vector_norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype);
inline Tensor& vector_norm_out(Tensor& result, const Tensor& self, Scalar ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
return torch::linalg_vector_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype);
}

inline Tensor matrix_power(const Tensor& self, int64_t n) {
Expand Down Expand Up @@ -315,12 +315,12 @@ inline Tensor& norm_out(Tensor& result, const Tensor& self, std::string ord, opt
}

/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.vector_norm
inline Tensor vector_norm(const Tensor& self, optional<Scalar> opt_ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
return detail::vector_norm(self, opt_ord, opt_dim, keepdim, opt_dtype);
inline Tensor vector_norm(const Tensor& self, Scalar ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
return detail::vector_norm(self, ord, opt_dim, keepdim, opt_dtype);
}

inline Tensor& vector_norm_out(Tensor& result, const Tensor& self, optional<Scalar> opt_ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
return detail::vector_norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype);
inline Tensor& vector_norm_out(Tensor& result, const Tensor& self, Scalar ord, optional<IntArrayRef> opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
return detail::vector_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype);
}

/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.matrix_power
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/autograd/FunctionsManual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,10 @@ Tensor norm_backward(Tensor grad, const Tensor& self, const optional<Scalar> & p
return self_scaled * scale_v;
}

Tensor linalg_vector_norm_backward(Tensor grad, const Tensor& self, const optional<Scalar>& opt_ord, Tensor norm, const optional<IntArrayRef>& opt_dim, bool keepdim) {
Tensor linalg_vector_norm_backward(Tensor grad, const Tensor& self, const Scalar& scalar_ord, Tensor norm, const optional<IntArrayRef>& opt_dim, bool keepdim) {
size_t ndim = self.sizes().size();
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto ord = opt_ord.value_or(2.0).toDouble();
auto ord = scalar_ord.toDouble();
auto dim = opt_dim.value_or(IntArrayRef({}));
Tensor self_scaled;
Tensor scale_v;
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/autograd/FunctionsManual.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Tensor restore_reduced_dims(const Tensor &output, IntArrayRef dims, bool keepdim
Tensor scale_grad_by_count(const Tensor &grad, const Tensor &mask, IntArrayRef dims);
at::Tensor norm_backward(const at::Tensor & grad, const at::Tensor & self, const optional<at::Scalar> & p_, const at::Tensor & norm);
at::Tensor norm_backward(at::Tensor grad, const at::Tensor & self, const optional<at::Scalar> & p_, at::Tensor norm, at::IntArrayRef dim, bool keepdim);
at::Tensor linalg_vector_norm_backward(at::Tensor grad, const at::Tensor & self, const optional<at::Scalar> & opt_ord, at::Tensor norm, const c10::optional<at::IntArrayRef> & opt_dim, bool keepdim);
at::Tensor linalg_vector_norm_backward(at::Tensor grad, const at::Tensor & self, const at::Scalar & ord, at::Tensor norm, const c10::optional<at::IntArrayRef> & opt_dim, bool keepdim);
at::Tensor pow_backward(at::Tensor grad, const at::Tensor & self, const at::Scalar & exponent_);
at::Tensor pow_backward_self(at::Tensor grad, const at::Tensor & self, const at::Tensor & exponent);
at::Tensor pow_backward_exponent(at::Tensor grad, const at::Tensor& self, const at::Tensor& exponent, at::Tensor result);
Expand Down
6 changes: 3 additions & 3 deletions torch/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@
""")

vector_norm = _add_docstr(_linalg.linalg_vector_norm, r"""
linalg.vector_norm(A, ord=None, dim=None, keepdim=False, *, dtype=None, out=None) -> Tensor
linalg.vector_norm(A, ord=2, dim=None, keepdim=False, *, dtype=None, out=None) -> Tensor
Computes a vector norm.
Expand All @@ -966,7 +966,7 @@
====================== ========================================================
:attr:`ord` vector norm
====================== ========================================================
`None` (default) `2`-norm
`2` (default) `2`-norm
`inf` `max(abs(x))`
`-inf` `min(abs(x))`
`0` `sum(x != 0)`
Expand All @@ -977,7 +977,7 @@
Args:
A (Tensor): tensor of shape `(*, n)` where `*` is zero or more batch dimensions.
ord (int, float, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `None`
ord (int, float, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `2`
dim (int, Tuple[int], optional): dimensions over which to compute
the norm. See above for the behavior when :attr:`dim`\ `= None`.
Default: `None`
Expand Down
2 changes: 1 addition & 1 deletion torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.nonzero: lambda input, as_tuple=False: -1,
torch.norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1,
torch.linalg.norm: lambda input, ord=None, dim=None, keepdim=False, out=None, dtype=None: -1,
torch.linalg.vector_norm: lambda input, ord=None, dim=None, keepdim=False, out=None, dtype=None: -1,
torch.linalg.vector_norm: lambda input, ord=2, dim=None, keepdim=False, out=None, dtype=None: -1,
torch.norm_except_dim: lambda v, pow=2, dim=0: -1,
torch.nuclear_norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1,
torch.numel: lambda input: -1,
Expand Down
10 changes: 5 additions & 5 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,8 +612,8 @@ def sample_inputs_linalg_vector_norm(op_info, device, dtype, requires_grad, **kw

test_cases = [
# input size, ord, dim args
(size_1D, None, None),
(size_1D, None, (0,)),
(size_1D, 2, None),
(size_1D, 2, (0,)),
(size_1D, 0, None),
(size_1D, 0, (0,)),
(size_1D, 0.9, None),
Expand All @@ -627,9 +627,9 @@ def sample_inputs_linalg_vector_norm(op_info, device, dtype, requires_grad, **kw
(size_1D, -inf, None),
(size_1D, -inf, (0,)),

(size_2D, None, None),
(size_2D, None, (0,)),
(size_2D, None, (-1, 0)),
(size_2D, 2, None),
(size_2D, 2, (0,)),
(size_2D, 2, (-1, 0)),
(size_2D, 0, None),
(size_2D, 0, (0,)),
(size_2D, 0, (-1, 0)),
Expand Down

0 comments on commit 1f1e2da

Please sign in to comment.