11
11
(instantiate_device_type_tests , dtypes , onlyCUDA , skipCUDAIfRocm , skipMeta , ops )
12
12
from torch .testing ._internal .common_methods_invocations import \
13
13
(foreach_unary_op_db , foreach_binary_op_db , foreach_pointwise_op_db , foreach_minmax_op_db )
14
+ from torch .testing ._internal .common_dtype import (
15
+ get_all_dtypes , get_all_int_dtypes , get_all_complex_dtypes , get_all_fp_dtypes ,
16
+ )
14
17
15
18
# Includes some values such that N * N won't be a multiple of 4,
16
19
# which should ensure we test the vectorized and non-vectorized
@@ -133,7 +136,7 @@ def _test_binary_op_tensorlists(self, device, dtype, opinfo, N, is_fastpath, dis
133
136
self ._binary_test (dtype , inplace_op , inplace_ref , inputs , is_fastpath , is_inplace = True )
134
137
if opinfo .supports_alpha_param :
135
138
alpha = None
136
- if dtype in torch . testing . get_all_int_dtypes ():
139
+ if dtype in get_all_int_dtypes ():
137
140
alpha = 3
138
141
elif dtype .is_complex :
139
142
alpha = complex (3 , 3 )
@@ -170,7 +173,7 @@ def _test_binary_op_tensorlists(self, device, dtype, opinfo, N, is_fastpath, dis
170
173
@ops (foreach_binary_op_db )
171
174
def test_binary_op_tensorlists_fastpath (self , device , dtype , op ):
172
175
for N in N_values :
173
- disable_fastpath = op .ref == torch .div and dtype in torch . testing . get_all_int_dtypes () + [torch .bool ]
176
+ disable_fastpath = op .ref == torch .div and dtype in get_all_int_dtypes () + [torch .bool ]
174
177
if op .ref == torch .add and dtype == torch .bool :
175
178
disable_fastpath = True
176
179
self ._test_binary_op_tensorlists (device , dtype , op , N , True , disable_fastpath )
@@ -192,17 +195,17 @@ def _test_binary_op_scalar(self, device, dtype, opinfo, N, scalar, is_fastpath,
192
195
@ops (foreach_binary_op_db )
193
196
def test_binary_op_scalar_fastpath (self , device , dtype , op ):
194
197
for N , scalar in itertools .product (N_values , Scalars ):
195
- disable_fastpath = op .ref == torch .div and dtype in torch . testing . get_all_int_dtypes () + [torch .bool ]
198
+ disable_fastpath = op .ref == torch .div and dtype in get_all_int_dtypes () + [torch .bool ]
196
199
if isinstance (scalar , int ):
197
200
disable_fastpath |= dtype == torch .bool
198
201
if isinstance (scalar , float ):
199
- disable_fastpath |= dtype in torch . testing . get_all_int_dtypes () + [torch .bool ]
202
+ disable_fastpath |= dtype in get_all_int_dtypes () + [torch .bool ]
200
203
if isinstance (scalar , bool ):
201
204
disable_fastpath |= dtype == torch .bool
202
205
if op .ref in (torch .add , torch .mul ):
203
206
disable_fastpath = False
204
207
if isinstance (scalar , complex ):
205
- disable_fastpath |= dtype not in torch . testing . get_all_complex_dtypes ()
208
+ disable_fastpath |= dtype not in get_all_complex_dtypes ()
206
209
self ._test_binary_op_scalar (device , dtype , op , N , scalar , True , disable_fastpath )
207
210
208
211
@ops (foreach_binary_op_db )
@@ -232,16 +235,16 @@ def _test_binary_op_scalarlist(self, device, dtype, opinfo, N, scalarlist, is_fa
232
235
def test_binary_op_scalarlist_fastpath (self , device , dtype , op ):
233
236
for N in N_values :
234
237
for type_str , scalarlist in getScalarLists (N ):
235
- bool_int_div = op .ref == torch .div and dtype in torch . testing . get_all_int_dtypes () + [torch .bool ]
238
+ bool_int_div = op .ref == torch .div and dtype in get_all_int_dtypes () + [torch .bool ]
236
239
disable_fastpath = bool_int_div
237
240
if type_str == "int" :
238
241
disable_fastpath |= dtype == torch .bool
239
242
if type_str == "float" :
240
- disable_fastpath |= dtype in torch . testing . get_all_int_dtypes () + [torch .bool ]
243
+ disable_fastpath |= dtype in get_all_int_dtypes () + [torch .bool ]
241
244
if type_str == "complex" :
242
- disable_fastpath |= dtype not in torch . testing . get_all_complex_dtypes ()
245
+ disable_fastpath |= dtype not in get_all_complex_dtypes ()
243
246
if type_str == "mixed" :
244
- disable_fastpath |= True and dtype not in torch . testing . get_all_complex_dtypes ()
247
+ disable_fastpath |= True and dtype not in get_all_complex_dtypes ()
245
248
self ._test_binary_op_scalarlist (device , dtype , op , N , scalarlist , True , disable_fastpath )
246
249
247
250
@ops (foreach_binary_op_db )
@@ -298,7 +301,7 @@ def _test_pointwise_op(self, device, dtype, opinfo, N, is_fastpath, disable_fast
298
301
@skipMeta
299
302
@ops (foreach_pointwise_op_db )
300
303
def test_pointwise_op_fastpath (self , device , dtype , op ):
301
- disable_fastpath = dtype in torch . testing . get_all_int_dtypes () + [torch .bool ]
304
+ disable_fastpath = dtype in get_all_int_dtypes () + [torch .bool ]
302
305
# for N, scalar in itertools.product(N_values, Scalars):
303
306
for N in N_values :
304
307
self ._test_pointwise_op (device , dtype , op , N , True , disable_fastpath )
@@ -356,7 +359,7 @@ def _test_unary(self, device, dtype, opinfo, N, is_fastpath):
356
359
op , ref , inplace_op , inplace_ref = self ._get_funcs (opinfo , 1 )
357
360
inputs = opinfo .sample_inputs (device , dtype , N , noncontiguous = not is_fastpath ),
358
361
# note(mkozuki): Complex inputs for `_foreach_abs` go through slowpath.
359
- if opinfo .name == "_foreach_abs" and dtype in torch . testing . get_all_complex_dtypes ():
362
+ if opinfo .name == "_foreach_abs" and dtype in get_all_complex_dtypes ():
360
363
is_fastpath = False
361
364
self ._regular_unary_test (dtype , op , ref , inputs , is_fastpath )
362
365
self ._inplace_unary_test (dtype , inplace_op , inplace_ref , inputs , is_fastpath )
@@ -367,7 +370,7 @@ def test_unary_fastpath(self, device, dtype, op):
367
370
for N in N_values :
368
371
self ._test_unary (device , dtype , op , N , is_fastpath = True )
369
372
370
- @dtypes (* torch . testing . get_all_dtypes ())
373
+ @dtypes (* get_all_dtypes ())
371
374
@ops (foreach_unary_op_db )
372
375
def test_unary_slowpath (self , device , dtype , op ):
373
376
for N in N_values :
@@ -378,14 +381,14 @@ def _minmax_test(self, opinfo, inputs, is_fastpath, n_expected_cudaLaunchKernels
378
381
self .assertEqual (ref (inputs ), op (inputs , self .is_cuda , is_fastpath ))
379
382
380
383
# note(mkozuki): in-place of foreach_minimum and foreach_maximum aren't implemented.
381
- # @dtypes(*torch.testing. get_all_dtypes(include_bfloat16=False, include_complex=False))
384
+ # @dtypes(*get_all_dtypes(include_bfloat16=False, include_complex=False))
382
385
@ops (foreach_minmax_op_db )
383
386
def test_minmax_fastpath (self , device , dtype , op ):
384
387
for N in N_values :
385
388
inputs = tuple (op .sample_inputs (device , dtype , N ) for _ in range (2 ))
386
389
self ._minmax_test (op , inputs , True , N if dtype == torch .bool else 1 )
387
390
388
- @dtypes (* torch . testing . get_all_dtypes (include_half = True , include_bfloat16 = True , include_complex = False ))
391
+ @dtypes (* get_all_dtypes (include_half = True , include_bfloat16 = True , include_complex = False ))
389
392
@ops (foreach_minmax_op_db )
390
393
def test_minmax_slowpath (self , device , dtype , op ):
391
394
for N in N_values :
@@ -394,7 +397,7 @@ def test_minmax_slowpath(self, device, dtype, op):
394
397
395
398
# note(mkozuki): ForeachFuncInfo's of both `_foreach_maximum` and `_foreach_minimum` include integer types.
396
399
# so, manually limit dtypes to fp types for inf&nan tests.
397
- @dtypes (* torch . testing . get_all_fp_dtypes (include_bfloat16 = True , include_half = True ))
400
+ @dtypes (* get_all_fp_dtypes (include_bfloat16 = True , include_half = True ))
398
401
@ops (foreach_minmax_op_db )
399
402
def test_minmax_float_inf_nan (self , device , dtype , op ):
400
403
inputs = (
@@ -413,7 +416,7 @@ def test_minmax_float_inf_nan(self, device, dtype, op):
413
416
)
414
417
self ._minmax_test (op , inputs , True , 1 )
415
418
416
- @dtypes (* torch . testing . get_all_dtypes ())
419
+ @dtypes (* get_all_dtypes ())
417
420
def test_add_scalar_with_empty_list_and_empty_tensor (self , device , dtype ):
418
421
# TODO: enable empty list case
419
422
for tensors in [[torch .randn ([0 ])]]:
@@ -423,7 +426,7 @@ def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype):
423
426
torch ._foreach_add_ (tensors , 1 )
424
427
self .assertEqual (res , tensors )
425
428
426
- @dtypes (* torch . testing . get_all_dtypes ())
429
+ @dtypes (* get_all_dtypes ())
427
430
@ops (foreach_binary_op_db )
428
431
def test_binary_op_scalar_with_overlapping_tensors (self , device , dtype , op ):
429
432
foreach_op , ref = op .method_variant , op .ref
@@ -457,7 +460,7 @@ def test_binary_op_scalar_with_different_tensor_dtypes(self, device, dtype, op):
457
460
runtime_error = e
458
461
self .assertIsNone (runtime_error )
459
462
460
- @dtypes (* torch . testing . get_all_dtypes ())
463
+ @dtypes (* get_all_dtypes ())
461
464
@ops (foreach_binary_op_db )
462
465
def test_binary_op_list_error_cases (self , device , dtype , op ):
463
466
foreach_op , foreach_op_ , ref , ref_ = op .method_variant , op .inplace_variant , op .ref , op .ref_inplace
@@ -513,7 +516,7 @@ def test_binary_op_list_error_cases(self, device, dtype, op):
513
516
return
514
517
with self .assertRaisesRegex (RuntimeError , "Expected all tensors to be on the same device" ):
515
518
foreach_op ([tensor1 ], [tensor2 ])
516
- if dtype in torch . testing . get_all_int_dtypes () + [torch .bool ] and foreach_op == torch ._foreach_div :
519
+ if dtype in get_all_int_dtypes () + [torch .bool ] and foreach_op == torch ._foreach_div :
517
520
with self .assertRaisesRegex (RuntimeError , "result type" ):
518
521
foreach_op_ ([tensor1 ], [tensor2 ])
519
522
else :
@@ -522,7 +525,7 @@ def test_binary_op_list_error_cases(self, device, dtype, op):
522
525
523
526
@skipMeta
524
527
@unittest .skipIf (not torch .cuda .is_available (), "CUDA not found" )
525
- @dtypes (* torch . testing . get_all_dtypes ())
528
+ @dtypes (* get_all_dtypes ())
526
529
@ops (foreach_binary_op_db )
527
530
def test_binary_op_list_slow_path (self , device , dtype , op ):
528
531
# note(mkozuki): why `n_expected_cudaLaunchKernels=0`?
@@ -615,7 +618,7 @@ def test_binary_op_tensors_on_different_devices(self, device, dtype, op):
615
618
self .assertEqual (actual , tensors1 )
616
619
617
620
@onlyCUDA
618
- @dtypes (* torch . testing . get_all_fp_dtypes (include_half = False , include_bfloat16 = False ))
621
+ @dtypes (* get_all_fp_dtypes (include_half = False , include_bfloat16 = False ))
619
622
@ops (foreach_pointwise_op_db )
620
623
def test_pointwise_op_tensors_on_different_devices (self , device , dtype , op ):
621
624
# tensors1: ['cuda', 'cpu]
0 commit comments