1
1
import torch .testing ._internal .common_utils as common
2
+ from torch .testing import make_tensor
2
3
from torch .testing ._internal .common_device_type import (
3
4
instantiate_device_type_tests ,
4
5
dtypes
@@ -23,7 +24,7 @@ def _run_test(self, shape, dtype, count=-1, first=0, offset=None, **kwargs):
23
24
if offset is None :
24
25
offset = first * get_dtype_size (dtype )
25
26
26
- numpy_original = common . make_tensor (shape , torch .device ("cpu" ), dtype ).numpy ()
27
+ numpy_original = make_tensor (shape , torch .device ("cpu" ), dtype ).numpy ()
27
28
original = memoryview (numpy_original )
28
29
# First call PyTorch's version in case of errors.
29
30
# If this call exits successfully, the NumPy version must also do so.
@@ -125,7 +126,7 @@ def test_invalid_positional_args(self, device, dtype):
125
126
126
127
@dtypes (* common .torch_to_numpy_dtype_dict .keys ())
127
128
def test_shared_buffer (self , device , dtype ):
128
- x = common . make_tensor ((1 ,), device , dtype )
129
+ x = make_tensor ((1 ,), device , dtype )
129
130
# Modify the whole tensor
130
131
arr , tensor = self ._run_test (SHAPE , dtype )
131
132
tensor [:] = x
@@ -158,7 +159,7 @@ def test_not_a_buffer(self, device, dtype):
158
159
159
160
@dtypes (* common .torch_to_numpy_dtype_dict .keys ())
160
161
def test_non_writable_buffer (self , device , dtype ):
161
- numpy_arr = common . make_tensor ((1 ,), device , dtype ).numpy ()
162
+ numpy_arr = make_tensor ((1 ,), device , dtype ).numpy ()
162
163
byte_arr = numpy_arr .tobytes ()
163
164
with self .assertWarnsOnceRegex (UserWarning ,
164
165
r"The given buffer is not writable." ):
0 commit comments