Skip to content

Commit

Permalink
Add OpInfo for torch.mean (pytorch#55525)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#55525

Reviewed By: agolynski

Differential Revision: D27796651

Pulled By: heitorschueroff

fbshipit-source-id: 6473d854f090ff62c856b404870f226f46569449
  • Loading branch information
heitorschueroff authored and facebook-github-bot committed Apr 16, 2021
1 parent 119b3ec commit 48c6f0c
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3751,6 +3751,14 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
dtypesIfCPU=all_types_and(torch.float16, torch.bool),
supports_out=False,
sample_inputs_func=sample_inputs_reduction_wrapper(supports_multiple_dims=True)),
# TODO(@heitorschueroff) Add test for dtype kwarg
OpInfo('mean',
dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
assert_autodiffed=True,
sample_inputs_func=sample_inputs_reduction_wrapper(supports_multiple_dims=True),
# Need to skip out test because one of the overload for mean does not support it
# TODO(@heitorschueroff) fix this when implementing ReductionInfo
skips=(SkipInfo('TestCommon', 'test_out'),)),
OpInfo('maximum',
op=torch.maximum,
dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
Expand Down Expand Up @@ -4839,13 +4847,6 @@ def method_tests():
('lerp', (S, 1, S), ((S, S), 0.4), 'broadcast_all', (True,)),
('lerp', (), ((S, S, S), 0.4), 'scalar_broadcast_lhs', (True,)),
('lerp', (S, 1, S), ((S, S), (S, 1, 1, S)), 'tensor_broadcast_all', (True,)),
('mean', (S, S, S), NO_ARGS, '', (True,)),
('mean', (S, S, S), (1,), 'dim', (True,), [0]),
('mean', (S, S, S), (1, True,), 'keepdim_dim', (True,), [0]),
('mean', (), NO_ARGS, 'scalar', (True,)),
('mean', (), (0,), 'scalar_dim', (True,), [0]),
('mean', (), (0, True,), 'scalar_keepdim_dim', (True,), [0]),
('mean', (S, S, S), (), 'dtype', (True,), (), (), ident, {'dtype': torch.float64}),
('kthvalue', (S, S, S), (2,)),
('kthvalue', (S, S, S), (2, 1,), 'dim', (), [1]),
('kthvalue', (S, S, S), (2, 1,), 'dim_alert_nondeterministic', (), [1],
Expand Down

0 comments on commit 48c6f0c

Please sign in to comment.