Skip to content

Commit

Permalink
multi_margin_loss: check weight shape, make contiguous on CPU, ad…
Browse files Browse the repository at this point in the history
…d tests (pytorch#104852)

Pull Request resolved: pytorch#104852
Approved by: https://github.com/ezyang
  • Loading branch information
nkaretnikov authored and pytorchmergebot committed Jul 14, 2023
1 parent de67b52 commit 0a68882
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 24 deletions.
9 changes: 8 additions & 1 deletion aten/src/ATen/native/LossMulti.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ namespace {
int64_t& dim,
const int64_t& ndims,
const Tensor& input,
const Tensor& target) {
const Tensor& target,
const c10::optional<Tensor>& weight) {
TORCH_CHECK(
(ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
"Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
Expand All @@ -64,6 +65,12 @@ namespace {
target.dim() <= 1 && target.numel() == nframe,
"inconsistent target size, expected ", nframe, " but got ",
target.sizes());
if (weight && weight->defined()) {
TORCH_CHECK(
weight->dim() <= 1 && weight->numel() == dim,
"inconsistent weight size, expected ", dim, " but got ",
weight->sizes());
}
}


Expand Down
26 changes: 12 additions & 14 deletions aten/src/ATen/native/LossMultiMargin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,15 @@ void multi_margin_loss_out_cpu_template(
const Tensor& target,
int p,
const Scalar& margin,
const Tensor& weight,
const c10::optional<Tensor>& weight,
int64_t reduction) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t nframe, dim;
const auto ndims = input.dim();

TORCH_CHECK(p == 1 || p == 2, "only p == 1 and p == 2 supported");

multi_margin_loss_shape_check(nframe, dim, ndims, input, target);
multi_margin_loss_shape_check(nframe, dim, ndims, input, target, weight);

// produce a scalar output for 1d input
if (reduction == Reduction::None && target.dim() > 0) {
Expand All @@ -125,13 +125,17 @@ void multi_margin_loss_out_cpu_template(

auto input_contiguous = input.contiguous();
auto target_contiguous = target.contiguous();
Tensor weight_contiguous;
if (weight && weight->defined()) {
weight_contiguous = weight->contiguous();
}

AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "multi_margin_loss_cpu_kernel", [&] {
auto input_data = input_contiguous.data_ptr<scalar_t>();
auto target_data = target_contiguous.data_ptr<int64_t>();
auto weight_data =
weight.defined() ? weight.data_ptr<scalar_t>() : nullptr;
weight_contiguous.defined() ? weight_contiguous.data_ptr<scalar_t>() : nullptr;
multi_margin_loss_cpu_kernel<scalar_t>(
output,
input_data,
Expand Down Expand Up @@ -219,7 +223,7 @@ void multi_margin_loss_backward_out_cpu_template(

TORCH_CHECK(p == 1 || p == 2, "only p == 1 and p == 2 supported");

multi_margin_loss_shape_check(nframe, dim, ndims, input, target);
multi_margin_loss_shape_check(nframe, dim, ndims, input, target, weight);
grad_input.resize_as_(input);
TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");

Expand Down Expand Up @@ -262,12 +266,9 @@ Tensor multi_margin_loss_cpu(
const Tensor& input,
const Tensor& target,
const Scalar& p,
const Scalar& margin, const c10::optional<Tensor>& weight_opt,
const Scalar& margin,
const c10::optional<Tensor>& weight,
int64_t reduction) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;

auto output = at::empty({0}, input.options());
multi_margin_loss_out_cpu_template(
output, input, target, p.toInt(), margin, weight, reduction);
Expand All @@ -277,13 +278,10 @@ Tensor multi_margin_loss_cpu(
Tensor& multi_margin_loss_cpu_out(const Tensor& input,
const Tensor& target,
const Scalar& p,
const Scalar& margin, const c10::optional<Tensor>& weight_opt,
const Scalar& margin,
const c10::optional<Tensor>& weight,
int64_t reduction,
Tensor& output) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;

multi_margin_loss_out_cpu_template(
output, input, target, p.toInt(), margin, weight, reduction);
return output;
Expand Down
13 changes: 10 additions & 3 deletions aten/src/ATen/native/cuda/MultiMarginLoss.cu
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ void multi_margin_loss_shape_check(
int64_t& dim,
const int64_t& ndims,
const Tensor& input,
const Tensor& target) {
const Tensor& target,
const c10::optional<Tensor>& weight) {
TORCH_CHECK(
(ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
"Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
Expand All @@ -150,6 +151,12 @@ void multi_margin_loss_shape_check(
target.dim() <= 1 && target.numel() == nframe,
"inconsistent target size, expected ", nframe, " but got ",
target.sizes());
if (weight && weight->defined()) {
TORCH_CHECK(
weight->dim() <= 1 && weight->numel() == dim,
"inconsistent weight size, expected ", dim, " but got ",
weight->sizes());
}
}

} // namespace (anonymous)
Expand All @@ -163,7 +170,7 @@ Tensor& multi_margin_loss_cuda_out(

TORCH_CHECK(p == 1 || p == 2, "multi_margin_loss: Invalid p, expected 1 or 2 but got ", p);

multi_margin_loss_shape_check(nframe, dim, ndims, input_, target_);
multi_margin_loss_shape_check(nframe, dim, ndims, input_, target_, weights_);

// produce a scalar output for 1d input
if (reduction == Reduction::None && target_.dim() > 0) {
Expand Down Expand Up @@ -318,7 +325,7 @@ Tensor& multi_margin_loss_cuda_backward_out(
TORCH_CHECK(p == 1 || p == 2,
"multi_margin_loss_backward: Invalid p, expected 1 or 2 but got ", p);

multi_margin_loss_shape_check(nframe, dim, ndims, input_, target_);
multi_margin_loss_shape_check(nframe, dim, ndims, input_, target_, weights_);
resize_output(grad_input_, input_.sizes());

if (input_.numel() == 0) {
Expand Down
11 changes: 8 additions & 3 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def meta_index_select_out(self, dim, index, out):
return out.copy_(torch.index_select(self, dim, index))


def _multi_margin_loss_shape_check(ndims, input, target):
def _multi_margin_loss_shape_check(ndims, input, target, weight):
torch._check(
(ndims == 2 and input.size(1) != 0)
or (ndims == 1 and input.size(0) != 0)
Expand All @@ -344,6 +344,11 @@ def _multi_margin_loss_shape_check(ndims, input, target):
target.dim() <= 1 and target.numel() == nframe,
lambda: f"inconsistent target size, expected {nframe} but got {target.shape}",
)
if weight is not None:
torch._check(
weight.ndim <= 1 and weight.numel() == dim,
lambda: f"inconsistent weight size, expected {dim} but got {weight.shape}",
)

return nframe, dim

Expand All @@ -360,7 +365,7 @@ def meta_multi_margin_loss(
) -> Tensor:
ndims = input.ndim
torch._check(p == 1 or p == 2, lambda: "only p == 1 and p == 2 supported")
nframe, _ = _multi_margin_loss_shape_check(ndims, input, target)
nframe, _ = _multi_margin_loss_shape_check(ndims, input, target, weight)
if reduction == Reduction.NONE.value and target.ndim > 0:
return input.new_empty(nframe)
else:
Expand All @@ -380,7 +385,7 @@ def meta_multi_margin_loss_backward(
) -> Tensor:
ndims = input.ndim
torch._check(p == 1 or p == 2, lambda: "only p == 1 and p == 2 supported")
_multi_margin_loss_shape_check(ndims, input, target)
_multi_margin_loss_shape_check(ndims, input, target, weight)
return input.new_empty(input.shape)


Expand Down
27 changes: 24 additions & 3 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,6 +1397,7 @@ def sample_inputs_zero_(op_info, device, dtype, requires_grad, **kwargs):
def sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs):
_make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False)
make_weight = partial(_make_tensor, requires_grad=False)

inputs = (
((), make_target([], low=0, high=1), {}),
Expand All @@ -1405,6 +1406,7 @@ def sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwa
((S, M), make_target([S], low=0, high=M), {"margin": 1.0}),
((S, M), make_target([S], low=0, high=M), {"margin": -3.14}),
((M, S), make_target([M], low=0, high=S), {"weight": None}),
((M, S), make_target([M], low=0, high=S), {"weight": make_weight([S], low=-10., high=10.)}),
((M, S), make_target([M], low=0, high=S), {"reduction": "none"}),
((M, S), make_target([M], low=0, high=S), {"reduction": "mean"}),
((M, S), make_target([M], low=0, high=S), {"reduction": "sum"}),
Expand All @@ -1418,6 +1420,7 @@ def reference_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **
yield from sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs)
_make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False)
make_weight = partial(_make_tensor, requires_grad=False)

inputs = (
((), make_target([], low=0, high=1)),
Expand All @@ -1427,13 +1430,17 @@ def reference_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **
)
ps = (1, 2)
margins = (0, 7, -3.14)
weights = (False, True)
reductions = (None, "none", "mean", "sum")

for (input_shape, target), p, margin, reduction in product(inputs, ps, margins, reductions):
kwargs = {"p": p, "margin": margin}
for (input_shape, target), p, margin, weight, reduction in product(inputs, ps, margins, weights, reductions):
input = _make_tensor(input_shape)
weight_shape = [input.size(-1)] if input.ndim > 0 else [1]
weight = make_weight(weight_shape, low=-10., high=10.) if weight else None
kwargs = {"p": p, "margin": margin, "weight": weight}
if reduction is not None:
kwargs["reduction"] = reduction
yield SampleInput(_make_tensor(input_shape), args=(target,), kwargs=kwargs)
yield SampleInput(input, args=(target,), kwargs=kwargs)


def error_inputs_multi_margin_loss(op, device, **kwargs):
Expand All @@ -1454,6 +1461,13 @@ def error_inputs_multi_margin_loss(op, device, **kwargs):
# invalid target dtype
yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={}),
error_type=RuntimeError, error_regex='expected scalar type Long but found Float')
# invalid weight
yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'weight': make_input(())}),
error_type=ValueError, error_regex='weight must be one-dimensional')
yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'weight': make_input(5, 4)}),
error_type=ValueError, error_regex='weight must be one-dimensional')
yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'weight': make_input(5,)}),
error_type=RuntimeError, error_regex=r'inconsistent weight size, expected 4 but got \[5\]')
# invalid p
yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'p': 3}),
error_type=ValueError, error_regex='only p == 1 and p == 2 supported')
Expand Down Expand Up @@ -12795,6 +12809,13 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
sample_inputs_func=sample_inputs_multi_margin_loss,
reference_inputs_func=reference_inputs_multi_margin_loss,
error_inputs_func=error_inputs_multi_margin_loss,
decorators=(
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}),
"TestJit",
"test_variant_consistency_jit",
),
),
),
OpInfo(
"nn.functional.multilabel_margin_loss",
Expand Down

0 comments on commit 0a68882

Please sign in to comment.