Skip to content

Commit

Permalink
[pt2] add metas for multilabel_margin_loss ops (pytorch#104388)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#104388
Approved by: https://github.com/ezyang
  • Loading branch information
nkaretnikov authored and pytorchmergebot committed Jul 5, 2023
1 parent a3aa4da commit c00dd43
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 6 deletions.
1 change: 0 additions & 1 deletion test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2833,7 +2833,6 @@ def forward(self, x):
xfail('nn.functional.interpolate', 'area'), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.interpolate', 'linear'), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.interpolate', 'trilinear'), # Cannot call sizes() on tensor with symbolic sizes/st...
xfail('nn.functional.multilabel_margin_loss', ''), # could not find kernel
xfail('nn.functional.nll_loss', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.pixel_shuffle', ''), # aten.pixel_shuffle.default - couldn't find symbolic meta fun...
xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta...
Expand Down
4 changes: 0 additions & 4 deletions test/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,6 @@ def run_meta_crossref(
torch.mode : {f64, i32, i64, f16, u8, i16, bf16, b8, i8, f32},
torch.nn.functional.ctc_loss : {f64, f32},
torch.nn.functional.gaussian_nll_loss : {f16, f64, bf16, f32},
torch.nn.functional.multilabel_margin_loss : {f64, f32},
torch.nn.functional.one_hot : {i64},
torch._segment_reduce : {f64, f16, bf16, f32},
torch.cholesky : {f64, f32, c128, c64},
Expand Down Expand Up @@ -719,7 +718,6 @@ def run_meta_crossref(
torch.histc: {i16, i32, i64, i8}, # aten::histc, aten::histc.out
torch.kthvalue: {f16}, # aten::kthvalue.values
torch.median: {f16}, # aten::median, aten::median.dim_values
torch.nn.functional.multilabel_margin_loss: {bf16, f16}, # aten::multilabel_margin_loss_forward
torch.ormqr: {f32, f64}, # aten::ormqr, aten::ormqr.out
}

Expand Down Expand Up @@ -844,7 +842,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
aten.median.default : {i8, f64, i64, bf16, f32, i32, i16, u8},
aten.median.dim : {i8, f64, i64, bf16, f32, i32, i16, u8},
aten.mode.default : {f16, i8, f64, i64, bf16, f32, i32, b8, i16, u8},
aten.multilabel_margin_loss_forward.default : {f32, f64},
aten.nll_loss2d_forward.default : {bf16, f32, f64},
aten.rrelu_with_noise.default : {bf16, f32, f64},
aten.segment_reduce.default : {bf16, f32, f16, f64},
Expand Down Expand Up @@ -901,7 +898,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
aten.log_sigmoid_forward.output : {bf16, f16, f64, f32}, # aten::log_sigmoid_forward.output
aten.median.default: {f16}, # aten::median
aten.median.dim: {f16}, # aten::median.dim_values
aten.multilabel_margin_loss_forward.default: {bf16, f16}, # aten::multilabel_margin_loss_forward
aten.nll_loss2d_forward.default: {f16}, # aten::nll_loss2d_forward
aten.ormqr.default: {f32, f64}, # aten::ormqr
aten.ormqr.out: {f32, f64}, # aten::ormqr.out
Expand Down
1 change: 0 additions & 1 deletion test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,7 +1540,6 @@ def f(a, b, c, d, e):
xfail('nn.functional.grid_sample', ''), # aten.grid_sampler_2d.default - couldn't find symbolic meta function/decompos...
xfail('nn.functional.interpolate', 'linear'), # aten.upsample_linear1d.vec - couldn't find symbolic meta function/dec...
xfail('nn.functional.interpolate', 'trilinear'), # aten.upsample_trilinear3d.vec - couldn't find symbolic meta functi...
xfail('nn.functional.multilabel_margin_loss', ''), # Could not run 'aten::multilabel_margin_loss_forward' with ...
xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco...
xfail('normal', 'number_mean'), # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition
xfail('ormqr', ''), # aten.ormqr.default - couldn't find symbolic meta function/decomposition
Expand Down
72 changes: 72 additions & 0 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,78 @@ def meta_multi_margin_loss_backward(
return input.new_empty(input.shape)


def _multilabel_margin_loss_shape_check(ndims, target_arg, input, target):
valid_inputs = (
(ndims == 2 and input.size(1) != 0)
or (ndims == 1 and input.size(0) != 0)
or ndims == 0
)
torch._check(
valid_inputs,
lambda: f"Expected non-empty vector or matrix with optional 0-dim batch size, but got: {input.shape}",
)
if ndims <= 1:
nframe = 1
dim = 1 if ndims == 0 else input.size(0)
torch._check(
valid_inputs and target.ndim <= 1 and target.numel() == dim,
lambda: f"inconsistent size {target.shape} for {target_arg}",
)
else:
nframe = input.size(0)
dim = input.size(1)
torch._check(
valid_inputs
and target.ndim == 2
and target.size(0) == nframe
and target.size(1) == dim,
lambda: f"inconsistent size {target.shape} for {target_arg}",
)
return nframe, dim


@register_meta(aten.multilabel_margin_loss_forward)
@out_wrapper("output", "is_target")
def meta_multilabel_margin_loss_forward(
input: Tensor,
target: Tensor,
reduction: int,
) -> Tuple[Tensor, Tensor]:
target_arg = "argument #2 'target'"
ndims = input.ndim
nframe, _ = _multilabel_margin_loss_shape_check(ndims, target_arg, input, target)
if reduction != Reduction.NONE.value or target.ndim <= 1:
output = input.new_empty(())
else:
output = input.new_empty(nframe)
is_target = input.new_empty(target.shape)
return output, is_target


@register_meta(aten.multilabel_margin_loss_backward)
@out_wrapper()
def meta_multilabel_margin_loss_backward(
grad_output: Tensor,
input: Tensor,
target: Tensor,
reduction: int,
is_target: Tensor,
) -> Tensor:
target_arg = "argument #3 'target'"
is_target_arg = "argument #5 'is_target'"
ndims = input.ndim
_multilabel_margin_loss_shape_check(ndims, target_arg, input, target)
torch._check(
target.shape == is_target.shape,
lambda: (
f"Expected tensor for {target_arg} to have same size as tensor for {is_target_arg}"
f"; but {target.shape} does not equal {is_target.shape}"
f" (while checking arguments for multilabel_margin_loss_backward)"
),
)
return input.new_empty(input.shape)


@register_meta([aten.max.default, aten.max.unary_out])
@out_wrapper()
def meta_max(self):
Expand Down

0 comments on commit c00dd43

Please sign in to comment.