Skip to content

Commit

Permalink
add serialization of interface (#25227)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch/pytorch#25227

Adds cases to NamedType serialization to so that interfaces are written.
Similar implementation to NamedTuples

Test Plan: Imported from OSS

Differential Revision: D17066674

Pulled By: zdevito

fbshipit-source-id: fda5419260fad29e8c4ddb92de1d3447d621d982
  • Loading branch information
zdevito authored and facebook-github-bot committed Aug 28, 2019
1 parent a01358f commit fba107f
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 69 deletions.
3 changes: 3 additions & 0 deletions aten/src/ATen/core/jit_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -1491,6 +1491,9 @@ struct CAFFE2_API InterfaceType : public NamedType {
// returns nullptr if not found.
const FunctionSchema* getMethod(const std::string& name) const;
void addMethod(FunctionSchema schema);
const std::vector<FunctionSchema>& methods() {
return *methods_;
}
static const TypeKind Kind = TypeKind::InterfaceType;
~InterfaceType() override;
private:
Expand Down
67 changes: 33 additions & 34 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18421,48 +18421,47 @@ def forward(self, a):
self.assertEqual(3 * input, output)

def test_interface(self):
with torch.jit._disable_emit_hooks():
@torch.jit.script
class Foo(object):
def __init__(self):
pass
@torch.jit.script
class Foo(object):
def __init__(self):
pass

def one(self, x, y):
return x + y
def one(self, x, y):
return x + y

def two(self, x):
return 2 * x
def two(self, x):
return 2 * x

@torch.jit.script
class Bar(object):
def __init__(self):
pass
@torch.jit.script
class Bar(object):
def __init__(self):
pass

def one(self, x, y):
return x * y
def one(self, x, y):
return x * y

def two(self, x):
return 2 / x
def two(self, x):
return 2 / x

@torch.jit.interface
class OneTwo(object):
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
pass
@torch.jit.interface
class OneTwo(object):
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
pass

def two(self, x):
# type: (Tensor) -> Tensor
pass
def two(self, x):
# type: (Tensor) -> Tensor
pass

def use_them(x):
a = Foo()
b = Bar()
c = torch.jit.annotate(List[OneTwo], [a, b])
for i in range(len(c)):
x = c[i].one(x, x)
x = c[i].two(x)
return x
self.checkScript(use_them, (torch.rand(3, 4),))
def use_them(x):
a = Foo()
b = Bar()
c = torch.jit.annotate(List[OneTwo], [a, b])
for i in range(len(c)):
x = c[i].one(x, x)
x = c[i].two(x)
return x
self.checkScript(use_them, (torch.rand(3, 4),))


def test_overloaded_fn(self):
Expand Down
49 changes: 28 additions & 21 deletions torch/csrc/jit/import_source.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ struct SourceImporter {
switch (kind) {
case TK_CLASS_DEF: {
auto parsed_treeref = p_.parseClass();
importClass(qualifier, ClassDef(parsed_treeref));
importNamedType(qualifier, ClassDef(parsed_treeref));
} break;
case TK_DEF: {
auto parsed_treeref = p_.parseFunction(/*is_method=*/false);
Expand Down Expand Up @@ -229,24 +229,34 @@ struct SourceImporter {
cu_->define(qualifier, definitions, resolvers, nullptr);
}

void importClass(const std::string& qualifier, const ClassDef& class_def) {
bool is_module = false;
if (class_def.superclass().present()) {
const auto& superclass_name =
Var(class_def.superclass().get()).name().name();
if (superclass_name == "Module") {
is_module = true;
} else if (superclass_name == "NamedTuple") {
// NamedTuples have special rules (since they are TupleTypes and not ClassTypes)
return importNamedTuple(qualifier, class_def);
} else {
throw ErrorReport(class_def.range())
<< "Torchscript does not support class inheritance.";
}
void importNamedType(
const std::string& qualifier,
const ClassDef& class_def) {
const auto qualified_name =
QualifiedName(QualifiedName(qualifier), class_def.name().name());
if (!class_def.superclass().present()) {
return importClass(qualified_name, class_def, /*is_module=*/false);
}
const auto& superclass_name =
Var(class_def.superclass().get()).name().name();
if (superclass_name == "Module") {
importClass(qualified_name, class_def, /*is_module=*/true);
} else if (superclass_name == "NamedTuple") {
// NamedTuples have special rules (since they are TupleTypes and not
// ClassTypes)
return importNamedTuple(qualified_name, class_def);
} else if (superclass_name == "Interface") {
cu_->define_interface(qualified_name, class_def, resolver_);
} else {
throw ErrorReport(class_def.range())
<< "Torchscript does not support class inheritance.";
}
}

const auto qualified_classname =
QualifiedName(QualifiedName(qualifier), class_def.name().name());
void importClass(
const QualifiedName& qualified_classname,
const ClassDef& class_def,
bool is_module) {
auto class_type = ClassType::create(
c10::QualifiedName(qualified_classname), cu_, is_module);

Expand Down Expand Up @@ -352,11 +362,8 @@ struct SourceImporter {
}

void importNamedTuple(
const std::string& qualifier,
const QualifiedName& qualified_name,
const ClassDef& named_tuple_def) {
auto qualified_name =
c10::QualifiedName(qualifier + "." + named_tuple_def.name().name());

ScriptTypeParser type_parser(resolver_);
std::vector<std::string> field_names;
std::vector<TypePtr> field_types;
Expand Down
66 changes: 55 additions & 11 deletions torch/csrc/jit/passes/python_print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,8 @@ struct PythonPrintPass {
if (tupleType->name()) {
registerDependency(tupleType);
}
} else if (const auto interfaceType = type->cast<InterfaceType>()) {
registerDependency(interfaceType);
}
for (const auto& containedType : type->containedTypes()) {
registerClassDependencies(containedType);
Expand Down Expand Up @@ -947,6 +949,20 @@ struct PythonPrintPass {
}
}

static bool elementTypeCanBeInferredFromMembers(const TypePtr& elem_type) {
if (elem_type->kind() == OptionalType::Kind) {
// it is possible that we are constructing an optional list, but all
// elements are present
return false;
}
if (elem_type->kind() == InterfaceType::Kind) {
// since classes can be members of multiple interfaces, we cannot
// construct which interface the list holds from the members alone
return false;
}
return true;
}

// Prints the RHS value of a Node, e.g. `aten.add(x, y)`
void printRHS(TaggedStringStream& stmt, Node* node) {
switch (node->kind()) {
Expand Down Expand Up @@ -1037,9 +1053,7 @@ struct PythonPrintPass {
if (node->inputs().size() == 0) {
stmt << "annotate(" << node->output()->type()->python_str()
<< ", [])";
} else if (elem_type->cast<OptionalType>()) {
// if the element type is a optional type, we annotate the list so
// that we could correctly infer the type on import
} else if (!elementTypeCanBeInferredFromMembers(elem_type)) {
stmt << "annotate(" << node->output()->type()->python_str() << ",";
printValueList(stmt, node->inputs(), "[", "]");
stmt << ")";
Expand Down Expand Up @@ -1089,21 +1103,27 @@ struct PythonPrintPass {
} break;
case prim::CallMethod: {
const auto& self = node->inputs().at(0);
const auto& selfType = self->type()->expect<ClassType>();
const auto& methodName = node->s(attr::name);
const auto method = selfType->getMethod(node->s(attr::name));
registerDependency(selfType);

TORCH_INTERNAL_ASSERT(
method->qualname() ==
QualifiedName(selfType->name()->qualifiedName(), methodName));

stmt << "(" << useOf(self) << ")"
<< "." << methodName << "(";
for (size_t i = 1; i < node->inputs().size(); i++) {
stmt << useOf(node->inputs()[i]) << ", ";
}
stmt << ")";

if (auto selfClass = self->type()->cast<ClassType>()) {
registerDependency(selfClass);
const auto method = selfClass->getMethod(node->s(attr::name));
TORCH_INTERNAL_ASSERT(
method->qualname() ==
QualifiedName(selfClass->name()->qualifiedName(), methodName));
} else if (auto selfInterface = self->type()->cast<InterfaceType>()) {
registerDependency(selfInterface);
} else {
TORCH_INTERNAL_ASSERT(
false, "method call to unhandled type in serialization");
}

} break;
default: {
Symbol kind = node->kind();
Expand Down Expand Up @@ -1351,6 +1371,30 @@ struct PythonPrintPass {
body_ << attr.name() << " : " << attr.type()->python_str() << "\n";
}
}
} else if (auto interfaceType = type->cast<InterfaceType>()) {
body_ << "class " << interfaceType->name()->name();
body_ << "(Interface):\n";
{
auto guard = WithIndented();
for (const FunctionSchema& method : interfaceType->methods()) {
indent();
body_ << "def " << method.name() << "(self";
TORCH_INTERNAL_ASSERT(
method.arguments().size() > 0 &&
method.arguments().at(0).name() == "self");
for (const Argument& arg :
at::ArrayRef<Argument>(method.arguments()).slice(1)) {
auto type = arg.type();
registerClassDependencies(type);
body_ << ", " << arg.name() << ": " << type->python_str();
}
auto return_type = method.returns().at(0).type();
registerClassDependencies(return_type);
body_ << ") -> " << return_type->python_str() << ":\n";
indent();
body_ << " pass\n";
}
}
} else {
TORCH_INTERNAL_ASSERT(false, "Unhandled NamedType");
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/script/compilation_unit.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ struct TORCH_API CompilationUnit {
const Self* self);

void define_interface(
const std::string& qualifiedName,
const c10::QualifiedName& qualifiedName,
const ClassDef& classDef,
ResolverPtr rcb);

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/script/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3294,7 +3294,7 @@ void lambdaLiftFork(Node* fork_node) {
}

void CompilationUnit::define_interface(
const std::string& qualifiedName,
const c10::QualifiedName& qualifiedName,
const ClassDef& classDef,
ResolverPtr rcb) {
ScriptTypeParser typeParser(rcb);
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/script/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ void initJitScriptBindings(PyObject* module) {
const ClassDef& classDef,
ResolutionCallback rcb) {
get_python_cu()->define_interface(
qualifiedName, classDef, pythonResolver(rcb));
c10::QualifiedName(qualifiedName), classDef, pythonResolver(rcb));
});

m.def("parse_type_comment", [](const std::string& comment) {
Expand Down

0 comments on commit fba107f

Please sign in to comment.