Skip to content

Commit

Permalink
[JIT][write path] Make NoneType annotation_str emit NoneType instea…
Browse files Browse the repository at this point in the history
…d of `None` (pytorch#54746)

Summary: Pull Request resolved: pytorch#54746

Test Plan: Imported from OSS

Reviewed By: SplitInfinity

Differential Revision: D27350331

Pulled By: jamesr66a

fbshipit-source-id: 3f44d6589c29f39378432d0b6b281d96bb4829e7
  • Loading branch information
James Reed authored and facebook-github-bot committed Apr 13, 2021
1 parent a3c06e6 commit 68e0796
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 9 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/core/jit_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -1324,7 +1324,7 @@ struct TORCH_API NoneType : public Type {
return rhs.kind() == kind();
}
std::string str() const override {
return "None";
return "NoneType";
}
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream *why_not) const override {
if (rhs->kind() == OptionalType::Kind) {
Expand Down
2 changes: 1 addition & 1 deletion test/expect/TestJit.test_cu_escaped_number.expect
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
def foo(a: Tensor) -> None:
def foo(a: Tensor) -> NoneType:
print("hi\016")
return None
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
def print_weird_test(y: Tensor) -> None:
def print_weird_test(y: Tensor) -> NoneType:
print("hi\016")
return None
2 changes: 1 addition & 1 deletion test/jit/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def forward_hook_wrong_input3(self, input: Tuple[None], output: str):
with self.assertRaisesRegex(
RuntimeError,
"has the wrong inner types for the input tuple"
r" argument. Received type: 'Tuple\[None\]'",
r" argument. Received type: 'Tuple\[NoneType\]'",
):
torch.jit.script(m)

Expand Down
16 changes: 11 additions & 5 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4869,7 +4869,7 @@ def test(a):
print(typed_nones())

graph_str = str(test.graph)
self.assertTrue(graph_str.count("None = prim::Constant") == 1)
self.assertTrue(graph_str.count("NoneType = prim::Constant") == 1)

def test_constant_pooling_same_identity(self):
def foo():
Expand Down Expand Up @@ -11061,14 +11061,20 @@ def foo(x : NoneType) -> NoneType:
''')

foo_code = cu.find_function('foo').code
FileCheck().check(": None").check("-> None").run(foo_code)
FileCheck().check(": NoneType").check("-> NoneType").run(foo_code)

def test_empty_tuple_str(self):
empty_tuple_type = torch._C.TupleType([])
g = {'Tuple' : typing.Tuple}
python_type = eval(empty_tuple_type.annotation_str, g)
assert python_type is typing.Tuple[()]

def test_none_type_str(self):
none_type = torch._C.NoneType.get()
g = {'NoneType' : type(None)}
python_type = eval(none_type.annotation_str, g)
assert python_type is type(None)

def test_zip_enumerate_modulelist(self):
class Sub(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -12692,7 +12698,7 @@ def foo_no_decl_always_throws():

# function that has no declared type but always throws set to None
output_type = next(foo_no_decl_always_throws.graph.outputs()).type()
self.assertTrue(str(output_type) == "None")
self.assertTrue(str(output_type) == "NoneType")

@torch.jit.script
def foo_decl_always_throws():
Expand Down Expand Up @@ -13416,7 +13422,7 @@ def backward(grad_output):
''')
cu = torch.jit.CompilationUnit(code)
g = cu.tanh.graph
FileCheck().check_count("prim::Closure_0", 2).check("None = prim::Constant") \
FileCheck().check_count("prim::Closure_0", 2).check("NoneType = prim::Constant") \
.check_next("return").run(g)

code = dedent('''
Expand Down Expand Up @@ -13447,7 +13453,7 @@ def backward(grad_output):
''')
cu = torch.jit.CompilationUnit(code)
fc = FileCheck()
fc.check("prim::Closure").check("(Tensor, None) = prim::TupleConstruct")
fc.check("prim::Closure").check("(Tensor, NoneType) = prim::TupleConstruct")
# Loop then two if's added in exit transform
fc.check("prim::Closure").check("prim::Loop").check_count("prim::If", 2)
fc.run(cu.loop_in_closure.graph)
Expand Down

0 comments on commit 68e0796

Please sign in to comment.