Skip to content

Commit

Permalink
[OpInfo] add reference and error inputs for multilabel_margin_loss (p…
Browse files Browse the repository at this point in the history
…ytorch#105523)

Pull Request resolved: pytorch#105523
Approved by: https://github.com/ezyang
  • Loading branch information
nkaretnikov authored and pytorchmergebot committed Jul 23, 2023
1 parent bba06ad commit eac9e1b
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 25 deletions.
1 change: 0 additions & 1 deletion test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9231,7 +9231,6 @@ def v(fn):

zeros = torch.zeros_like(input).to(torch.int64)
v(lambda: F.multilabel_soft_margin_loss(input, zeros, reduction=reduction))
v(lambda: F.multilabel_margin_loss(input, zeros, reduction=reduction))

v(lambda: F.triplet_margin_loss(input, input, input, reduction=reduction))
v(lambda: F.triplet_margin_with_distance_loss(input, input, input, reduction=reduction))
Expand Down
17 changes: 0 additions & 17 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,23 +714,6 @@ def test_scalar_check(self, device):
self.assertEqual((), torch.nn.functional.nll_loss(input, target, reduction='mean').shape)
self.assertEqual((), torch.nn.functional.nll_loss(input, target, reduction='sum').shape)

# multilabel_margin_loss
for input in (zero_d, one_d, torch.randn(1, 1, device=device)):
for target in (torch.tensor(0, device=device), torch.tensor([0], device=device), torch.tensor([[0]], device=device)):
if (input.dim() <= 1 and target.dim() <= 1) or (input.dim() == 2 and target.dim() == 2):
output_shape = (target.shape[0],) if target.dim() == 2 else ()
self.assertEqual(output_shape,
torch.nn.functional.multilabel_margin_loss(input, target, reduction='none').shape)
self.assertEqual((), torch.nn.functional.multilabel_margin_loss(input, target, reduction='mean').shape)
self.assertEqual((), torch.nn.functional.multilabel_margin_loss(input, target, reduction='sum').shape)
else:
self.assertRaises(RuntimeError,
lambda: torch.nn.functional.multilabel_margin_loss(input, target, reduction='none'))
self.assertRaises(RuntimeError,
lambda: torch.nn.functional.multilabel_margin_loss(input, target, reduction='mean'))
self.assertRaises(RuntimeError,
lambda: torch.nn.functional.multilabel_margin_loss(input, target, reduction='sum'))

# Test that `torch._check_tensor_all` raises errors in the correct cases
def test_check_tensor_all(self, device):
default_message = 'Expected cond to be True'
Expand Down
71 changes: 64 additions & 7 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1542,19 +1542,74 @@ def reference_inputs_like_fns(op, device, dtype, requires_grad, **kwargs):
yield SampleInput(make_arg(shape, noncontiguous=True))
yield SampleInput(make_arg(shape, noncontiguous=True).transpose(0, -1))

# TODO: add reduction kwargs
def sample_inputs_multilabel_margin_loss(op_info, device, dtype, requires_grad, **kwargs):
_make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False)

inputs = (
([], make_target([], low=0, high=1)),
([S], make_target([S], low=0, high=S)),
([M, S], make_target([M, S], low=0, high=S)),
([], make_target([], low=0, high=1), {}),
([S], make_target([S], low=0, high=S), {}),
([M, S], make_target([M, S], low=0, high=S), {}),
([M, S], make_target([M, S], low=0, high=S), {"reduction": "none"}),
([M, S], make_target([M, S], low=0, high=S), {"reduction": "mean"}),
([M, S], make_target([M, S], low=0, high=S), {"reduction": "sum"}),
)

for shape, target, kwargs in inputs:
yield SampleInput(_make_tensor(shape), args=(target,), kwargs=kwargs)


def reference_inputs_multilabel_margin_loss(op_info, device, dtype, requires_grad, **kwargs):
yield from sample_inputs_multilabel_margin_loss(op_info, device, dtype, requires_grad, **kwargs)
_make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False)
make_target_tensor = partial(torch.tensor, device=device, dtype=torch.long, requires_grad=False)

inputs = (
# random tests including -1 target labels
([], make_target([], low=-1, high=1)),
([S], make_target([S], low=-1, high=S)),
([M, S], make_target([M, S], low=-1, high=S)),
# repeated target labels and -1 (labels after the first -1 are ignored)
([], make_target_tensor(-1)),
([7], make_target_tensor([2, 0, 6, -1, 4, -1, 6])),
([4, 5], make_target_tensor([[4, -1, 0, -1, 2], [0, 0, 4, 1, 4], [-1, 3, -1, 1, 0], [4, 3, 2, 1, 0]])),
)
reductions = (None, "none", "mean", "sum")

for (shape, target), reduction in product(inputs, reductions):
kwargs = {}
if reduction is not None:
kwargs["reduction"] = reduction
yield SampleInput(_make_tensor(shape), args=(target,), kwargs=kwargs)


def error_inputs_multilabel_margin_loss(op, device, **kwargs):
make_input = partial(make_tensor, device=device, dtype=torch.float32)
# invalid reduction
yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5, 4),), kwargs={'reduction': 'abc'}),
error_type=ValueError, error_regex='abc is not a valid value for reduction')
# invalid input
yield ErrorInput(SampleInput(make_input(5, 0), args=(make_input(5, 4),), kwargs={}),
error_type=RuntimeError,
error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[5, 0\]')
yield ErrorInput(SampleInput(make_input(0,), args=(make_input(0,),), kwargs={}),
error_type=RuntimeError,
error_regex=r'Expected non-empty vector or matrix with optional 0-dim batch size, but got: \[0\]')
# invalid target
yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(4,),), kwargs={}),
error_type=RuntimeError,
error_regex=(
r'inconsistent target size: \[4\] for input of size: \[5, 4\]'
if torch.device(device).type == 'cuda' else
r'inconsistent size \[4\] for argument #2 \'target\''))
yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input((),),), kwargs={}),
error_type=RuntimeError,
error_regex=(
r'inconsistent target size: \[\] for input of size: \[5, 4\]'
if torch.device(device).type == 'cuda' else
r'inconsistent size \[\] for argument #2 \'target\''))

for shape, target in inputs:
yield SampleInput(_make_tensor(shape), args=(target,))

def get_independent_tensor(tensor):
return tensor.clone().requires_grad_(tensor.requires_grad)
Expand Down Expand Up @@ -12825,7 +12880,9 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16),
supports_out=False,
supports_gradgrad=False,
sample_inputs_func=sample_inputs_multilabel_margin_loss
sample_inputs_func=sample_inputs_multilabel_margin_loss,
reference_inputs_func=reference_inputs_multilabel_margin_loss,
error_inputs_func=error_inputs_multilabel_margin_loss,
),
OpInfo('nn.functional.leaky_relu',
aliases=None,
Expand Down

0 comments on commit eac9e1b

Please sign in to comment.