Skip to content

Commit 26b7ff5

Browse files
pmeierfacebook-github-bot
authored andcommitted
deprecate dtype getters from torch.testing namespace (pytorch#63554)
Summary: Pull Request resolved: pytorch#63554 Following pytorch#61840 (comment), this deprecates all the dtype getters publicly exposed in the `torch.testing` namespace. The reason for this twofold: 1. If someone is not familiar with the C++ dispatch macros PyTorch uses, the names are misleading. For example `torch.testing.floating_types()` will only give you `float32` and `float64` skipping `float16` and `bfloat16`. 2. The dtype getters provide very minimal functionality that can be easily emulated by downstream libraries. We thought about [providing an replacement](https://gist.github.com/pmeier/3dfd2e105842ad0de4505068a1a0270a), but ultimately decided against it. The major problem is BC: by keeping it, either the namespace is getting messy again after a new dtype is added or we need to somehow version the return values of the getters. Test Plan: Imported from OSS Reviewed By: H-Huang Differential Revision: D30662206 Pulled By: mruberry fbshipit-source-id: a2bdb10ab02ae665df1b5b76e8afa9af043bbf56
1 parent f767cf6 commit 26b7ff5

28 files changed

+560
-488
lines changed

test/test_autograd.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
onlyCPU, onlyCUDA, onlyOnCPUAndCUDA, dtypes, dtypesIfCUDA,
4343
deviceCountAtLeast, skipCUDAIfCudnnVersionLessThan,
4444
skipCUDAIf, skipMeta)
45+
from torch.testing._internal.common_dtype import get_all_dtypes
4546

4647
import pickle
4748

@@ -8474,7 +8475,7 @@ def test_copy_(self, device):
84748475
# At the time of writing this test, copy_ is not generated from native_functions.yaml
84758476
# there was a bug that bfloat16 was not recognized as floating.
84768477
x = torch.randn(10, device=device, requires_grad=True)
8477-
floating_dt = [dt for dt in torch.testing.get_all_dtypes() if dt.is_floating_point]
8478+
floating_dt = [dt for dt in get_all_dtypes() if dt.is_floating_point]
84788479
for dt in floating_dt:
84798480
y = torch.empty(10, device=device, dtype=dt)
84808481
y.copy_(x)

test/test_binary_ufuncs.py

+47-43
Large diffs are not rendered by default.

test/test_complex.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import torch
22
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes
33
from torch.testing._internal.common_utils import TestCase, run_tests
4+
from torch.testing._internal.common_dtype import get_all_complex_dtypes
45

56
devices = (torch.device('cpu'), torch.device('cuda:0'))
67

78
class TestComplexTensor(TestCase):
8-
@dtypes(*torch.testing.get_all_complex_dtypes())
9+
@dtypes(*get_all_complex_dtypes())
910
def test_to_list(self, device, dtype):
1011
# test that the complex float tensor has expected values and
1112
# there's no garbage value in the resultant list

test/test_foreach.py

+24-21
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
(instantiate_device_type_tests, dtypes, onlyCUDA, skipCUDAIfRocm, skipMeta, ops)
1212
from torch.testing._internal.common_methods_invocations import \
1313
(foreach_unary_op_db, foreach_binary_op_db, foreach_pointwise_op_db, foreach_minmax_op_db)
14+
from torch.testing._internal.common_dtype import (
15+
get_all_dtypes, get_all_int_dtypes, get_all_complex_dtypes, get_all_fp_dtypes,
16+
)
1417

1518
# Includes some values such that N * N won't be a multiple of 4,
1619
# which should ensure we test the vectorized and non-vectorized
@@ -133,7 +136,7 @@ def _test_binary_op_tensorlists(self, device, dtype, opinfo, N, is_fastpath, dis
133136
self._binary_test(dtype, inplace_op, inplace_ref, inputs, is_fastpath, is_inplace=True)
134137
if opinfo.supports_alpha_param:
135138
alpha = None
136-
if dtype in torch.testing.get_all_int_dtypes():
139+
if dtype in get_all_int_dtypes():
137140
alpha = 3
138141
elif dtype.is_complex:
139142
alpha = complex(3, 3)
@@ -170,7 +173,7 @@ def _test_binary_op_tensorlists(self, device, dtype, opinfo, N, is_fastpath, dis
170173
@ops(foreach_binary_op_db)
171174
def test_binary_op_tensorlists_fastpath(self, device, dtype, op):
172175
for N in N_values:
173-
disable_fastpath = op.ref == torch.div and dtype in torch.testing.get_all_int_dtypes() + [torch.bool]
176+
disable_fastpath = op.ref == torch.div and dtype in get_all_int_dtypes() + [torch.bool]
174177
if op.ref == torch.add and dtype == torch.bool:
175178
disable_fastpath = True
176179
self._test_binary_op_tensorlists(device, dtype, op, N, True, disable_fastpath)
@@ -192,17 +195,17 @@ def _test_binary_op_scalar(self, device, dtype, opinfo, N, scalar, is_fastpath,
192195
@ops(foreach_binary_op_db)
193196
def test_binary_op_scalar_fastpath(self, device, dtype, op):
194197
for N, scalar in itertools.product(N_values, Scalars):
195-
disable_fastpath = op.ref == torch.div and dtype in torch.testing.get_all_int_dtypes() + [torch.bool]
198+
disable_fastpath = op.ref == torch.div and dtype in get_all_int_dtypes() + [torch.bool]
196199
if isinstance(scalar, int):
197200
disable_fastpath |= dtype == torch.bool
198201
if isinstance(scalar, float):
199-
disable_fastpath |= dtype in torch.testing.get_all_int_dtypes() + [torch.bool]
202+
disable_fastpath |= dtype in get_all_int_dtypes() + [torch.bool]
200203
if isinstance(scalar, bool):
201204
disable_fastpath |= dtype == torch.bool
202205
if op.ref in (torch.add, torch.mul):
203206
disable_fastpath = False
204207
if isinstance(scalar, complex):
205-
disable_fastpath |= dtype not in torch.testing.get_all_complex_dtypes()
208+
disable_fastpath |= dtype not in get_all_complex_dtypes()
206209
self._test_binary_op_scalar(device, dtype, op, N, scalar, True, disable_fastpath)
207210

208211
@ops(foreach_binary_op_db)
@@ -232,16 +235,16 @@ def _test_binary_op_scalarlist(self, device, dtype, opinfo, N, scalarlist, is_fa
232235
def test_binary_op_scalarlist_fastpath(self, device, dtype, op):
233236
for N in N_values:
234237
for type_str, scalarlist in getScalarLists(N):
235-
bool_int_div = op.ref == torch.div and dtype in torch.testing.get_all_int_dtypes() + [torch.bool]
238+
bool_int_div = op.ref == torch.div and dtype in get_all_int_dtypes() + [torch.bool]
236239
disable_fastpath = bool_int_div
237240
if type_str == "int":
238241
disable_fastpath |= dtype == torch.bool
239242
if type_str == "float":
240-
disable_fastpath |= dtype in torch.testing.get_all_int_dtypes() + [torch.bool]
243+
disable_fastpath |= dtype in get_all_int_dtypes() + [torch.bool]
241244
if type_str == "complex":
242-
disable_fastpath |= dtype not in torch.testing.get_all_complex_dtypes()
245+
disable_fastpath |= dtype not in get_all_complex_dtypes()
243246
if type_str == "mixed":
244-
disable_fastpath |= True and dtype not in torch.testing.get_all_complex_dtypes()
247+
disable_fastpath |= True and dtype not in get_all_complex_dtypes()
245248
self._test_binary_op_scalarlist(device, dtype, op, N, scalarlist, True, disable_fastpath)
246249

247250
@ops(foreach_binary_op_db)
@@ -298,7 +301,7 @@ def _test_pointwise_op(self, device, dtype, opinfo, N, is_fastpath, disable_fast
298301
@skipMeta
299302
@ops(foreach_pointwise_op_db)
300303
def test_pointwise_op_fastpath(self, device, dtype, op):
301-
disable_fastpath = dtype in torch.testing.get_all_int_dtypes() + [torch.bool]
304+
disable_fastpath = dtype in get_all_int_dtypes() + [torch.bool]
302305
# for N, scalar in itertools.product(N_values, Scalars):
303306
for N in N_values:
304307
self._test_pointwise_op(device, dtype, op, N, True, disable_fastpath)
@@ -356,7 +359,7 @@ def _test_unary(self, device, dtype, opinfo, N, is_fastpath):
356359
op, ref, inplace_op, inplace_ref = self._get_funcs(opinfo, 1)
357360
inputs = opinfo.sample_inputs(device, dtype, N, noncontiguous=not is_fastpath),
358361
# note(mkozuki): Complex inputs for `_foreach_abs` go through slowpath.
359-
if opinfo.name == "_foreach_abs" and dtype in torch.testing.get_all_complex_dtypes():
362+
if opinfo.name == "_foreach_abs" and dtype in get_all_complex_dtypes():
360363
is_fastpath = False
361364
self._regular_unary_test(dtype, op, ref, inputs, is_fastpath)
362365
self._inplace_unary_test(dtype, inplace_op, inplace_ref, inputs, is_fastpath)
@@ -367,7 +370,7 @@ def test_unary_fastpath(self, device, dtype, op):
367370
for N in N_values:
368371
self._test_unary(device, dtype, op, N, is_fastpath=True)
369372

370-
@dtypes(*torch.testing.get_all_dtypes())
373+
@dtypes(*get_all_dtypes())
371374
@ops(foreach_unary_op_db)
372375
def test_unary_slowpath(self, device, dtype, op):
373376
for N in N_values:
@@ -378,14 +381,14 @@ def _minmax_test(self, opinfo, inputs, is_fastpath, n_expected_cudaLaunchKernels
378381
self.assertEqual(ref(inputs), op(inputs, self.is_cuda, is_fastpath))
379382

380383
# note(mkozuki): in-place of foreach_minimum and foreach_maximum aren't implemented.
381-
# @dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False, include_complex=False))
384+
# @dtypes(*get_all_dtypes(include_bfloat16=False, include_complex=False))
382385
@ops(foreach_minmax_op_db)
383386
def test_minmax_fastpath(self, device, dtype, op):
384387
for N in N_values:
385388
inputs = tuple(op.sample_inputs(device, dtype, N) for _ in range(2))
386389
self._minmax_test(op, inputs, True, N if dtype == torch.bool else 1)
387390

388-
@dtypes(*torch.testing.get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False))
391+
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False))
389392
@ops(foreach_minmax_op_db)
390393
def test_minmax_slowpath(self, device, dtype, op):
391394
for N in N_values:
@@ -394,7 +397,7 @@ def test_minmax_slowpath(self, device, dtype, op):
394397

395398
# note(mkozuki): ForeachFuncInfo's of both `_foreach_maximum` and `_foreach_minimum` include integer types.
396399
# so, manually limit dtypes to fp types for inf&nan tests.
397-
@dtypes(*torch.testing.get_all_fp_dtypes(include_bfloat16=True, include_half=True))
400+
@dtypes(*get_all_fp_dtypes(include_bfloat16=True, include_half=True))
398401
@ops(foreach_minmax_op_db)
399402
def test_minmax_float_inf_nan(self, device, dtype, op):
400403
inputs = (
@@ -413,7 +416,7 @@ def test_minmax_float_inf_nan(self, device, dtype, op):
413416
)
414417
self._minmax_test(op, inputs, True, 1)
415418

416-
@dtypes(*torch.testing.get_all_dtypes())
419+
@dtypes(*get_all_dtypes())
417420
def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype):
418421
# TODO: enable empty list case
419422
for tensors in [[torch.randn([0])]]:
@@ -423,7 +426,7 @@ def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype):
423426
torch._foreach_add_(tensors, 1)
424427
self.assertEqual(res, tensors)
425428

426-
@dtypes(*torch.testing.get_all_dtypes())
429+
@dtypes(*get_all_dtypes())
427430
@ops(foreach_binary_op_db)
428431
def test_binary_op_scalar_with_overlapping_tensors(self, device, dtype, op):
429432
foreach_op, ref = op.method_variant, op.ref
@@ -457,7 +460,7 @@ def test_binary_op_scalar_with_different_tensor_dtypes(self, device, dtype, op):
457460
runtime_error = e
458461
self.assertIsNone(runtime_error)
459462

460-
@dtypes(*torch.testing.get_all_dtypes())
463+
@dtypes(*get_all_dtypes())
461464
@ops(foreach_binary_op_db)
462465
def test_binary_op_list_error_cases(self, device, dtype, op):
463466
foreach_op, foreach_op_, ref, ref_ = op.method_variant, op.inplace_variant, op.ref, op.ref_inplace
@@ -513,7 +516,7 @@ def test_binary_op_list_error_cases(self, device, dtype, op):
513516
return
514517
with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
515518
foreach_op([tensor1], [tensor2])
516-
if dtype in torch.testing.get_all_int_dtypes() + [torch.bool] and foreach_op == torch._foreach_div:
519+
if dtype in get_all_int_dtypes() + [torch.bool] and foreach_op == torch._foreach_div:
517520
with self.assertRaisesRegex(RuntimeError, "result type"):
518521
foreach_op_([tensor1], [tensor2])
519522
else:
@@ -522,7 +525,7 @@ def test_binary_op_list_error_cases(self, device, dtype, op):
522525

523526
@skipMeta
524527
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
525-
@dtypes(*torch.testing.get_all_dtypes())
528+
@dtypes(*get_all_dtypes())
526529
@ops(foreach_binary_op_db)
527530
def test_binary_op_list_slow_path(self, device, dtype, op):
528531
# note(mkozuki): why `n_expected_cudaLaunchKernels=0`?
@@ -615,7 +618,7 @@ def test_binary_op_tensors_on_different_devices(self, device, dtype, op):
615618
self.assertEqual(actual, tensors1)
616619

617620
@onlyCUDA
618-
@dtypes(*torch.testing.get_all_fp_dtypes(include_half=False, include_bfloat16=False))
621+
@dtypes(*get_all_fp_dtypes(include_half=False, include_bfloat16=False))
619622
@ops(foreach_pointwise_op_db)
620623
def test_pointwise_op_tensors_on_different_devices(self, device, dtype, op):
621624
# tensors1: ['cuda', 'cpu]

0 commit comments

Comments
 (0)