Skip to content

Commit

Permalink
[nn] no batch dim support: CosineEmbeddingLoss (pytorch#64590)
Browse files Browse the repository at this point in the history
Summary:
Reference: pytorch#60585

TODO
* [x] Add tests

Pull Request resolved: pytorch#64590

Reviewed By: H-Huang

Differential Revision: D30900775

Pulled By: jbschlosser

fbshipit-source-id: d24e72787017e79afbf8f04a94901a290485b81a
  • Loading branch information
kshitij12345 authored and facebook-github-bot committed Sep 13, 2021
1 parent 2ae938e commit 01e92f2
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 10 deletions.
29 changes: 24 additions & 5 deletions aten/src/ATen/native/Loss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,32 @@ DEFINE_DISPATCH(mse_stub);
DEFINE_DISPATCH(mse_backward_stub);

Tensor cosine_embedding_loss(const Tensor& input1, const Tensor& input2, const Tensor& target, double margin, int64_t reduction) {
auto targ_dim = target.dim();
TORCH_CHECK(
target.dim() == 1,
"1D target tensor expected, multi-target not supported");
targ_dim == 1 || targ_dim == 0,
"0D or 1D target tensor expected, multi-target not supported");

if (targ_dim == 1) {
TORCH_CHECK(
input1.dim() == 2,
"1D target tensor expects 2D input tensors, but found inputs with sizes ",
input1.sizes(),
" and ",
input2.sizes(),
".");
} else {
TORCH_CHECK(
input1.dim() == 1,
"0D target tensor expects 1D input tensors, but found inputs with sizes ",
input1.sizes(),
" and ",
input2.sizes(),
".");
}

auto prod_sum = (input1 * input2).sum(1);
auto mag_square1 = (input1 * input1).sum(1) + EPSILON;
auto mag_square2 = (input2 * input2).sum(1) + EPSILON;
auto prod_sum = (input1 * input2).sum(targ_dim);
auto mag_square1 = (input1 * input1).sum(targ_dim) + EPSILON;
auto mag_square2 = (input2 * input2).sum(targ_dim) + EPSILON;
auto denom = (mag_square1 * mag_square2).sqrt_();
auto cos = prod_sum / denom;

Expand Down
8 changes: 7 additions & 1 deletion test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9480,14 +9480,20 @@ def test_cosine_embedding_loss_margin_no_reduce(self):
loss_reference_fns['CosineEmbeddingLoss'](input1, input2, target,
margin=0.5, reduction='none'))

def test_cosine_embedding_loss_invalid_target_shape(self):
def test_cosine_embedding_loss_invalid_shape(self):
input1 = torch.randn(15, 10)
input2 = torch.randn(15, 10)
target = torch.randn(15, 1).sign()

with self.assertRaisesRegex(RuntimeError, "1D target tensor expected"):
F.cosine_embedding_loss(input1, input2, target)

with self.assertRaisesRegex(RuntimeError, "1D target tensor expects 2D input tensors"):
F.cosine_embedding_loss(torch.randn(10), torch.randn(10), torch.randn(10))

with self.assertRaisesRegex(RuntimeError, "0D target tensor expects 1D input tensors"):
F.cosine_embedding_loss(torch.randn(2, 5), torch.randn(2, 5), torch.randn(()))

def test_margin_ranking_loss_no_reduce(self):
input1 = torch.randn(15).mul_(10).requires_grad_()
input2 = torch.randn(15).mul_(10).requires_grad_()
Expand Down
6 changes: 3 additions & 3 deletions torch/nn/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,9 +1236,9 @@ class CosineEmbeddingLoss(_Loss):
specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
Shape:
- Input1: :math:`(N, D)`, where `N` is the batch size and `D` is the embedding dimension.
- Input2: :math:`(N, D)`, same shape as Input1.
- Target: :math:`(N)`.
- Input1: :math:`(N, D)` or :math:`(D)`, where `N` is the batch size and `D` is the embedding dimension.
- Input2: :math:`(N, D)` or :math:`(D)`, same shape as Input1.
- Target: :math:`(N)` or :math:`()`.
- Output: If :attr:`reduction` is ``'none'``, then :math:`(N)`, otherwise scalar.
"""
__constants__ = ['margin', 'reduction']
Expand Down
18 changes: 17 additions & 1 deletion torch/testing/_internal/common_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5384,7 +5384,22 @@ def single_batch_reference_criterion_fn(*args):
The output is squeezed to compare with the no-batch input.
"""
criterion = args[-1]
single_batch_input_args = [input.unsqueeze(0) for input in args[:-1]]

def unsqueeze_inp(inp):
if isinstance(inp, (list, tuple)):
return [t.unsqueeze(0) for t in inp]
return inp.unsqueeze(0)

def flatten(xs):
result = []
if isinstance(xs, (list, tuple)):
for x in xs:
result.extend(flatten(x))
else:
result.append(xs)
return result

single_batch_input_args = flatten([unsqueeze_inp(input) for input in args[:-1]])

output = criterion(*single_batch_input_args)
reduction = get_reduction(criterion)
Expand Down Expand Up @@ -5421,6 +5436,7 @@ def single_batch_reference_criterion_fn(*args):
('MultiLabelMarginLoss', lambda: torch.randn(4), lambda: torch.tensor([3, 0, -1, 1])),
('SoftMarginLoss', lambda: torch.randn(9), lambda: torch.tensor([-1, 1, 1] * 3)),
('NLLLoss', lambda: F.log_softmax(torch.randn(3), dim=0), lambda: torch.tensor(1)),
('CosineEmbeddingLoss', lambda: (torch.randn(9), torch.randn(9)), lambda: torch.tensor(1)),
]
classification_criterion_no_batch_extra_info: Dict[str, dict] = {
'MultiLabelMarginLoss': {'check_gradgrad': False},
Expand Down

0 comments on commit 01e92f2

Please sign in to comment.