Skip to content

Commit

Permalink
[mypyc] Add support for dynamically registering singledispatch functi…
Browse files Browse the repository at this point in the history
…ons (python#10968)

Instead of generating a regular native function for singledispatch
functions, generate a callable class. That gives us a place to put the
registry dict (instead of storing the registry dict in a static
variable) and also will allow us to support dynamically registering
functions by adding a register method to that callable class.
  • Loading branch information
pranavrajpal authored Aug 11, 2021
1 parent 0bde4b9 commit 0dcd2c6
Show file tree
Hide file tree
Showing 8 changed files with 282 additions and 73 deletions.
75 changes: 51 additions & 24 deletions mypyc/irbuild/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,21 @@

from mypyc.ir.ops import (
BasicBlock, Value, Register, Return, SetAttr, Integer, GetAttr, Branch, InitStatic,
LoadAddress, LoadLiteral, Unbox, Unreachable, LoadStatic,
LoadAddress, LoadLiteral, Unbox, Unreachable,
)
from mypyc.ir.rtypes import (
object_rprimitive, RInstance, object_pointer_rprimitive, dict_rprimitive, int_rprimitive,
bool_rprimitive,
)
from mypyc.ir.func_ir import (
FuncIR, FuncSignature, RuntimeArg, FuncDecl, FUNC_CLASSMETHOD, FUNC_STATICMETHOD, FUNC_NORMAL
)
from mypyc.ir.class_ir import ClassIR, NonExtClassInfo
from mypyc.primitives.generic_ops import py_setattr_op, next_raw_op, iter_op
from mypyc.primitives.misc_ops import (
check_stop_op, yield_from_except_op, coro_op, send_op
check_stop_op, yield_from_except_op, coro_op, send_op, register_function
)
from mypyc.primitives.dict_ops import dict_set_item_op
from mypyc.primitives.dict_ops import dict_set_item_op, dict_new_op
from mypyc.common import SELF_NAME, LAMBDA_NAME, decorator_helper_name
from mypyc.sametype import is_same_method_signature
from mypyc.irbuild.util import is_constant
Expand Down Expand Up @@ -88,9 +89,8 @@ def transform_decorator(builder: IRBuilder, dec: Decorator) -> None:
dec.func.name,
builder.mapper.fdef_to_sig(dec.func)
)

if dec.func in builder.nested_fitems:
assert func_reg is not None
decorated_func: Optional[Value] = None
if func_reg:
decorated_func = load_decorated_func(builder, dec.func, func_reg)
builder.assign(get_func_target(builder, dec.func), decorated_func, dec.func.line)
func_reg = decorated_func
Expand All @@ -106,6 +106,7 @@ def transform_decorator(builder: IRBuilder, dec: Decorator) -> None:
orig_func = builder.load_global_str(helper_name, dec.line)
decorated_func = load_decorated_func(builder, dec.func, orig_func)

if decorated_func is not None:
# Set the callable object representing the decorated function as a global.
builder.call_c(dict_set_item_op,
[builder.load_globals_dict(),
Expand Down Expand Up @@ -333,8 +334,7 @@ def c() -> None:
# create the dispatch function
assert isinstance(fitem, FuncDef)
dispatch_name = decorator_helper_name(name) if is_decorated else name
dispatch_func_ir = gen_dispatch_func_ir(builder, fitem, fn_info.name, dispatch_name, sig)
return dispatch_func_ir, None
return gen_dispatch_func_ir(builder, fitem, fn_info.name, dispatch_name, sig)

return (func_ir, func_reg)

Expand Down Expand Up @@ -847,7 +847,7 @@ def gen_native_func_call_and_return(fdef: FuncDef) -> None:
func_decl, arg_info.args, arg_info.arg_kinds, arg_info.arg_names, line
)
coerced = builder.coerce(ret_val, current_func_decl.sig.ret_type, line)
builder.nonlocal_control[-1].gen_return(builder, coerced, line)
builder.add(Return(coerced))

registry = load_singledispatch_registry(builder, fitem, line)

Expand Down Expand Up @@ -897,7 +897,7 @@ def gen_native_func_call_and_return(fdef: FuncDef) -> None:
impl_to_use, arg_info.args, line, arg_info.arg_kinds, arg_info.arg_names
)
coerced = builder.coerce(ret_val, current_func_decl.sig.ret_type, line)
builder.nonlocal_control[-1].gen_return(builder, coerced, line)
builder.add(Return(coerced))


def gen_dispatch_func_ir(
Expand All @@ -906,22 +906,51 @@ def gen_dispatch_func_ir(
main_func_name: str,
dispatch_name: str,
sig: FuncSignature,
) -> FuncIR:
) -> Tuple[FuncIR, Value]:
"""Create a dispatch function (a function that checks the first argument type and dispatches
to the correct implementation)
"""
builder.enter()
builder.enter(FuncInfo(fitem, dispatch_name))
setup_callable_class(builder)
builder.fn_info.callable_class.ir.attributes['registry'] = dict_rprimitive
builder.fn_info.callable_class.ir.has_dict = True
generate_singledispatch_callable_class_ctor(builder)

generate_singledispatch_dispatch_function(builder, main_func_name, fitem)
args, _, blocks, _, fn_info = builder.leave()
func_decl = FuncDecl(dispatch_name, None, builder.module_name, sig)
dispatch_func_ir = FuncIR(func_decl, args, blocks)
return dispatch_func_ir
dispatch_func_ir = add_call_to_callable_class(builder, args, blocks, sig, fn_info)
add_get_to_callable_class(builder, fn_info)
add_register_method_to_callable_class(builder, fn_info)
func_reg = instantiate_callable_class(builder, fn_info)

return dispatch_func_ir, func_reg


def generate_singledispatch_callable_class_ctor(builder: IRBuilder) -> None:
"""Create an __init__ that sets registry to an empty dict"""
line = -1
class_ir = builder.fn_info.callable_class.ir
builder.enter_method(class_ir, '__init__', bool_rprimitive)
empty_dict = builder.call_c(dict_new_op, [], line)
builder.add(SetAttr(builder.self(), 'registry', empty_dict, line))
# the generated C code seems to expect that __init__ returns a char, so just return 1
builder.add(Return(Integer(1, bool_rprimitive, line), line))
builder.leave_method()


def add_register_method_to_callable_class(builder: IRBuilder, fn_info: FuncInfo) -> None:
line = -1
builder.enter_method(fn_info.callable_class.ir, 'register', object_rprimitive)
cls_arg = builder.add_argument('cls', object_rprimitive)
func_arg = builder.add_argument('func', object_rprimitive, ArgKind.ARG_OPT)
ret_val = builder.call_c(register_function, [builder.self(), cls_arg, func_arg], line)
builder.add(Return(ret_val, line))
builder.leave_method()


def load_singledispatch_registry(builder: IRBuilder, fitem: FuncDef, line: int) -> Value:
name = get_registry_identifier(fitem)
module_name = fitem.fullname.rsplit('.', maxsplit=1)[0]
return builder.add(LoadStatic(dict_rprimitive, name, module_name, line=line))
dispatch_func_obj = load_func(builder, fitem.name, fitem.fullname, line)
return builder.builder.get_attr(dispatch_func_obj, 'registry', dict_rprimitive, line)


def singledispatch_main_func_name(orig_name: str) -> str:
Expand Down Expand Up @@ -952,12 +981,10 @@ def maybe_insert_into_registry_dict(builder: IRBuilder, fitem: FuncDef) -> None:
loaded_object_type = builder.load_module_attr_by_fullname('builtins.object', line)
registry_dict = builder.builder.make_dict([(loaded_object_type, main_func_obj)], line)

name = get_registry_identifier(fitem)
# HACK: reuse the final_names list to make sure that the registry dict gets declared as a
# static variable in the backend, even though this isn't a final variable
builder.final_names.append((name, dict_rprimitive))
init_static = InitStatic(registry_dict, name, builder.module_name, line=line)
builder.add(init_static)
dispatch_func_obj = builder.load_global_str(fitem.name, line)
builder.call_c(
py_setattr_op, [dispatch_func_obj, builder.load_str('registry'), registry_dict], line
)

for singledispatch_func, types in to_register.items():
# TODO: avoid recomputing the native IDs for all the functions every time we find a new
Expand Down
4 changes: 4 additions & 0 deletions mypyc/irbuild/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,10 @@ def visit_decorator(self, dec: Decorator) -> None:
else:
if refers_to_fullname(d, 'functools.singledispatch'):
decorators_to_remove.append(i)
# make sure that we still treat the function as a singledispatch function
# even if we don't find any registered implementations (which might happen
# if all registered implementations are registered dynamically)
self.singledispatch_impls.setdefault(dec.func, [])
last_non_register = i

if decorators_to_remove:
Expand Down
2 changes: 2 additions & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,8 @@ PyObject *CPy_CallReverseOpMethod(PyObject *left, PyObject *right, const char *o
PyObject *CPyImport_ImportFrom(PyObject *module, PyObject *package_name,
PyObject *import_name, PyObject *as_name);

PyObject *CPySingledispatch_RegisterFunction(PyObject *singledispatch_func, PyObject *cls,
PyObject *func);
#ifdef __cplusplus
}
#endif
Expand Down
70 changes: 70 additions & 0 deletions mypyc/lib-rt/misc_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -708,3 +708,73 @@ CPy_CallReverseOpMethod(PyObject *left,
Py_DECREF(m);
return result;
}

PyObject *CPySingledispatch_RegisterFunction(PyObject *singledispatch_func,
PyObject *cls,
PyObject *func) {
PyObject *registry = PyObject_GetAttrString(singledispatch_func, "registry");
PyObject *register_func = NULL;
PyObject *typing = NULL;
PyObject *get_type_hints = NULL;
PyObject *type_hints = NULL;

if (registry == NULL) goto fail;
if (func == NULL) {
// one argument case
if (PyType_Check(cls)) {
// passed a class
// bind cls to the first argument so that register gets called again with both the
// class and the function
register_func = PyObject_GetAttrString(singledispatch_func, "register");
if (register_func == NULL) goto fail;
return PyMethod_New(register_func, cls);
}
// passed a function
PyObject *annotations = PyFunction_GetAnnotations(cls);
const char *invalid_first_arg_msg =
"Invalid first argument to `register()`: %R. "
"Use either `@register(some_class)` or plain `@register` "
"on an annotated function.";

if (annotations == NULL) {
PyErr_Format(PyExc_TypeError, invalid_first_arg_msg, cls);
goto fail;
}

Py_INCREF(annotations);

func = cls;
typing = PyImport_ImportModule("typing");
if (typing == NULL) goto fail;
get_type_hints = PyObject_GetAttrString(typing, "get_type_hints");

type_hints = PyObject_CallOneArg(get_type_hints, func);
PyObject *argname;
Py_ssize_t pos = 0;
if (!PyDict_Next(type_hints, &pos, &argname, &cls)) {
// the functools implementation raises the same type error if annotations is an empty dict
PyErr_Format(PyExc_TypeError, invalid_first_arg_msg, cls);
goto fail;
}
if (!PyType_Check(cls)) {
const char *invalid_annotation_msg = "Invalid annotation for %R. %R is not a class.";
PyErr_Format(PyExc_TypeError, invalid_annotation_msg, argname, cls);
goto fail;
}
}
if (PyDict_SetItem(registry, cls, func) == -1) {
goto fail;
}

Py_INCREF(func);
return func;

fail:
Py_XDECREF(registry);
Py_XDECREF(register_func);
Py_XDECREF(typing);
Py_XDECREF(get_type_hints);
Py_XDECREF(type_hints);
return NULL;

}
9 changes: 9 additions & 0 deletions mypyc/primitives/misc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,12 @@
return_type=c_int_rprimitive,
c_function_name='CPySequence_CheckUnpackCount',
error_kind=ERR_NEG_INT)


# register an implementation for a singledispatch function
register_function = custom_op(
arg_types=[object_rprimitive, object_rprimitive, object_rprimitive],
return_type=object_rprimitive,
c_function_name='CPySingledispatch_RegisterFunction',
error_kind=ERR_MAGIC,
)
30 changes: 19 additions & 11 deletions mypyc/test-data/irbuild-basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -2972,10 +2972,14 @@ def __mypyc_c_decorator_helper__():
r8 :: str
r9, r10 :: object
r11 :: bool
r12 :: str
r13 :: object
r14 :: str
r15, r16, r17, r18 :: object
r12 :: dict
r13 :: str
r14 :: int32
r15 :: bit
r16 :: str
r17 :: object
r18 :: str
r19, r20, r21, r22 :: object
L0:
r0 = __mypyc_c_decorator_helper___env()
r1 = __mypyc_d_decorator_helper_____mypyc_c_decorator_helper___obj()
Expand All @@ -2989,13 +2993,17 @@ L0:
r9 = CPyDict_GetItem(r7, r8)
r10 = PyObject_CallFunctionObjArgs(r9, r6, 0)
r0.d = r10; r11 = is_error
r12 = 'c'
r13 = builtins :: module
r14 = 'print'
r15 = CPyObject_GetAttr(r13, r14)
r16 = PyObject_CallFunctionObjArgs(r15, r12, 0)
r17 = r0.d
r18 = PyObject_CallFunctionObjArgs(r17, 0)
r12 = __main__.globals :: static
r13 = 'd'
r14 = CPyDict_SetItem(r12, r13, r10)
r15 = r14 >= 0 :: signed
r16 = 'c'
r17 = builtins :: module
r18 = 'print'
r19 = CPyObject_GetAttr(r17, r18)
r20 = PyObject_CallFunctionObjArgs(r19, r16, 0)
r21 = r0.d
r22 = PyObject_CallFunctionObjArgs(r21, 0)
return 1
def __top_level__():
r0, r1 :: object
Expand Down
Loading

0 comments on commit 0dcd2c6

Please sign in to comment.