Skip to content

Commit

Permalink
add support for overloading functions (pytorch#23886)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#23886

This is a series of PRs that will allow us to support adding [padding to conv](pytorch#22484) and also reduce the friction of adding method overloads that was brought up in  pytorch#23266.

Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading).

The usage is:
```
torch.jit.overload
def add(x: int, y: int) -> int: ...
torch.jit.overload
def add(x: float, y: float) -> float: ...

def add:
    return x + y
```

Follow up PRs:

- Add same API for methods
- A couple of cleanups for functions:
     - don't require default params specified on the overload as well
     - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently

Test Plan: Imported from OSS

Differential Revision: D16694863

Pulled By: eellison

fbshipit-source-id: f94f2933bc1c97fa58f31846acfe962b0630068c
  • Loading branch information
Elias Ellison authored and facebook-github-bot committed Aug 8, 2019
1 parent 9ecc33d commit 451fc51
Show file tree
Hide file tree
Showing 7 changed files with 310 additions and 34 deletions.
109 changes: 109 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12989,6 +12989,115 @@ def test_non_primitive_types(x):
out = test_non_primitive_types(_MyNamedTuple(value=torch.tensor(5.0)))
self.assertEqual(out, torch.tensor(6.0))

def test_function_overloads(self):
# TODO: pyflakes currently does not compose @overload annotation with other
# decorators. This is fixed on master but not on version 2.1.1.
# Next version update remove noqa and add @typing.overload annotation

@torch.jit._overload # noqa: F811
def test_simple(x1): # noqa: F811
# type: (int) -> int
pass

@torch.jit._overload # noqa: F811
def test_simple(x1): # noqa: F811
# type: (float) -> float
pass

def test_simple(x1): # noqa: F811
return x1 + 5

def invoke_function():
return test_simple(1.0), test_simple(.5)

self.checkScript(invoke_function, ())

# testing that the functions are cached
compiled_fns_1 = torch.jit._get_overloads(test_simple)
compiled_fns_2 = torch.jit._get_overloads(test_simple)
for a, b in zip(compiled_fns_1, compiled_fns_2):
self.assertIs(a, b)

# currently we take the default values have to be specified in the
# overload as well - TODO take them from implementation and apply
# where the type is valid.
@torch.jit._overload # noqa: F811
def identity(x1): # noqa: F811
# type: (str) -> str
pass

@torch.jit._overload # noqa: F811
def identity(x1=1.0): # noqa: F811
# type: (float) -> float
pass

def identity(x1=1.0): # noqa: F811
return x1

def invoke():
return identity(), identity(.5), identity("hi")

self.checkScript(invoke, ())

def schema_match_failure():
return identity((1, 2))

thrown = False
try:
torch.jit.script(schema_match_failure)
except Exception as e:
thrown = True
self.assertTrue(r"of type 'str'" in str(e) and r"of type 'float" in str(e))
self.assertTrue(thrown)

with self.assertRaisesRegex(Exception, "cannot be directly compiled"):
torch.jit.script(identity)

@torch.jit._overload # noqa: F811
def impl_compile_failure(x, y): # noqa: F811
# type: (str, str) -> (str)
pass

@torch.jit._overload # noqa: F811
def impl_compile_failure(x, y): # noqa: F811
# type: (int, int) -> (int)
pass

def impl_compile_failure(x, y): # noqa: F811
return x - y

def test():
impl_compile_failure("one", "two")


with self.assertRaisesRegex(Exception, "Arguments for call are not valid"):
torch.jit.script(test)

def test_function_overloading_isinstance(self):
@torch.jit._overload # noqa: F811
def my_conv(x, y): # noqa: F811
# type: (float, str) -> (float)
pass

@torch.jit._overload # noqa: F811
def my_conv(x, y=2.0): # noqa: F811
# type: (float, float) -> (float)
pass

def my_conv(x, y=2.0): # noqa: F811
if isinstance(y, str):
if y == "hi":
return 4.0 - x
else:
return 5.0 - x
else:
return 2.0 + x

def test_uses():
return my_conv(1.5), my_conv(1.5, "hi"), my_conv(1.5, 5.0)

self.checkScript(test_uses, ())

@unittest.skipIf(True, "Removing weak script")
def test_overloading(self):
@torch._jit_internal.weak_module
Expand Down
73 changes: 56 additions & 17 deletions torch/csrc/jit/script/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,25 @@ std::shared_ptr<PythonResolver> pythonResolver(
return std::make_shared<PythonResolver>(
rcb, std::move(classname), std::move(classType));
}

void checkOverloadDecl(const Decl& new_decl, const Decl& old_decl) {
const auto& new_params = new_decl.params();
const auto& old_params = old_decl.params();

// TODO. same number of parameters not strictly necessary.
TORCH_INTERNAL_ASSERT(
new_params.size() == old_params.size(),
"Overload must have same number of parameters\n",
new_decl.range(),
old_decl.range());
for (size_t i = 0; i < new_decl.params().size(); ++i) {
TORCH_INTERNAL_ASSERT(
new_params[i].ident().name() == old_params[i].ident().name(),
"Overload parameters must have the same names\n",
new_params[i].ident(),
old_params[i].ident());
}
}
} // namespace

FunctionSchema getSchemaWithNameAndDefaults(
Expand Down Expand Up @@ -215,6 +234,27 @@ FunctionSchema getSchemaWithNameAndDefaults(
schema.is_varret());
}

static StrongFunctionPtr script_compile_function(
const c10::QualifiedName& name,
const Def& def,
const FunctionDefaults& defaults,
ResolutionCallback rcb) {
auto cu = get_python_cu();
auto defined_functions = cu->define(
QualifiedName(name.prefix()),
{def},
{pythonResolver(std::move(rcb))},
nullptr,
true);
TORCH_INTERNAL_ASSERT(defined_functions.size() == 1);
auto& defined = defined_functions[0];
defined->setSchema(getSchemaWithNameAndDefaults(
def.range(), defined->getSchema(), def.name().name(), defaults));
StrongFunctionPtr ret(std::move(cu), defined);
didFinishEmitFunction(ret);
return ret;
}

struct VISIBILITY_HIDDEN ModuleSelf : public Self {
ModuleSelf(const Module& m, py::object& py_m)
: Self(), module_(m), pyModule_(py_m) {}
Expand Down Expand Up @@ -423,8 +463,8 @@ void initJitScriptBindings(PyObject* module) {
.def(
"_register_attribute",
[](Module& self, std::string name, TypePtr type, py::object value) {
auto unshaped = unshapedType(type);
self.register_attribute(name, unshaped, toIValue(value, type));
self.register_attribute(
name, type, toIValue(std::move(value), type));
})
.def("_register_module", &Module::register_module)
.def("_register_buffer", &Module::register_buffer)
Expand Down Expand Up @@ -684,24 +724,23 @@ void initJitScriptBindings(PyObject* module) {
[](const std::string& qualname,
const Def& def,
ResolutionCallback rcb,
FunctionDefaults defaults) {
const FunctionDefaults& defaults) {
C10_LOG_API_USAGE_ONCE("torch.script.compile");
const auto name = c10::QualifiedName(qualname);
TORCH_INTERNAL_ASSERT(name.name() == def.name().name());
auto cu = get_python_cu();
auto defined_functions = cu->define(
QualifiedName(name.prefix()),
{def},
{pythonResolver(std::move(rcb))},
nullptr,
true);
TORCH_INTERNAL_ASSERT(defined_functions.size() == 1);
auto& defined = defined_functions[0];
defined->setSchema(getSchemaWithNameAndDefaults(
def.range(), defined->getSchema(), def.name().name(), defaults));
StrongFunctionPtr ret(std::move(cu), defined);
didFinishEmitFunction(ret);
return ret;
return script_compile_function(name, def, defaults, std::move(rcb));
});
m.def(
"_jit_script_compile_overload",
[](const std::string& qualname,
const Decl& overload_decl,
const Def& implementation_def,
ResolutionCallback rcb,
const FunctionDefaults& defaults) {
const auto name = c10::QualifiedName(qualname);
checkOverloadDecl(overload_decl, implementation_def.decl());
auto new_def = implementation_def.withDecl(overload_decl);
return script_compile_function(name, new_def, defaults, std::move(rcb));
});

m.def(
Expand Down
36 changes: 36 additions & 0 deletions torch/csrc/jit/script/python_sugared_value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,35 @@ std::shared_ptr<SugaredValue> OverloadedMethodValue::call(
<< err.str();
}

std::shared_ptr<SugaredValue> OverloadedFunctionValue::call(
const SourceRange& loc,
Function& caller,
at::ArrayRef<NamedValue> inputs_,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) {
std::stringstream failure_messages;
for (bool allow_conversions : {false, true}) {
// clear previous error messages
failure_messages.str("");
for (const auto& compiled_overload : compiled_overloads_) {
const auto matched_schema = tryMatchSchema(
compiled_overload.function_->getSchema(),
loc,
*caller.graph(),
c10::nullopt,
inputs_,
attributes,
&failure_messages,
allow_conversions);
if (matched_schema) {
return FunctionValue(compiled_overload)
.call(loc, caller, inputs_, attributes, n_binders);
}
}
}
throw ErrorReport(loc) << failure_messages.str();
}

std::shared_ptr<SugaredValue> ModuleValue::attr(
const SourceRange& loc,
Function& m,
Expand Down Expand Up @@ -544,6 +573,13 @@ std::shared_ptr<SugaredValue> toSugaredValue(

py::bool_ isFunction = py::module::import("inspect").attr("isfunction")(obj);
if (py::cast<bool>(isFunction)) {
auto overloads =
py::module::import("torch.jit").attr("_get_overloads")(obj);
if (!overloads.is_none()) {
auto compiled_fns = py::cast<std::vector<StrongFunctionPtr>>(overloads);
return std::make_shared<OverloadedFunctionValue>(std::move(compiled_fns));
}

auto compiled_fn =
py::module::import("torch.jit").attr("_try_compile_fn")(obj, loc);
if (auto callee = as_function(compiled_fn)) {
Expand Down
19 changes: 19 additions & 0 deletions torch/csrc/jit/script/python_sugared_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,25 @@ struct VISIBILITY_HIDDEN OverloadedMethodValue : public SugaredValue {
std::vector<std::string> method_names_;
};

struct VISIBILITY_HIDDEN OverloadedFunctionValue : public SugaredValue {
OverloadedFunctionValue(std::vector<StrongFunctionPtr> compiled_overloads)
: compiled_overloads_(std::move(compiled_overloads)) {}

std::string kind() const override {
return "overloaded function";
}

std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& caller,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) override;

private:
std::vector<StrongFunctionPtr> compiled_overloads_;
};

// defines how modules/methods behave inside the script subset.
// for now this does not have any interaction with python.
// in the future, we will add the ability to resolve `self.foo` to python
Expand Down
12 changes: 7 additions & 5 deletions torch/csrc/jit/script/python_tree_views.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,13 @@ void initTreeViewBindings(PyObject* module) {
py::class_<Stmt, TreeView>(m, "Stmt") // NOLINT(bugprone-unused-raii)
.def(py::init([](const TreeView& thing) { return Stmt(thing.get()); }));
py::class_<Expr, TreeView>(m, "Expr"); // NOLINT(bugprone-unused-raii)
py::class_<Def, TreeView>(m, "Def").def(
py::init([](const Ident& name, Decl decl, std::vector<Stmt> body) {
const auto& r = name.range();
return Def::create(r, name, decl, wrap_list(r, std::move(body)));
}));
py::class_<Def, TreeView>(m, "Def")
.def(py::init(
[](const Ident& name, const Decl& decl, std::vector<Stmt> body) {
const auto& r = name.range();
return Def::create(r, name, decl, wrap_list(r, std::move(body)));
}))
.def("decl", [](const Def& def) { return def.decl(); });
py::class_<ClassDef, TreeView>(m, "ClassDef")
.def(py::init([](const Ident& name, std::vector<Stmt> body) {
const auto& r = name.range();
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/jit/script/tree_views.h
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,9 @@ struct Def : public TreeView {
auto new_ident = Ident::create(name().range(), std::move(new_name));
return create(range(), new_ident, decl(), statements());
}
Def withDecl(Decl decl) const {
return create(range(), name(), decl, statements());
}
Ident name() const {
return Ident(subtree(0));
}
Expand Down
Loading

0 comments on commit 451fc51

Please sign in to comment.