From 58d1cf7e39ec8e2eec5150e2d227e9a3da9f9bb6 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Wed, 3 Aug 2022 22:45:39 +0000 Subject: [PATCH] Fix issue 38095 TODOs in test_jit (#82629) Fix TODOs related to https://github.com/pytorch/pytorch/issues/38095 in test_jit.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/82629 Approved by: https://github.com/clee2000, https://github.com/malfet --- test/test_jit.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index b07b83cc40c28..4c8a2d7176a3b 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -5754,10 +5754,9 @@ def test_integral_shape_inference(a): return a * a ''') inputs = [torch.ones(10, 10, dtype=torch.long)] - outputs = torch.ones(10, 10) + outputs = torch.ones(10, 10, dtype=torch.long) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(cu.test_integral_shape_inference(*inputs), outputs) + self.assertEqual(cu.test_integral_shape_inference(*inputs), outputs) @unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser') @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle") @@ -7327,9 +7326,7 @@ def func(): if inp == 'empty_list': # torchscript returns int tensor, python returns float tensor self.assertNotEqual(t1.dtype, t2.dtype) - - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(t1, t2) + self.assertEqual(t1, t2, exact_dtype=False) self.assertEqual(t1.device, t2.device) @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Simple Executor doesn't have any shapes to propagate") @@ -15354,8 +15351,7 @@ def forward(self, key): # TODO: re-enable module hook when Python printing of attributes is # supported m = M({char : torch.ones(1) + ord(char) - ord("a") for char in "abcdefg"}) - # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 - self.assertEqualIgnoreType(m("c"), torch.tensor([103])) + self.assertEqual(m("c"), torch.tensor([103.])) def test_module_none_attrs(self): class MyMod(torch.jit.ScriptModule):