diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 6031f6e426682..a455c9c2be34e 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -2833,7 +2833,6 @@ def forward(self, x): xfail('nn.functional.interpolate', 'area'), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.interpolate', 'linear'), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.interpolate', 'trilinear'), # Cannot call sizes() on tensor with symbolic sizes/st... - xfail('nn.functional.multilabel_margin_loss', ''), # could not find kernel xfail('nn.functional.nll_loss', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.pixel_shuffle', ''), # aten.pixel_shuffle.default - couldn't find symbolic meta fun... xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta... diff --git a/test/test_meta.py b/test/test_meta.py index 62a2c854445ab..b0879b930b5cf 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -631,7 +631,6 @@ def run_meta_crossref( torch.mode : {f64, i32, i64, f16, u8, i16, bf16, b8, i8, f32}, torch.nn.functional.ctc_loss : {f64, f32}, torch.nn.functional.gaussian_nll_loss : {f16, f64, bf16, f32}, - torch.nn.functional.multilabel_margin_loss : {f64, f32}, torch.nn.functional.one_hot : {i64}, torch._segment_reduce : {f64, f16, bf16, f32}, torch.cholesky : {f64, f32, c128, c64}, @@ -719,7 +718,6 @@ def run_meta_crossref( torch.histc: {i16, i32, i64, i8}, # aten::histc, aten::histc.out torch.kthvalue: {f16}, # aten::kthvalue.values torch.median: {f16}, # aten::median, aten::median.dim_values - torch.nn.functional.multilabel_margin_loss: {bf16, f16}, # aten::multilabel_margin_loss_forward torch.ormqr: {f32, f64}, # aten::ormqr, aten::ormqr.out } @@ -844,7 +842,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None): aten.median.default : {i8, f64, i64, bf16, f32, i32, i16, u8}, aten.median.dim : {i8, f64, i64, bf16, f32, i32, i16, u8}, aten.mode.default : {f16, i8, f64, i64, bf16, f32, i32, b8, i16, u8}, - aten.multilabel_margin_loss_forward.default : {f32, f64}, aten.nll_loss2d_forward.default : {bf16, f32, f64}, aten.rrelu_with_noise.default : {bf16, f32, f64}, aten.segment_reduce.default : {bf16, f32, f16, f64}, @@ -901,7 +898,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None): aten.log_sigmoid_forward.output : {bf16, f16, f64, f32}, # aten::log_sigmoid_forward.output aten.median.default: {f16}, # aten::median aten.median.dim: {f16}, # aten::median.dim_values - aten.multilabel_margin_loss_forward.default: {bf16, f16}, # aten::multilabel_margin_loss_forward aten.nll_loss2d_forward.default: {f16}, # aten::nll_loss2d_forward aten.ormqr.default: {f32, f64}, # aten::ormqr aten.ormqr.out: {f32, f64}, # aten::ormqr.out diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 482946aaf1cdf..d76ecd23b734b 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1540,7 +1540,6 @@ def f(a, b, c, d, e): xfail('nn.functional.grid_sample', ''), # aten.grid_sampler_2d.default - couldn't find symbolic meta function/decompos... xfail('nn.functional.interpolate', 'linear'), # aten.upsample_linear1d.vec - couldn't find symbolic meta function/dec... xfail('nn.functional.interpolate', 'trilinear'), # aten.upsample_trilinear3d.vec - couldn't find symbolic meta functi... - xfail('nn.functional.multilabel_margin_loss', ''), # Could not run 'aten::multilabel_margin_loss_forward' with ... xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco... xfail('normal', 'number_mean'), # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition xfail('ormqr', ''), # aten.ormqr.default - couldn't find symbolic meta function/decomposition diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 6b0516c6650e3..59af8b9ce4ad5 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -377,6 +377,78 @@ def meta_multi_margin_loss_backward( return input.new_empty(input.shape) +def _multilabel_margin_loss_shape_check(ndims, target_arg, input, target): + valid_inputs = ( + (ndims == 2 and input.size(1) != 0) + or (ndims == 1 and input.size(0) != 0) + or ndims == 0 + ) + torch._check( + valid_inputs, + lambda: f"Expected non-empty vector or matrix with optional 0-dim batch size, but got: {input.shape}", + ) + if ndims <= 1: + nframe = 1 + dim = 1 if ndims == 0 else input.size(0) + torch._check( + valid_inputs and target.ndim <= 1 and target.numel() == dim, + lambda: f"inconsistent size {target.shape} for {target_arg}", + ) + else: + nframe = input.size(0) + dim = input.size(1) + torch._check( + valid_inputs + and target.ndim == 2 + and target.size(0) == nframe + and target.size(1) == dim, + lambda: f"inconsistent size {target.shape} for {target_arg}", + ) + return nframe, dim + + +@register_meta(aten.multilabel_margin_loss_forward) +@out_wrapper("output", "is_target") +def meta_multilabel_margin_loss_forward( + input: Tensor, + target: Tensor, + reduction: int, +) -> Tuple[Tensor, Tensor]: + target_arg = "argument #2 'target'" + ndims = input.ndim + nframe, _ = _multilabel_margin_loss_shape_check(ndims, target_arg, input, target) + if reduction != Reduction.NONE.value or target.ndim <= 1: + output = input.new_empty(()) + else: + output = input.new_empty(nframe) + is_target = input.new_empty(target.shape) + return output, is_target + + +@register_meta(aten.multilabel_margin_loss_backward) +@out_wrapper() +def meta_multilabel_margin_loss_backward( + grad_output: Tensor, + input: Tensor, + target: Tensor, + reduction: int, + is_target: Tensor, +) -> Tensor: + target_arg = "argument #3 'target'" + is_target_arg = "argument #5 'is_target'" + ndims = input.ndim + _multilabel_margin_loss_shape_check(ndims, target_arg, input, target) + torch._check( + target.shape == is_target.shape, + lambda: ( + f"Expected tensor for {target_arg} to have same size as tensor for {is_target_arg}" + f"; but {target.shape} does not equal {is_target.shape}" + f" (while checking arguments for multilabel_margin_loss_backward)" + ), + ) + return input.new_empty(input.shape) + + @register_meta([aten.max.default, aten.max.unary_out]) @out_wrapper() def meta_max(self):