diff --git a/aten/src/ATen/native/Loss.cpp b/aten/src/ATen/native/Loss.cpp index 5bf6feedecc7f..6c4c21bd1aa8c 100644 --- a/aten/src/ATen/native/Loss.cpp +++ b/aten/src/ATen/native/Loss.cpp @@ -30,13 +30,32 @@ DEFINE_DISPATCH(mse_stub); DEFINE_DISPATCH(mse_backward_stub); Tensor cosine_embedding_loss(const Tensor& input1, const Tensor& input2, const Tensor& target, double margin, int64_t reduction) { + auto targ_dim = target.dim(); TORCH_CHECK( - target.dim() == 1, - "1D target tensor expected, multi-target not supported"); + targ_dim == 1 || targ_dim == 0, + "0D or 1D target tensor expected, multi-target not supported"); + + if (targ_dim == 1) { + TORCH_CHECK( + input1.dim() == 2, + "1D target tensor expects 2D input tensors, but found inputs with sizes ", + input1.sizes(), + " and ", + input2.sizes(), + "."); + } else { + TORCH_CHECK( + input1.dim() == 1, + "0D target tensor expects 1D input tensors, but found inputs with sizes ", + input1.sizes(), + " and ", + input2.sizes(), + "."); + } - auto prod_sum = (input1 * input2).sum(1); - auto mag_square1 = (input1 * input1).sum(1) + EPSILON; - auto mag_square2 = (input2 * input2).sum(1) + EPSILON; + auto prod_sum = (input1 * input2).sum(targ_dim); + auto mag_square1 = (input1 * input1).sum(targ_dim) + EPSILON; + auto mag_square2 = (input2 * input2).sum(targ_dim) + EPSILON; auto denom = (mag_square1 * mag_square2).sqrt_(); auto cos = prod_sum / denom; diff --git a/test/test_nn.py b/test/test_nn.py index 1a6416d72e65a..f5b435f0d0d24 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -9480,7 +9480,7 @@ def test_cosine_embedding_loss_margin_no_reduce(self): loss_reference_fns['CosineEmbeddingLoss'](input1, input2, target, margin=0.5, reduction='none')) - def test_cosine_embedding_loss_invalid_target_shape(self): + def test_cosine_embedding_loss_invalid_shape(self): input1 = torch.randn(15, 10) input2 = torch.randn(15, 10) target = torch.randn(15, 1).sign() @@ -9488,6 +9488,12 @@ def test_cosine_embedding_loss_invalid_target_shape(self): with self.assertRaisesRegex(RuntimeError, "1D target tensor expected"): F.cosine_embedding_loss(input1, input2, target) + with self.assertRaisesRegex(RuntimeError, "1D target tensor expects 2D input tensors"): + F.cosine_embedding_loss(torch.randn(10), torch.randn(10), torch.randn(10)) + + with self.assertRaisesRegex(RuntimeError, "0D target tensor expects 1D input tensors"): + F.cosine_embedding_loss(torch.randn(2, 5), torch.randn(2, 5), torch.randn(())) + def test_margin_ranking_loss_no_reduce(self): input1 = torch.randn(15).mul_(10).requires_grad_() input2 = torch.randn(15).mul_(10).requires_grad_() diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index d72c614c88048..e0989e5b44f28 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -1236,9 +1236,9 @@ class CosineEmbeddingLoss(_Loss): specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` Shape: - - Input1: :math:`(N, D)`, where `N` is the batch size and `D` is the embedding dimension. - - Input2: :math:`(N, D)`, same shape as Input1. - - Target: :math:`(N)`. + - Input1: :math:`(N, D)` or :math:`(D)`, where `N` is the batch size and `D` is the embedding dimension. + - Input2: :math:`(N, D)` or :math:`(D)`, same shape as Input1. + - Target: :math:`(N)` or :math:`()`. - Output: If :attr:`reduction` is ``'none'``, then :math:`(N)`, otherwise scalar. """ __constants__ = ['margin', 'reduction'] diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index b22b6ab1d2ec5..f3cdc478ac71d 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -5384,7 +5384,22 @@ def single_batch_reference_criterion_fn(*args): The output is squeezed to compare with the no-batch input. """ criterion = args[-1] - single_batch_input_args = [input.unsqueeze(0) for input in args[:-1]] + + def unsqueeze_inp(inp): + if isinstance(inp, (list, tuple)): + return [t.unsqueeze(0) for t in inp] + return inp.unsqueeze(0) + + def flatten(xs): + result = [] + if isinstance(xs, (list, tuple)): + for x in xs: + result.extend(flatten(x)) + else: + result.append(xs) + return result + + single_batch_input_args = flatten([unsqueeze_inp(input) for input in args[:-1]]) output = criterion(*single_batch_input_args) reduction = get_reduction(criterion) @@ -5421,6 +5436,7 @@ def single_batch_reference_criterion_fn(*args): ('MultiLabelMarginLoss', lambda: torch.randn(4), lambda: torch.tensor([3, 0, -1, 1])), ('SoftMarginLoss', lambda: torch.randn(9), lambda: torch.tensor([-1, 1, 1] * 3)), ('NLLLoss', lambda: F.log_softmax(torch.randn(3), dim=0), lambda: torch.tensor(1)), + ('CosineEmbeddingLoss', lambda: (torch.randn(9), torch.randn(9)), lambda: torch.tensor(1)), ] classification_criterion_no_batch_extra_info: Dict[str, dict] = { 'MultiLabelMarginLoss': {'check_gradgrad': False},