Skip to content

Commit

Permalink
[OpInfo] add reference and error inputs for multi_margin_loss (pyto…
Browse files Browse the repository at this point in the history
…rch#104850)

Pull Request resolved: pytorch#104850
Approved by: https://github.com/ezyang
  • Loading branch information
nkaretnikov authored and pytorchmergebot committed Jul 14, 2023
1 parent d06e1df commit 0c89596
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 14 deletions.
6 changes: 0 additions & 6 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9308,7 +9308,6 @@ def v(fn):

v(lambda: F.nll_loss(input, target, reduction=reduction))
v(lambda: F.cross_entropy(input, target, reduction=reduction))
v(lambda: F.multi_margin_loss(input, target, reduction=reduction))

v(lambda: F.kl_div(input, input, reduction=reduction))
v(lambda: F.huber_loss(input, input, reduction=reduction))
Expand Down Expand Up @@ -11444,11 +11443,6 @@ def test_batchnorm_update_stats(self, device):
with torch.backends.cudnn.flags(enabled=False):
self._test_batchnorm_update_stats(device)

def test_multi_margin_loss_errors(self, device):
self.assertRaises(RuntimeError,
lambda: nn.functional.multi_margin_loss(torch.randn(5, device=device),
torch.zeros(3, device=device)))

@onlyCPU
@dtypes(torch.bfloat16, torch.float16)
def test_activations_bfloat16_half_cpu(self, device, dtype):
Expand Down
7 changes: 0 additions & 7 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,13 +732,6 @@ def test_scalar_check(self, device):
self.assertRaises(RuntimeError,
lambda: torch.nn.functional.multilabel_margin_loss(input, target, reduction='sum'))

# multi_margin_loss
for input in (zero_d, one_d, torch.randn(1, 1, device=device)):
for target in (torch.tensor(0, device=device), torch.tensor([0], device=device)):
self.assertEqual(target.shape, torch.nn.functional.multi_margin_loss(input, target, reduction='none').shape)
self.assertEqual((), torch.nn.functional.multi_margin_loss(input, target, reduction='mean').shape)
self.assertEqual((), torch.nn.functional.multi_margin_loss(input, target, reduction='sum').shape)

# Test that `torch._check_tensor_all` raises errors in the correct cases
def test_check_tensor_all(self, device):
default_message = 'Expected cond to be True'
Expand Down
57 changes: 56 additions & 1 deletion torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1394,7 +1394,6 @@ def sample_inputs_zero_(op_info, device, dtype, requires_grad, **kwargs):
for shape in cases:
yield SampleInput(make_arg(shape))

# TODO: add reduction kwargs
def sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs):
_make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False)
Expand All @@ -1404,12 +1403,66 @@ def sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwa
((S,), make_target([], low=0, high=S), {"p": 1}),
((S,), make_target([1], low=0, high=S), {"p": 2}),
((S, M), make_target([S], low=0, high=M), {"margin": 1.0}),
((S, M), make_target([S], low=0, high=M), {"margin": -3.14}),
((M, S), make_target([M], low=0, high=S), {"weight": None}),
((M, S), make_target([M], low=0, high=S), {"reduction": "none"}),
((M, S), make_target([M], low=0, high=S), {"reduction": "mean"}),
((M, S), make_target([M], low=0, high=S), {"reduction": "sum"}),
)

for input_shape, target, kwargs in inputs:
yield SampleInput(_make_tensor(input_shape), args=(target,), kwargs=kwargs)


def reference_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs):
yield from sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs)
_make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False)

inputs = (
((), make_target([], low=0, high=1)),
((S,), make_target([], low=0, high=S)),
((S,), make_target([1], low=0, high=S)),
((M, S), make_target([M], low=0, high=S)),
)
ps = (1, 2)
margins = (0, 7, -3.14)
reductions = (None, "none", "mean", "sum")

for (input_shape, target), p, margin, reduction in product(inputs, ps, margins, reductions):
kwargs = {"p": p, "margin": margin}
if reduction is not None:
kwargs["reduction"] = reduction
yield SampleInput(_make_tensor(input_shape), args=(target,), kwargs=kwargs)


def error_inputs_multi_margin_loss(op, device, **kwargs):
make_input = partial(make_tensor, device=device, dtype=torch.float32)
# invalid reduction
yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'reduction': 'abc'}),
error_type=ValueError, error_regex='abc is not a valid value for reduction')
# invalid input
yield ErrorInput(SampleInput(make_input(5, 0), args=(make_input(5,),), kwargs={}),
error_type=RuntimeError,
error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[5, 0\]')
yield ErrorInput(SampleInput(make_input(0,), args=(make_input(5,),), kwargs={}),
error_type=RuntimeError,
error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[0\]')
# invalid target
yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={}),
error_type=RuntimeError,
error_regex=(
r'inconsistent target size, expected 5 but got \[5, 4\]'
if torch.device(device).type == 'cuda' else
r'inconsistent target size, got: \[5, 4\]'))
# invalid target dtype
yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={}),
error_type=RuntimeError, error_regex='expected scalar type Long but found Float')
# invalid p
yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'p': 3}),
error_type=ValueError, error_regex='only p == 1 and p == 2 supported')


def sample_inputs_logsumexp(self, device, dtype, requires_grad, **kwargs):
inputs = (
((), (0,), True),
Expand Down Expand Up @@ -12744,6 +12797,8 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
supports_out=False,
supports_gradgrad=False,
sample_inputs_func=sample_inputs_multi_margin_loss,
reference_inputs_func=reference_inputs_multi_margin_loss,
error_inputs_func=error_inputs_multi_margin_loss,
),
OpInfo(
"nn.functional.multilabel_margin_loss",
Expand Down

0 comments on commit 0c89596

Please sign in to comment.