Skip to content

Commit

Permalink
Expand the test of torch.bmm on CUDA (pytorch#47124)
Browse files Browse the repository at this point in the history
Summary:
basically pytorch#47070, enabled on all CI with `ci-all`

Pull Request resolved: pytorch#47124

Reviewed By: ejguan

Differential Revision: D24735130

Pulled By: ngimel

fbshipit-source-id: c2124562a9f9d1caf24686e5d8a1106c79366233
  • Loading branch information
zasdfgbnm authored and facebook-github-bot committed Nov 5, 2020
1 parent 32c76db commit 030caa1
Showing 1 changed file with 50 additions and 12 deletions.
62 changes: 50 additions & 12 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17997,21 +17997,59 @@ def test_strided_mm_bmm(self, device, dtype):
torch_fn = lambda x: torch.mm(x, x) # noqa: E731
self.compare_with_numpy(torch_fn, np_fn, sx[0])

@onlyCPU
@dtypes(*(torch.testing.get_all_complex_dtypes() + [torch.float, torch.double]))
@skipCUDAIf(torch.version.cuda == "10.1", "flaky on CUDA 10.1")
@onlyOnCPUAndCUDA
@dtypes(*torch.testing.get_all_fp_dtypes(), *torch.testing.get_all_complex_dtypes())
@tf32_on_and_off(0.05)
def test_bmm(self, device, dtype):
num_batches = 10
M, N, O = 23, 8, 12
b1 = torch.randn(num_batches, M, N, dtype=dtype, device=device)
b2 = torch.randn(num_batches, N, O, dtype=dtype, device=device)
res = torch.bmm(b1, b2)
for i in range(num_batches):
r = torch.mm(b1[i], b2[i])
self.assertEqual(r, res[i])
if torch.cuda.is_available():
# check that mixed arguments are rejected
self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2.cuda()))
self.assertRaises(RuntimeError, lambda: torch.bmm(b1.cuda(), b2))
numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32

cpu_supported_dtypes = torch.testing.floating_and_complex_types()
cuda_supported_dtypes = torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM)
is_supported = (self.device_type == 'cpu' and dtype in cpu_supported_dtypes) or \
(self.device_type == 'cuda' and dtype in cuda_supported_dtypes)

if not is_supported:
b1 = torch.randn(num_batches, M, N, device=device).to(dtype)
b2 = torch.randn(num_batches, N, O, device=device).to(dtype)
self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|Ampere", lambda: torch.bmm(b1, b2))
return

def invert_perm(p):
d = {x: i for i, x in enumerate(p)}
return (d[0], d[1], d[2])

def generate_inputs():
for perm1, perm2 in product(permutations((0, 1, 2)), repeat=2):
b1 = torch.randn(num_batches, M, N, dtype=dtype, device=device)
b2 = torch.randn(num_batches, N, O, dtype=dtype, device=device)
b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
yield b1, b2
for b1, b2, b3, b4, b5, b6 in product((True, False), repeat=6):
shape1 = (num_batches if b1 else 1, M if b2 else 1, N if b3 else 1)
shape2 = (num_batches if b4 else 1, N if b5 else 1, O if b6 else 1)
b1 = torch.randn(shape1, dtype=dtype, device=device).expand(num_batches, M, N)
b2 = torch.randn(shape2, dtype=dtype, device=device).expand(num_batches, N, O)
yield b1, b2

for (b1, b2), perm3 in product(generate_inputs(), permutations((0, 1, 2))):
res1 = torch.bmm(b1, b2)
res2 = torch.full((num_batches, M, O), math.nan, dtype=dtype, device=device) \
.permute(perm3).contiguous().permute(invert_perm(perm3))
torch.bmm(b1, b2, out=res2)
expect = torch.from_numpy(
b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype)
self.assertEqual(expect, res1)
self.assertEqual(expect, res2)

if self.device_type == 'cuda':
# check that mixed arguments are rejected
self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2.cpu()))
self.assertRaises(RuntimeError, lambda: torch.bmm(b1.cpu(), b2))
self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2, out=res2.cpu()))

@onlyCUDA
@wrapDeterministicFlagAPITest
Expand Down

0 comments on commit 030caa1

Please sign in to comment.