Skip to content

Commit

Permalink
Updated docs/test for dot and vdot (pytorch#47242)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#47242

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D24733771

Pulled By: heitorschueroff

fbshipit-source-id: 92e3b0e28e0565918335fa85d52abe5db9eeff57
  • Loading branch information
heitorschueroff authored and facebook-github-bot committed Nov 5, 2020
1 parent d8c3b2b commit a4ba018
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 98 deletions.
71 changes: 70 additions & 1 deletion test/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
(TestCase, run_tests, TEST_NUMPY, IS_MACOS, IS_WINDOWS, TEST_WITH_ASAN, make_tensor)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, dtypes, dtypesIfCUDA,
onlyCUDA, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride)
onlyCUDA, onlyOnCPUAndCUDA, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride)
from torch.testing._internal.jit_metaprogramming_utils import gen_script_fn_and_args
from torch.autograd import gradcheck

Expand Down Expand Up @@ -1139,6 +1139,75 @@ def test_tensorsolve_errors_and_warnings(self, device, dtype):
with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match self dtype"):
torch.linalg.tensorsolve(a, b, out=out)

def _test_dot_vdot_vs_numpy(self, device, dtype, torch_fn, np_fn):
def check(x, y):
# Compare with numpy
res = torch_fn(x, y)
ref = torch.from_numpy(np.array(np_fn(x.cpu().numpy(), y.cpu().numpy())))
self.assertEqual(res.cpu(), ref)

# Test out variant
out = torch.empty_like(res)
torch_fn(x, y, out=out)
self.assertEqual(out, res)

# Empty
x = torch.tensor([], dtype=dtype, device=device)
y = torch.tensor([], dtype=dtype, device=device)
check(x, y)

# Contiguous
x = torch.randn(10, dtype=dtype, device=device)
y = torch.randn(10, dtype=dtype, device=device)
check(x, y)

# 0 strided
y = torch.randn(1, dtype=dtype, device=device).expand(10)
check(x, y)

# 2 strided
check(x[::2], y[::2])

@dtypes(torch.float, torch.cfloat)
@precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5})
def test_dot_vs_numpy(self, device, dtype):
self._test_dot_vdot_vs_numpy(device, dtype, torch.dot, np.dot)

@dtypes(torch.float, torch.cfloat)
@precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5})
def test_vdot_vs_numpy(self, device, dtype):
self._test_dot_vdot_vs_numpy(device, dtype, torch.vdot, np.vdot)

def _test_dot_vdot_invalid_args(self, device, torch_fn, complex_dtypes=False):
def check(x, y, regex):
with self.assertRaisesRegex(RuntimeError, regex):
torch_fn(x, y)

if complex_dtypes:
x = torch.randn(1, dtype=torch.cfloat, device=device)
y = torch.randn(3, dtype=torch.cdouble, device=device)
else:
x = torch.randn(1, dtype=torch.float, device=device)
y = torch.randn(3, dtype=torch.double, device=device)

check(x, y, 'dot : expected both vectors to have same dtype')
check(x.reshape(1, 1), y, '1D tensors expected')
check(x.expand(9), y.to(x.dtype), 'inconsistent tensor size')

if self.device_type != 'cpu':
x_cpu = x.expand(3).cpu()
check(x_cpu, y.to(x.dtype), 'expected all tensors to be on the same device')

@onlyOnCPUAndCUDA
def test_vdot_invalid_args(self, device):
self._test_dot_vdot_invalid_args(device, torch.vdot)
self._test_dot_vdot_invalid_args(device, torch.vdot, complex_dtypes=True)

@onlyOnCPUAndCUDA
def test_dot_invalid_args(self, device):
self._test_dot_vdot_invalid_args(device, torch.dot)
self._test_dot_vdot_invalid_args(device, torch.dot, complex_dtypes=True)

instantiate_device_type_tests(TestLinalg, globals())

if __name__ == '__main__':
Expand Down
84 changes: 0 additions & 84 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17082,90 +17082,6 @@ def test_matmul_45724(self, device):
torch.matmul(a, b, out=c)
self.assertEqual(c, cpu_result)

def _test_dot_vdot_vs_numpy(self, device, dtype, torch_fn, np_fn):
def compare_with_numpy_bin_op(torch_fn, np_fn, x, y):
y_np = y.cpu().numpy()

# `compare_with_numpy` takes care of moving `x` to correct device for calling np_fn.
self.compare_with_numpy(lambda inp: torch_fn(inp, y), lambda inp: np_fn(inp, y_np), x)

# Use this tensor for out variant tests.
out = torch.randn((), dtype=dtype, device=device)

def compare_out_variant(torch_fn, x, y):
torch_fn(v1, v2, out=out)
self.assertEqual(torch_fn(v1, v2), out)

for _ in range(10):
numel = random.randint(10, 1000)
v1 = torch.randn(numel, dtype=dtype, device=device)
v2 = torch.randn(numel, dtype=dtype, device=device)
compare_with_numpy_bin_op(torch_fn, np_fn, v1, v2)
compare_out_variant(torch_fn, v1, v2)

# Test 0-strided
v3 = torch.randn(1, dtype=dtype, device=device).expand(numel)
compare_with_numpy_bin_op(torch_fn, np_fn, v1, v3)
compare_out_variant(torch_fn, v1, v3)

compare_with_numpy_bin_op(torch_fn, np_fn, v3, v1)
compare_out_variant(torch_fn, v3, v1)

# Test stride greater than 1
v4 = torch.randn(numel, numel, dtype=dtype, device=device)[:, numel - 1]
compare_with_numpy_bin_op(torch_fn, np_fn, v1, v4)
compare_out_variant(torch_fn, v1, v4)

compare_with_numpy_bin_op(torch_fn, np_fn, v4, v1)
compare_out_variant(torch_fn, v4, v1)

@precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5})
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
def test_dot_vs_numpy(self, device, dtype):
self._test_dot_vdot_vs_numpy(device, dtype, torch.dot, np.dot)

@precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5})
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
def test_vdot_vs_numpy(self, device, dtype):
self._test_dot_vdot_vs_numpy(device, dtype, torch.vdot, np.vdot)

def _test_dot_vdot_invalid_args(self, device, torch_fn, complex_dtypes=False):
if complex_dtypes:
x = torch.randn(1, dtype=torch.cfloat, device=device)
y = torch.randn(3, dtype=torch.cdouble, device=device)
else:
x = torch.randn(1, dtype=torch.float, device=device)
y = torch.randn(3, dtype=torch.double, device=device)

with self.assertRaisesRegex(RuntimeError,
'dot : expected both vectors to have same dtype'):
torch_fn(x, y)

with self.assertRaisesRegex(RuntimeError,
'1D tensors expected'):
torch_fn(x.reshape(1, 1), y)

with self.assertRaisesRegex(RuntimeError,
'inconsistent tensor size'):
torch_fn(x.expand(9), y.to(x.dtype))

if self.device_type != 'cpu':
x_cpu = x.expand(3).cpu()

with self.assertRaisesRegex(RuntimeError,
'expected all tensors to be on the same device'):
torch_fn(x_cpu, y.to(x.dtype))

@onlyOnCPUAndCUDA
def test_vdot_invalid_args(self, device):
self._test_dot_vdot_invalid_args(device, torch.vdot)
self._test_dot_vdot_invalid_args(device, torch.vdot, complex_dtypes=True)

@onlyOnCPUAndCUDA
def test_dot_invalid_args(self, device):
self._test_dot_vdot_invalid_args(device, torch.dot)
self._test_dot_vdot_invalid_args(device, torch.dot, complex_dtypes=True)

@onlyCPU
@slowTest
@dtypes(torch.float)
Expand Down
4 changes: 2 additions & 2 deletions torch/_tensor_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,7 +1223,7 @@ def add_docstr_all(method, docstr):

add_docstr_all('dot',
r"""
dot(tensor2) -> Tensor
dot(other) -> Tensor
See :func:`torch.dot`
""")
Expand Down Expand Up @@ -4024,7 +4024,7 @@ def callable(a, b) -> number

add_docstr_all('vdot',
r"""
dot(other) -> Tensor
vdot(other) -> Tensor
See :func:`torch.vdot`
""")
Expand Down
31 changes: 22 additions & 9 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2610,11 +2610,21 @@ def merge_dicts(*dicts):

add_docstr(torch.dot,
r"""
dot(input, tensor) -> Tensor
dot(input, other, *, out=None) -> Tensor
Computes the dot product (inner product) of two tensors.
Computes the dot product of two 1D tensors.
.. note:: This function does not :ref:`broadcast <broadcasting-semantics>`.
.. note::
Unlike NumPy's dot, torch.dot intentionally only supports computing the dot product
of two 1D tensors with the same number of elements.
Args:
input (Tensor): first tensor in the dot product, must be 1D.
other (Tensor): second tensor in the dot product, must be 1D.
Keyword args:
{out}
Example::
Expand All @@ -2626,15 +2636,18 @@ def merge_dicts(*dicts):
r"""
vdot(input, other, *, out=None) -> Tensor
Computes the dot product (inner product) of two tensors. The vdot(a, b) function
handles complex numbers differently than dot(a, b). If the first argument is complex,
the complex conjugate of the first argument is used for the calculation of the dot product.
Computes the dot product of two 1D tensors. The vdot(a, b) function handles complex numbers
differently than dot(a, b). If the first argument is complex, the complex conjugate of the
first argument is used for the calculation of the dot product.
.. note:: This function does not :ref:`broadcast <broadcasting-semantics>`.
.. note::
Unlike NumPy's vdot, torch.vdot intentionally only supports computing the dot product
of two 1D tensors with the same number of elements.
Args:
input (Tensor): first tensor in the dot product. Its conjugate is used if it's complex.
other (Tensor): second tensor in the dot product.
input (Tensor): first tensor in the dot product, must be 1D. Its conjugate is used if it's complex.
other (Tensor): second tensor in the dot product, must be 1D.
Keyword args:
{out}
Expand Down
4 changes: 2 additions & 2 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.dist: lambda input, other, p=2: -1,
torch.div: lambda input, other, out=None: -1,
torch.divide: lambda input, other, out=None: -1,
torch.dot: lambda mat1, mat2: -1,
torch.dot: lambda input, other, out=None: -1,
torch.dropout: lambda input, p, train, inplace=False: -1,
torch.dsmm: lambda input, mat2: -1,
torch.hsmm: lambda mat1, mat2: -1,
Expand Down Expand Up @@ -681,7 +681,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.randn_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
torch.ravel: lambda input: -1,
torch.real: lambda input, out=None: -1,
torch.vdot: lambda mat1, mat2: -1,
torch.vdot: lambda input, other, out=None: -1,
torch.view_as_real: lambda input: -1,
torch.view_as_complex: lambda input: -1,
torch.reciprocal: lambda input, out=None: -1,
Expand Down

0 comments on commit a4ba018

Please sign in to comment.