Skip to content

Commit

Permalink
Stop ignoring errors in cuda nn module tests. (pytorch#44783)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#44783

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D23731778

Pulled By: gchanan

fbshipit-source-id: 32df903a9e36bbf3f66645ee2d77efa5ed6ee429
  • Loading branch information
gchanan authored and facebook-github-bot committed Sep 18, 2020
1 parent df39c40 commit 6d178f6
Showing 1 changed file with 71 additions and 79 deletions.
150 changes: 71 additions & 79 deletions torch/testing/_internal/common_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4775,89 +4775,81 @@ def test_noncontig(self, test_case, module, input):
def test_cuda(self, test_case):
if not TEST_CUDA or not self.should_test_cuda:
raise unittest.SkipTest('Excluded from CUDA tests')
try:
cpu_input = self._get_input()
type_map = {'torch.DoubleTensor': torch.cuda.FloatTensor}
gpu_input = to_gpu(cpu_input, type_map=type_map)

cpu_module = self.constructor(*self.constructor_args)
gpu_module = self.constructor(*self.constructor_args).float().cuda()
cpu_param = test_case._get_parameters(cpu_module)
gpu_param = test_case._get_parameters(gpu_module)
for cpu_p, gpu_p in zip(cpu_param[0], gpu_param[0]):
gpu_p.data.copy_(cpu_p)

test_case._zero_grad_input(cpu_input)
test_case._zero_grad_input(gpu_input)
test_case._zero_grad_parameters(cpu_module)
test_case._zero_grad_parameters(gpu_module)
cpu_output = test_case._forward(cpu_module, cpu_input)
gpu_output = test_case._forward(gpu_module, gpu_input)
cpu_input = self._get_input()
type_map = {'torch.DoubleTensor': torch.cuda.FloatTensor}
gpu_input = to_gpu(cpu_input, type_map=type_map)

cpu_module = self.constructor(*self.constructor_args)
gpu_module = self.constructor(*self.constructor_args).float().cuda()
cpu_param = test_case._get_parameters(cpu_module)
gpu_param = test_case._get_parameters(gpu_module)
for cpu_p, gpu_p in zip(cpu_param[0], gpu_param[0]):
gpu_p.data.copy_(cpu_p)

test_case._zero_grad_input(cpu_input)
test_case._zero_grad_input(gpu_input)
test_case._zero_grad_parameters(cpu_module)
test_case._zero_grad_parameters(gpu_module)
cpu_output = test_case._forward(cpu_module, cpu_input)
gpu_output = test_case._forward(gpu_module, gpu_input)
# TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
test_case.assertEqualIgnoreType(cpu_output, gpu_output, atol=self.precision, rtol=0)

# Run backwards on CPU and GPU and compare results
for _ in range(5):
cpu_gradOutput = cpu_output.clone().normal_()
gpu_gradOutput = cpu_gradOutput.type('torch.cuda.FloatTensor')
cpu_gradInput = test_case._backward(cpu_module, cpu_input, cpu_output, cpu_gradOutput)
gpu_gradInput = test_case._backward(gpu_module, gpu_input, gpu_output, gpu_gradOutput)
# TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
test_case.assertEqualIgnoreType(cpu_output, gpu_output, atol=self.precision, rtol=0)

# Run backwards on CPU and GPU and compare results
for _ in range(5):
cpu_gradOutput = cpu_output.clone().normal_()
gpu_gradOutput = cpu_gradOutput.type('torch.cuda.FloatTensor')
cpu_gradInput = test_case._backward(cpu_module, cpu_input, cpu_output, cpu_gradOutput)
gpu_gradInput = test_case._backward(gpu_module, gpu_input, gpu_output, gpu_gradOutput)
test_case.assertEqualIgnoreType(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0)
for cpu_d_p, gpu_d_p in zip(cpu_param[1], gpu_param[1]):
test_case.assertEqual(cpu_d_p, gpu_d_p, atol=self.precision, rtol=0)

# Run double-backwards on CPU and GPU and compare results
if self.check_gradgrad and not self.FIXME_no_cuda_gradgrad_comparison:
cpu_output = cpu_module(cpu_input)
gpu_output = gpu_module(gpu_input)

cpu_gradOutput = torch.randn_like(cpu_output, requires_grad=True)
gpu_gradOutput = cpu_gradOutput.type_as(gpu_output).detach()
gpu_gradOutput.requires_grad = True

cpu_gradInputs = torch.autograd.grad(
cpu_output,
(cpu_input,) + tuple(cpu_module.parameters()),
cpu_gradOutput,
create_graph=True)
gpu_gradInputs = torch.autograd.grad(
gpu_output,
(gpu_input,) + tuple(gpu_module.parameters()),
gpu_gradOutput,
create_graph=True)

for cpu_d_i, gpu_d_i in zip(cpu_gradInputs, gpu_gradInputs):
# TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
test_case.assertEqualIgnoreType(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0)
for cpu_d_p, gpu_d_p in zip(cpu_param[1], gpu_param[1]):
test_case.assertEqual(cpu_d_p, gpu_d_p, atol=self.precision, rtol=0)

# Run double-backwards on CPU and GPU and compare results
if self.check_gradgrad and not self.FIXME_no_cuda_gradgrad_comparison:
cpu_output = cpu_module(cpu_input)
gpu_output = gpu_module(gpu_input)

cpu_gradOutput = torch.randn_like(cpu_output, requires_grad=True)
gpu_gradOutput = cpu_gradOutput.type_as(gpu_output).detach()
gpu_gradOutput.requires_grad = True

cpu_gradInputs = torch.autograd.grad(
cpu_output,
(cpu_input,) + tuple(cpu_module.parameters()),
cpu_gradOutput,
create_graph=True)
gpu_gradInputs = torch.autograd.grad(
gpu_output,
(gpu_input,) + tuple(gpu_module.parameters()),
gpu_gradOutput,
create_graph=True)

for cpu_d_i, gpu_d_i in zip(cpu_gradInputs, gpu_gradInputs):
# TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
test_case.assertEqualIgnoreType(cpu_d_i, gpu_d_i, atol=self.precision, rtol=0)

# We mix output into the second backwards computation so that
# torch.autograd.grad doesn't complain that some inputs
# are unreachable (which can happen if you differentiate
# only on the gradient.
cpu_gg = torch.autograd.grad(
cpu_output.sum() + sum(map(lambda x: x.sum(), cpu_gradInputs)),
(cpu_input, cpu_gradOutput) + tuple(cpu_module.parameters()),
retain_graph=True)
gpu_gg = torch.autograd.grad(
gpu_output.sum() + sum(map(lambda x: x.sum(), gpu_gradInputs)),
(gpu_input, gpu_gradOutput) + tuple(gpu_module.parameters()),
retain_graph=True)
test_case.assertEqualIgnoreType(cpu_d_i, gpu_d_i, atol=self.precision, rtol=0)

# We mix output into the second backwards computation so that
# torch.autograd.grad doesn't complain that some inputs
# are unreachable (which can happen if you differentiate
# only on the gradient.
cpu_gg = torch.autograd.grad(
cpu_output.sum() + sum(map(lambda x: x.sum(), cpu_gradInputs)),
(cpu_input, cpu_gradOutput) + tuple(cpu_module.parameters()),
retain_graph=True)
gpu_gg = torch.autograd.grad(
gpu_output.sum() + sum(map(lambda x: x.sum(), gpu_gradInputs)),
(gpu_input, gpu_gradOutput) + tuple(gpu_module.parameters()),
retain_graph=True)
# TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
test_case.assertEqualIgnoreType(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0)
for cpu_d_p, gpu_d_p in zip(cpu_gg, gpu_gg):
# TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
test_case.assertEqualIgnoreType(cpu_gradInput, gpu_gradInput, atol=self.precision, rtol=0)
for cpu_d_p, gpu_d_p in zip(cpu_gg, gpu_gg):
# TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
test_case.assertEqualIgnoreType(cpu_d_p, gpu_d_p, atol=self.precision, rtol=0)

self.test_noncontig(test_case, gpu_module, gpu_input)
except NotImplementedError:
pass
# TODO: remove this after CUDA scatter_ is implemented
except AttributeError as e:
if len(e.args) == 1 and "'FloatTensor' object has no attribute 'scatter_'" in e.args[0]:
pass
else:
raise
test_case.assertEqualIgnoreType(cpu_d_p, gpu_d_p, atol=self.precision, rtol=0)

self.test_noncontig(test_case, gpu_module, gpu_input)


class InputVariableMixin(object):
Expand Down

0 comments on commit 6d178f6

Please sign in to comment.