diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 5472e3f080fc4..8583dfec1c774 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -1324,7 +1324,7 @@ struct TORCH_API NoneType : public Type { return rhs.kind() == kind(); } std::string str() const override { - return "None"; + return "NoneType"; } bool isSubtypeOfExt(const TypePtr& rhs, std::ostream *why_not) const override { if (rhs->kind() == OptionalType::Kind) { diff --git a/test/expect/TestJit.test_cu_escaped_number.expect b/test/expect/TestJit.test_cu_escaped_number.expect index ff492c20a4470..ef7bd90052101 100644 --- a/test/expect/TestJit.test_cu_escaped_number.expect +++ b/test/expect/TestJit.test_cu_escaped_number.expect @@ -1,3 +1,3 @@ -def foo(a: Tensor) -> None: +def foo(a: Tensor) -> NoneType: print("hi\016") return None diff --git a/test/expect/TestJit.test_pretty_printer-print_weird_test.expect b/test/expect/TestJit.test_pretty_printer-print_weird_test.expect index 57cc3ea70122d..b8f7e22d8612f 100644 --- a/test/expect/TestJit.test_pretty_printer-print_weird_test.expect +++ b/test/expect/TestJit.test_pretty_printer-print_weird_test.expect @@ -1,3 +1,3 @@ -def print_weird_test(y: Tensor) -> None: +def print_weird_test(y: Tensor) -> NoneType: print("hi\016") return None diff --git a/test/jit/test_hooks.py b/test/jit/test_hooks.py index 73c4e5e76f70d..79698a55044a1 100644 --- a/test/jit/test_hooks.py +++ b/test/jit/test_hooks.py @@ -330,7 +330,7 @@ def forward_hook_wrong_input3(self, input: Tuple[None], output: str): with self.assertRaisesRegex( RuntimeError, "has the wrong inner types for the input tuple" - r" argument. Received type: 'Tuple\[None\]'", + r" argument. Received type: 'Tuple\[NoneType\]'", ): torch.jit.script(m) diff --git a/test/test_jit.py b/test/test_jit.py index 0bac81ad48a08..e8b4e977c2499 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -4869,7 +4869,7 @@ def test(a): print(typed_nones()) graph_str = str(test.graph) - self.assertTrue(graph_str.count("None = prim::Constant") == 1) + self.assertTrue(graph_str.count("NoneType = prim::Constant") == 1) def test_constant_pooling_same_identity(self): def foo(): @@ -11061,7 +11061,7 @@ def foo(x : NoneType) -> NoneType: ''') foo_code = cu.find_function('foo').code - FileCheck().check(": None").check("-> None").run(foo_code) + FileCheck().check(": NoneType").check("-> NoneType").run(foo_code) def test_empty_tuple_str(self): empty_tuple_type = torch._C.TupleType([]) @@ -11069,6 +11069,12 @@ def test_empty_tuple_str(self): python_type = eval(empty_tuple_type.annotation_str, g) assert python_type is typing.Tuple[()] + def test_none_type_str(self): + none_type = torch._C.NoneType.get() + g = {'NoneType' : type(None)} + python_type = eval(none_type.annotation_str, g) + assert python_type is type(None) + def test_zip_enumerate_modulelist(self): class Sub(torch.nn.Module): def __init__(self): @@ -12692,7 +12698,7 @@ def foo_no_decl_always_throws(): # function that has no declared type but always throws set to None output_type = next(foo_no_decl_always_throws.graph.outputs()).type() - self.assertTrue(str(output_type) == "None") + self.assertTrue(str(output_type) == "NoneType") @torch.jit.script def foo_decl_always_throws(): @@ -13416,7 +13422,7 @@ def backward(grad_output): ''') cu = torch.jit.CompilationUnit(code) g = cu.tanh.graph - FileCheck().check_count("prim::Closure_0", 2).check("None = prim::Constant") \ + FileCheck().check_count("prim::Closure_0", 2).check("NoneType = prim::Constant") \ .check_next("return").run(g) code = dedent(''' @@ -13447,7 +13453,7 @@ def backward(grad_output): ''') cu = torch.jit.CompilationUnit(code) fc = FileCheck() - fc.check("prim::Closure").check("(Tensor, None) = prim::TupleConstruct") + fc.check("prim::Closure").check("(Tensor, NoneType) = prim::TupleConstruct") # Loop then two if's added in exit transform fc.check("prim::Closure").check("prim::Loop").check_count("prim::If", 2) fc.run(cu.loop_in_closure.graph)