From fba107f18ea7fffa9eefa741c08caf78b04b795c Mon Sep 17 00:00:00 2001 From: Zachary DeVito Date: Tue, 27 Aug 2019 22:52:48 -0700 Subject: [PATCH] add serialization of interface (#25227) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25227 Adds cases to NamedType serialization to so that interfaces are written. Similar implementation to NamedTuples Test Plan: Imported from OSS Differential Revision: D17066674 Pulled By: zdevito fbshipit-source-id: fda5419260fad29e8c4ddb92de1d3447d621d982 --- aten/src/ATen/core/jit_type.h | 3 ++ test/test_jit.py | 67 ++++++++++++------------ torch/csrc/jit/import_source.cpp | 49 +++++++++-------- torch/csrc/jit/passes/python_print.cpp | 66 +++++++++++++++++++---- torch/csrc/jit/script/compilation_unit.h | 2 +- torch/csrc/jit/script/compiler.cpp | 2 +- torch/csrc/jit/script/init.cpp | 2 +- 7 files changed, 122 insertions(+), 69 deletions(-) diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index e8edcbf241..3f190522bf 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -1491,6 +1491,9 @@ struct CAFFE2_API InterfaceType : public NamedType { // returns nullptr if not found. const FunctionSchema* getMethod(const std::string& name) const; void addMethod(FunctionSchema schema); + const std::vector& methods() { + return *methods_; + } static const TypeKind Kind = TypeKind::InterfaceType; ~InterfaceType() override; private: diff --git a/test/test_jit.py b/test/test_jit.py index 803ce4b915..3e56e71626 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -18421,48 +18421,47 @@ def forward(self, a): self.assertEqual(3 * input, output) def test_interface(self): - with torch.jit._disable_emit_hooks(): - @torch.jit.script - class Foo(object): - def __init__(self): - pass + @torch.jit.script + class Foo(object): + def __init__(self): + pass - def one(self, x, y): - return x + y + def one(self, x, y): + return x + y - def two(self, x): - return 2 * x + def two(self, x): + return 2 * x - @torch.jit.script - class Bar(object): - def __init__(self): - pass + @torch.jit.script + class Bar(object): + def __init__(self): + pass - def one(self, x, y): - return x * y + def one(self, x, y): + return x * y - def two(self, x): - return 2 / x + def two(self, x): + return 2 / x - @torch.jit.interface - class OneTwo(object): - def one(self, x, y): - # type: (Tensor, Tensor) -> Tensor - pass + @torch.jit.interface + class OneTwo(object): + def one(self, x, y): + # type: (Tensor, Tensor) -> Tensor + pass - def two(self, x): - # type: (Tensor) -> Tensor - pass + def two(self, x): + # type: (Tensor) -> Tensor + pass - def use_them(x): - a = Foo() - b = Bar() - c = torch.jit.annotate(List[OneTwo], [a, b]) - for i in range(len(c)): - x = c[i].one(x, x) - x = c[i].two(x) - return x - self.checkScript(use_them, (torch.rand(3, 4),)) + def use_them(x): + a = Foo() + b = Bar() + c = torch.jit.annotate(List[OneTwo], [a, b]) + for i in range(len(c)): + x = c[i].one(x, x) + x = c[i].two(x) + return x + self.checkScript(use_them, (torch.rand(3, 4),)) def test_overloaded_fn(self): diff --git a/torch/csrc/jit/import_source.cpp b/torch/csrc/jit/import_source.cpp index 37a93071b1..f872b0be0f 100644 --- a/torch/csrc/jit/import_source.cpp +++ b/torch/csrc/jit/import_source.cpp @@ -193,7 +193,7 @@ struct SourceImporter { switch (kind) { case TK_CLASS_DEF: { auto parsed_treeref = p_.parseClass(); - importClass(qualifier, ClassDef(parsed_treeref)); + importNamedType(qualifier, ClassDef(parsed_treeref)); } break; case TK_DEF: { auto parsed_treeref = p_.parseFunction(/*is_method=*/false); @@ -229,24 +229,34 @@ struct SourceImporter { cu_->define(qualifier, definitions, resolvers, nullptr); } - void importClass(const std::string& qualifier, const ClassDef& class_def) { - bool is_module = false; - if (class_def.superclass().present()) { - const auto& superclass_name = - Var(class_def.superclass().get()).name().name(); - if (superclass_name == "Module") { - is_module = true; - } else if (superclass_name == "NamedTuple") { - // NamedTuples have special rules (since they are TupleTypes and not ClassTypes) - return importNamedTuple(qualifier, class_def); - } else { - throw ErrorReport(class_def.range()) - << "Torchscript does not support class inheritance."; - } + void importNamedType( + const std::string& qualifier, + const ClassDef& class_def) { + const auto qualified_name = + QualifiedName(QualifiedName(qualifier), class_def.name().name()); + if (!class_def.superclass().present()) { + return importClass(qualified_name, class_def, /*is_module=*/false); + } + const auto& superclass_name = + Var(class_def.superclass().get()).name().name(); + if (superclass_name == "Module") { + importClass(qualified_name, class_def, /*is_module=*/true); + } else if (superclass_name == "NamedTuple") { + // NamedTuples have special rules (since they are TupleTypes and not + // ClassTypes) + return importNamedTuple(qualified_name, class_def); + } else if (superclass_name == "Interface") { + cu_->define_interface(qualified_name, class_def, resolver_); + } else { + throw ErrorReport(class_def.range()) + << "Torchscript does not support class inheritance."; } + } - const auto qualified_classname = - QualifiedName(QualifiedName(qualifier), class_def.name().name()); + void importClass( + const QualifiedName& qualified_classname, + const ClassDef& class_def, + bool is_module) { auto class_type = ClassType::create( c10::QualifiedName(qualified_classname), cu_, is_module); @@ -352,11 +362,8 @@ struct SourceImporter { } void importNamedTuple( - const std::string& qualifier, + const QualifiedName& qualified_name, const ClassDef& named_tuple_def) { - auto qualified_name = - c10::QualifiedName(qualifier + "." + named_tuple_def.name().name()); - ScriptTypeParser type_parser(resolver_); std::vector field_names; std::vector field_types; diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp index 68dab3bf06..27878f78bd 100644 --- a/torch/csrc/jit/passes/python_print.cpp +++ b/torch/csrc/jit/passes/python_print.cpp @@ -757,6 +757,8 @@ struct PythonPrintPass { if (tupleType->name()) { registerDependency(tupleType); } + } else if (const auto interfaceType = type->cast()) { + registerDependency(interfaceType); } for (const auto& containedType : type->containedTypes()) { registerClassDependencies(containedType); @@ -947,6 +949,20 @@ struct PythonPrintPass { } } + static bool elementTypeCanBeInferredFromMembers(const TypePtr& elem_type) { + if (elem_type->kind() == OptionalType::Kind) { + // it is possible that we are constructing an optional list, but all + // elements are present + return false; + } + if (elem_type->kind() == InterfaceType::Kind) { + // since classes can be members of multiple interfaces, we cannot + // construct which interface the list holds from the members alone + return false; + } + return true; + } + // Prints the RHS value of a Node, e.g. `aten.add(x, y)` void printRHS(TaggedStringStream& stmt, Node* node) { switch (node->kind()) { @@ -1037,9 +1053,7 @@ struct PythonPrintPass { if (node->inputs().size() == 0) { stmt << "annotate(" << node->output()->type()->python_str() << ", [])"; - } else if (elem_type->cast()) { - // if the element type is a optional type, we annotate the list so - // that we could correctly infer the type on import + } else if (!elementTypeCanBeInferredFromMembers(elem_type)) { stmt << "annotate(" << node->output()->type()->python_str() << ","; printValueList(stmt, node->inputs(), "[", "]"); stmt << ")"; @@ -1089,21 +1103,27 @@ struct PythonPrintPass { } break; case prim::CallMethod: { const auto& self = node->inputs().at(0); - const auto& selfType = self->type()->expect(); const auto& methodName = node->s(attr::name); - const auto method = selfType->getMethod(node->s(attr::name)); - registerDependency(selfType); - - TORCH_INTERNAL_ASSERT( - method->qualname() == - QualifiedName(selfType->name()->qualifiedName(), methodName)); - stmt << "(" << useOf(self) << ")" << "." << methodName << "("; for (size_t i = 1; i < node->inputs().size(); i++) { stmt << useOf(node->inputs()[i]) << ", "; } stmt << ")"; + + if (auto selfClass = self->type()->cast()) { + registerDependency(selfClass); + const auto method = selfClass->getMethod(node->s(attr::name)); + TORCH_INTERNAL_ASSERT( + method->qualname() == + QualifiedName(selfClass->name()->qualifiedName(), methodName)); + } else if (auto selfInterface = self->type()->cast()) { + registerDependency(selfInterface); + } else { + TORCH_INTERNAL_ASSERT( + false, "method call to unhandled type in serialization"); + } + } break; default: { Symbol kind = node->kind(); @@ -1351,6 +1371,30 @@ struct PythonPrintPass { body_ << attr.name() << " : " << attr.type()->python_str() << "\n"; } } + } else if (auto interfaceType = type->cast()) { + body_ << "class " << interfaceType->name()->name(); + body_ << "(Interface):\n"; + { + auto guard = WithIndented(); + for (const FunctionSchema& method : interfaceType->methods()) { + indent(); + body_ << "def " << method.name() << "(self"; + TORCH_INTERNAL_ASSERT( + method.arguments().size() > 0 && + method.arguments().at(0).name() == "self"); + for (const Argument& arg : + at::ArrayRef(method.arguments()).slice(1)) { + auto type = arg.type(); + registerClassDependencies(type); + body_ << ", " << arg.name() << ": " << type->python_str(); + } + auto return_type = method.returns().at(0).type(); + registerClassDependencies(return_type); + body_ << ") -> " << return_type->python_str() << ":\n"; + indent(); + body_ << " pass\n"; + } + } } else { TORCH_INTERNAL_ASSERT(false, "Unhandled NamedType"); } diff --git a/torch/csrc/jit/script/compilation_unit.h b/torch/csrc/jit/script/compilation_unit.h index cb4e8f1309..06c3727e25 100644 --- a/torch/csrc/jit/script/compilation_unit.h +++ b/torch/csrc/jit/script/compilation_unit.h @@ -105,7 +105,7 @@ struct TORCH_API CompilationUnit { const Self* self); void define_interface( - const std::string& qualifiedName, + const c10::QualifiedName& qualifiedName, const ClassDef& classDef, ResolverPtr rcb); diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 3f00e35d97..f46a63b4a2 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -3294,7 +3294,7 @@ void lambdaLiftFork(Node* fork_node) { } void CompilationUnit::define_interface( - const std::string& qualifiedName, + const c10::QualifiedName& qualifiedName, const ClassDef& classDef, ResolverPtr rcb) { ScriptTypeParser typeParser(rcb); diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 661faf2b9d..c8d9030e23 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -805,7 +805,7 @@ void initJitScriptBindings(PyObject* module) { const ClassDef& classDef, ResolutionCallback rcb) { get_python_cu()->define_interface( - qualifiedName, classDef, pythonResolver(rcb)); + c10::QualifiedName(qualifiedName), classDef, pythonResolver(rcb)); }); m.def("parse_type_comment", [](const std::string& comment) {