Skip to content

Commit

Permalink
[primTorch] Add decomp for soft_margin_loss (pytorch#83804)
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaretnikov authored and pytorchmergebot committed Aug 31, 2022
1 parent 305af90 commit 71ce9cd
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
26 changes: 26 additions & 0 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,32 @@ def binary_cross_entropy_backward(
return result


@register_decomposition(aten.soft_margin_loss)
@out_wrapper()
@pw_cast_for_opmath
def soft_margin_loss(
input: Tensor,
target: Tensor,
reduction: int = Reduction.MEAN.value,
) -> Tensor:
loss = torch.log1p(torch.exp(-input * target))
return apply_loss_reduction(loss, reduction)


@register_decomposition(aten.soft_margin_loss_backward)
@pw_cast_for_opmath
def soft_margin_loss_backward(
grad_output: Tensor,
self: Tensor,
target: Tensor,
reduction: int = Reduction.MEAN.value,
) -> Tensor:
grad_input = target * grad_output * (torch.sigmoid(target * self) - 1)
if reduction == Reduction.MEAN.value:
grad_input = grad_input / self.numel()
return grad_input


@register_decomposition(aten._euclidean_dist)
def _euclidean_dist(x1: Tensor, x2: Tensor) -> Tensor:
x1_norm = x1.pow(2).sum(-1, True)
Expand Down
15 changes: 15 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7007,6 +7007,20 @@ def gen_shape_kwargs():
for input, target, kwargs in gen_shape_kwargs():
yield SampleInput(input, args=(target, ), kwargs=kwargs)

def error_inputs_soft_margin_loss(op_info, device, **kwargs):
make = partial(make_tensor, device=device, dtype=torch.float32)

# invalid reduction value
yield ErrorInput(SampleInput(make(5, 4), args=(make(5, 4),),
kwargs={'reduction': 'abc'}),
error_type=ValueError,
error_regex='abc is not a valid value for reduction')
# invalid input shapes
yield ErrorInput(SampleInput(make(5, 4), args=(make(5,),)),
error_regex=(r'The size of tensor a \(4\) must match the '
r'size of tensor b \(5\) at non-singleton '
r'dimension 1'))

def sample_inputs_triplet_margin_loss(op_info, device, dtype, requires_grad, with_distance=False, **kwargs):
make = partial(make_tensor, (S, M), device=device, dtype=dtype, requires_grad=requires_grad)

Expand Down Expand Up @@ -10836,6 +10850,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
supports_forward_ad=True,
# doesn't support grad on target
sample_inputs_func=partial(sample_inputs_loss, rhs_requires_grad=False),
error_inputs_func=error_inputs_soft_margin_loss,
),
OpInfo('nn.functional.upsample_nearest',
supports_autograd=True,
Expand Down

0 comments on commit 71ce9cd

Please sign in to comment.