From 0dcd2c636e28d27473c084e5148e7dfb86886831 Mon Sep 17 00:00:00 2001 From: pranavrajpal <78008260+pranavrajpal@users.noreply.github.com> Date: Tue, 10 Aug 2021 20:01:03 -0700 Subject: [PATCH] [mypyc] Add support for dynamically registering singledispatch functions (#10968) Instead of generating a regular native function for singledispatch functions, generate a callable class. That gives us a place to put the registry dict (instead of storing the registry dict in a static variable) and also will allow us to support dynamically registering functions by adding a register method to that callable class. --- mypyc/irbuild/function.py | 75 +++++++++----- mypyc/irbuild/prepare.py | 4 + mypyc/lib-rt/CPy.h | 2 + mypyc/lib-rt/misc_ops.c | 70 +++++++++++++ mypyc/primitives/misc_ops.py | 9 ++ mypyc/test-data/irbuild-basic.test | 30 ++++-- mypyc/test-data/irbuild-singledispatch.test | 107 +++++++++++++------- mypyc/test-data/run-singledispatch.test | 58 ++++++++++- 8 files changed, 282 insertions(+), 73 deletions(-) diff --git a/mypyc/irbuild/function.py b/mypyc/irbuild/function.py index b2d194d5fab0..0a040ba10712 100644 --- a/mypyc/irbuild/function.py +++ b/mypyc/irbuild/function.py @@ -22,10 +22,11 @@ from mypyc.ir.ops import ( BasicBlock, Value, Register, Return, SetAttr, Integer, GetAttr, Branch, InitStatic, - LoadAddress, LoadLiteral, Unbox, Unreachable, LoadStatic, + LoadAddress, LoadLiteral, Unbox, Unreachable, ) from mypyc.ir.rtypes import ( object_rprimitive, RInstance, object_pointer_rprimitive, dict_rprimitive, int_rprimitive, + bool_rprimitive, ) from mypyc.ir.func_ir import ( FuncIR, FuncSignature, RuntimeArg, FuncDecl, FUNC_CLASSMETHOD, FUNC_STATICMETHOD, FUNC_NORMAL @@ -33,9 +34,9 @@ from mypyc.ir.class_ir import ClassIR, NonExtClassInfo from mypyc.primitives.generic_ops import py_setattr_op, next_raw_op, iter_op from mypyc.primitives.misc_ops import ( - check_stop_op, yield_from_except_op, coro_op, send_op + check_stop_op, yield_from_except_op, coro_op, send_op, register_function ) -from mypyc.primitives.dict_ops import dict_set_item_op +from mypyc.primitives.dict_ops import dict_set_item_op, dict_new_op from mypyc.common import SELF_NAME, LAMBDA_NAME, decorator_helper_name from mypyc.sametype import is_same_method_signature from mypyc.irbuild.util import is_constant @@ -88,9 +89,8 @@ def transform_decorator(builder: IRBuilder, dec: Decorator) -> None: dec.func.name, builder.mapper.fdef_to_sig(dec.func) ) - - if dec.func in builder.nested_fitems: - assert func_reg is not None + decorated_func: Optional[Value] = None + if func_reg: decorated_func = load_decorated_func(builder, dec.func, func_reg) builder.assign(get_func_target(builder, dec.func), decorated_func, dec.func.line) func_reg = decorated_func @@ -106,6 +106,7 @@ def transform_decorator(builder: IRBuilder, dec: Decorator) -> None: orig_func = builder.load_global_str(helper_name, dec.line) decorated_func = load_decorated_func(builder, dec.func, orig_func) + if decorated_func is not None: # Set the callable object representing the decorated function as a global. builder.call_c(dict_set_item_op, [builder.load_globals_dict(), @@ -333,8 +334,7 @@ def c() -> None: # create the dispatch function assert isinstance(fitem, FuncDef) dispatch_name = decorator_helper_name(name) if is_decorated else name - dispatch_func_ir = gen_dispatch_func_ir(builder, fitem, fn_info.name, dispatch_name, sig) - return dispatch_func_ir, None + return gen_dispatch_func_ir(builder, fitem, fn_info.name, dispatch_name, sig) return (func_ir, func_reg) @@ -847,7 +847,7 @@ def gen_native_func_call_and_return(fdef: FuncDef) -> None: func_decl, arg_info.args, arg_info.arg_kinds, arg_info.arg_names, line ) coerced = builder.coerce(ret_val, current_func_decl.sig.ret_type, line) - builder.nonlocal_control[-1].gen_return(builder, coerced, line) + builder.add(Return(coerced)) registry = load_singledispatch_registry(builder, fitem, line) @@ -897,7 +897,7 @@ def gen_native_func_call_and_return(fdef: FuncDef) -> None: impl_to_use, arg_info.args, line, arg_info.arg_kinds, arg_info.arg_names ) coerced = builder.coerce(ret_val, current_func_decl.sig.ret_type, line) - builder.nonlocal_control[-1].gen_return(builder, coerced, line) + builder.add(Return(coerced)) def gen_dispatch_func_ir( @@ -906,22 +906,51 @@ def gen_dispatch_func_ir( main_func_name: str, dispatch_name: str, sig: FuncSignature, -) -> FuncIR: +) -> Tuple[FuncIR, Value]: """Create a dispatch function (a function that checks the first argument type and dispatches to the correct implementation) """ - builder.enter() + builder.enter(FuncInfo(fitem, dispatch_name)) + setup_callable_class(builder) + builder.fn_info.callable_class.ir.attributes['registry'] = dict_rprimitive + builder.fn_info.callable_class.ir.has_dict = True + generate_singledispatch_callable_class_ctor(builder) + generate_singledispatch_dispatch_function(builder, main_func_name, fitem) args, _, blocks, _, fn_info = builder.leave() - func_decl = FuncDecl(dispatch_name, None, builder.module_name, sig) - dispatch_func_ir = FuncIR(func_decl, args, blocks) - return dispatch_func_ir + dispatch_func_ir = add_call_to_callable_class(builder, args, blocks, sig, fn_info) + add_get_to_callable_class(builder, fn_info) + add_register_method_to_callable_class(builder, fn_info) + func_reg = instantiate_callable_class(builder, fn_info) + + return dispatch_func_ir, func_reg + + +def generate_singledispatch_callable_class_ctor(builder: IRBuilder) -> None: + """Create an __init__ that sets registry to an empty dict""" + line = -1 + class_ir = builder.fn_info.callable_class.ir + builder.enter_method(class_ir, '__init__', bool_rprimitive) + empty_dict = builder.call_c(dict_new_op, [], line) + builder.add(SetAttr(builder.self(), 'registry', empty_dict, line)) + # the generated C code seems to expect that __init__ returns a char, so just return 1 + builder.add(Return(Integer(1, bool_rprimitive, line), line)) + builder.leave_method() + + +def add_register_method_to_callable_class(builder: IRBuilder, fn_info: FuncInfo) -> None: + line = -1 + builder.enter_method(fn_info.callable_class.ir, 'register', object_rprimitive) + cls_arg = builder.add_argument('cls', object_rprimitive) + func_arg = builder.add_argument('func', object_rprimitive, ArgKind.ARG_OPT) + ret_val = builder.call_c(register_function, [builder.self(), cls_arg, func_arg], line) + builder.add(Return(ret_val, line)) + builder.leave_method() def load_singledispatch_registry(builder: IRBuilder, fitem: FuncDef, line: int) -> Value: - name = get_registry_identifier(fitem) - module_name = fitem.fullname.rsplit('.', maxsplit=1)[0] - return builder.add(LoadStatic(dict_rprimitive, name, module_name, line=line)) + dispatch_func_obj = load_func(builder, fitem.name, fitem.fullname, line) + return builder.builder.get_attr(dispatch_func_obj, 'registry', dict_rprimitive, line) def singledispatch_main_func_name(orig_name: str) -> str: @@ -952,12 +981,10 @@ def maybe_insert_into_registry_dict(builder: IRBuilder, fitem: FuncDef) -> None: loaded_object_type = builder.load_module_attr_by_fullname('builtins.object', line) registry_dict = builder.builder.make_dict([(loaded_object_type, main_func_obj)], line) - name = get_registry_identifier(fitem) - # HACK: reuse the final_names list to make sure that the registry dict gets declared as a - # static variable in the backend, even though this isn't a final variable - builder.final_names.append((name, dict_rprimitive)) - init_static = InitStatic(registry_dict, name, builder.module_name, line=line) - builder.add(init_static) + dispatch_func_obj = builder.load_global_str(fitem.name, line) + builder.call_c( + py_setattr_op, [dispatch_func_obj, builder.load_str('registry'), registry_dict], line + ) for singledispatch_func, types in to_register.items(): # TODO: avoid recomputing the native IDs for all the functions every time we find a new diff --git a/mypyc/irbuild/prepare.py b/mypyc/irbuild/prepare.py index 5b6eb6b458b3..d34bc6348c75 100644 --- a/mypyc/irbuild/prepare.py +++ b/mypyc/irbuild/prepare.py @@ -368,6 +368,10 @@ def visit_decorator(self, dec: Decorator) -> None: else: if refers_to_fullname(d, 'functools.singledispatch'): decorators_to_remove.append(i) + # make sure that we still treat the function as a singledispatch function + # even if we don't find any registered implementations (which might happen + # if all registered implementations are registered dynamically) + self.singledispatch_impls.setdefault(dec.func, []) last_non_register = i if decorators_to_remove: diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index 18a1ce76f26a..d54d60ef2a11 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -558,6 +558,8 @@ PyObject *CPy_CallReverseOpMethod(PyObject *left, PyObject *right, const char *o PyObject *CPyImport_ImportFrom(PyObject *module, PyObject *package_name, PyObject *import_name, PyObject *as_name); +PyObject *CPySingledispatch_RegisterFunction(PyObject *singledispatch_func, PyObject *cls, + PyObject *func); #ifdef __cplusplus } #endif diff --git a/mypyc/lib-rt/misc_ops.c b/mypyc/lib-rt/misc_ops.c index 7dfc6bb0f909..f301fa874211 100644 --- a/mypyc/lib-rt/misc_ops.c +++ b/mypyc/lib-rt/misc_ops.c @@ -708,3 +708,73 @@ CPy_CallReverseOpMethod(PyObject *left, Py_DECREF(m); return result; } + +PyObject *CPySingledispatch_RegisterFunction(PyObject *singledispatch_func, + PyObject *cls, + PyObject *func) { + PyObject *registry = PyObject_GetAttrString(singledispatch_func, "registry"); + PyObject *register_func = NULL; + PyObject *typing = NULL; + PyObject *get_type_hints = NULL; + PyObject *type_hints = NULL; + + if (registry == NULL) goto fail; + if (func == NULL) { + // one argument case + if (PyType_Check(cls)) { + // passed a class + // bind cls to the first argument so that register gets called again with both the + // class and the function + register_func = PyObject_GetAttrString(singledispatch_func, "register"); + if (register_func == NULL) goto fail; + return PyMethod_New(register_func, cls); + } + // passed a function + PyObject *annotations = PyFunction_GetAnnotations(cls); + const char *invalid_first_arg_msg = + "Invalid first argument to `register()`: %R. " + "Use either `@register(some_class)` or plain `@register` " + "on an annotated function."; + + if (annotations == NULL) { + PyErr_Format(PyExc_TypeError, invalid_first_arg_msg, cls); + goto fail; + } + + Py_INCREF(annotations); + + func = cls; + typing = PyImport_ImportModule("typing"); + if (typing == NULL) goto fail; + get_type_hints = PyObject_GetAttrString(typing, "get_type_hints"); + + type_hints = PyObject_CallOneArg(get_type_hints, func); + PyObject *argname; + Py_ssize_t pos = 0; + if (!PyDict_Next(type_hints, &pos, &argname, &cls)) { + // the functools implementation raises the same type error if annotations is an empty dict + PyErr_Format(PyExc_TypeError, invalid_first_arg_msg, cls); + goto fail; + } + if (!PyType_Check(cls)) { + const char *invalid_annotation_msg = "Invalid annotation for %R. %R is not a class."; + PyErr_Format(PyExc_TypeError, invalid_annotation_msg, argname, cls); + goto fail; + } + } + if (PyDict_SetItem(registry, cls, func) == -1) { + goto fail; + } + + Py_INCREF(func); + return func; + +fail: + Py_XDECREF(registry); + Py_XDECREF(register_func); + Py_XDECREF(typing); + Py_XDECREF(get_type_hints); + Py_XDECREF(type_hints); + return NULL; + +} diff --git a/mypyc/primitives/misc_ops.py b/mypyc/primitives/misc_ops.py index 133781dc8eac..cfdbb8a0f78d 100644 --- a/mypyc/primitives/misc_ops.py +++ b/mypyc/primitives/misc_ops.py @@ -214,3 +214,12 @@ return_type=c_int_rprimitive, c_function_name='CPySequence_CheckUnpackCount', error_kind=ERR_NEG_INT) + + +# register an implementation for a singledispatch function +register_function = custom_op( + arg_types=[object_rprimitive, object_rprimitive, object_rprimitive], + return_type=object_rprimitive, + c_function_name='CPySingledispatch_RegisterFunction', + error_kind=ERR_MAGIC, +) diff --git a/mypyc/test-data/irbuild-basic.test b/mypyc/test-data/irbuild-basic.test index 070bf7e333c8..57608c0fd0db 100644 --- a/mypyc/test-data/irbuild-basic.test +++ b/mypyc/test-data/irbuild-basic.test @@ -2972,10 +2972,14 @@ def __mypyc_c_decorator_helper__(): r8 :: str r9, r10 :: object r11 :: bool - r12 :: str - r13 :: object - r14 :: str - r15, r16, r17, r18 :: object + r12 :: dict + r13 :: str + r14 :: int32 + r15 :: bit + r16 :: str + r17 :: object + r18 :: str + r19, r20, r21, r22 :: object L0: r0 = __mypyc_c_decorator_helper___env() r1 = __mypyc_d_decorator_helper_____mypyc_c_decorator_helper___obj() @@ -2989,13 +2993,17 @@ L0: r9 = CPyDict_GetItem(r7, r8) r10 = PyObject_CallFunctionObjArgs(r9, r6, 0) r0.d = r10; r11 = is_error - r12 = 'c' - r13 = builtins :: module - r14 = 'print' - r15 = CPyObject_GetAttr(r13, r14) - r16 = PyObject_CallFunctionObjArgs(r15, r12, 0) - r17 = r0.d - r18 = PyObject_CallFunctionObjArgs(r17, 0) + r12 = __main__.globals :: static + r13 = 'd' + r14 = CPyDict_SetItem(r12, r13, r10) + r15 = r14 >= 0 :: signed + r16 = 'c' + r17 = builtins :: module + r18 = 'print' + r19 = CPyObject_GetAttr(r17, r18) + r20 = PyObject_CallFunctionObjArgs(r19, r16, 0) + r21 = r0.d + r22 = PyObject_CallFunctionObjArgs(r21, 0) return 1 def __top_level__(): r0, r1 :: object diff --git a/mypyc/test-data/irbuild-singledispatch.test b/mypyc/test-data/irbuild-singledispatch.test index 1f1ebdc1f232..cfbef229cc46 100644 --- a/mypyc/test-data/irbuild-singledispatch.test +++ b/mypyc/test-data/irbuild-singledispatch.test @@ -12,52 +12,87 @@ def __mypyc_singledispatch_main_function_f__(arg): arg :: object L0: return 0 -def f(arg): +def f_obj.__init__(__mypyc_self__): + __mypyc_self__ :: __main__.f_obj + r0 :: dict + r1 :: bool +L0: + r0 = PyDict_New() + __mypyc_self__.registry = r0; r1 = is_error + return 1 +def f_obj.__get__(__mypyc_self__, instance, owner): + __mypyc_self__, instance, owner, r0 :: object + r1 :: bit + r2 :: object +L0: + r0 = load_address _Py_NoneStruct + r1 = instance == r0 + if r1 goto L1 else goto L2 :: bool +L1: + return __mypyc_self__ +L2: + r2 = PyMethod_New(__mypyc_self__, instance) + return r2 +def f_obj.register(__mypyc_self__, cls, func): + __mypyc_self__ :: __main__.f_obj + cls, func, r0 :: object +L0: + r0 = CPySingledispatch_RegisterFunction(__mypyc_self__, cls, func) + return r0 +def f_obj.__call__(__mypyc_self__, arg): + __mypyc_self__ :: __main__.f_obj arg :: object r0 :: dict - r1 :: object - r2 :: str - r3 :: object - r4 :: ptr - r5, r6, r7 :: object + r1 :: str + r2 :: object + r3 :: str + r4, r5 :: object + r6 :: str + r7 :: object r8 :: ptr - r9 :: object - r10 :: bit - r11 :: int - r12 :: bit - r13 :: int - r14 :: bool - r15 :: object - r16 :: bool + r9, r10, r11 :: object + r12 :: ptr + r13 :: object + r14 :: bit + r15 :: int + r16 :: bit + r17 :: int + r18 :: bool + r19 :: object + r20 :: bool L0: - r0 = __main__.__mypyc_singledispatch_registry___main__.f__ :: static - r1 = functools :: module - r2 = '_find_impl' - r3 = CPyObject_GetAttr(r1, r2) - r4 = get_element_ptr arg ob_type :: PyObject - r5 = load_mem r4 :: builtins.object* - keep_alive arg - r6 = PyObject_CallFunctionObjArgs(r3, r5, r0, 0) - r7 = load_address PyLong_Type - r8 = get_element_ptr r6 ob_type :: PyObject + r0 = __main__.globals :: static + r1 = 'f' + r2 = CPyDict_GetItem(r0, r1) + r3 = 'registry' + r4 = CPyObject_GetAttr(r2, r3) + r5 = functools :: module + r6 = '_find_impl' + r7 = CPyObject_GetAttr(r5, r6) + r8 = get_element_ptr arg ob_type :: PyObject r9 = load_mem r8 :: builtins.object* - keep_alive r6 - r10 = r9 == r7 - if r10 goto L1 else goto L4 :: bool + keep_alive arg + r10 = PyObject_CallFunctionObjArgs(r7, r9, r4, 0) + r11 = load_address PyLong_Type + r12 = get_element_ptr r10 ob_type :: PyObject + r13 = load_mem r12 :: builtins.object* + keep_alive r10 + r14 = r13 == r11 + if r14 goto L1 else goto L4 :: bool L1: - r11 = unbox(int, r6) - r12 = r11 == 0 - if r12 goto L2 else goto L3 :: bool + r15 = unbox(int, r10) + r16 = r15 == 0 + if r16 goto L2 else goto L3 :: bool L2: - r13 = unbox(int, arg) - r14 = g(r13) - return r14 + r17 = unbox(int, arg) + r18 = g(r17) + return r18 L3: unreachable L4: - r15 = PyObject_CallFunctionObjArgs(r6, arg, 0) - r16 = unbox(bool, r15) - return r16 + r19 = PyObject_CallFunctionObjArgs(r10, arg, 0) + r20 = unbox(bool, r19) + return r20 def g(arg): arg :: int L0: diff --git a/mypyc/test-data/run-singledispatch.test b/mypyc/test-data/run-singledispatch.test index 7d3a1f07316f..de4bc2f649f5 100644 --- a/mypyc/test-data/run-singledispatch.test +++ b/mypyc/test-data/run-singledispatch.test @@ -114,7 +114,7 @@ def test_singledispatch() -> None: assert fun(1) assert not fun('a') -[case testUseRegisterAsAFunction-xfail] +[case testUseRegisterAsAFunction] from functools import singledispatch @singledispatch @@ -145,7 +145,7 @@ def test_singledispatch() -> None: assert fun_specialized('a') # TODO: turn this into a mypy error -[case testNoneIsntATypeWhenUsedAsArgumentToRegister-xfail] +[case testNoneIsntATypeWhenUsedAsArgumentToRegister] from functools import singledispatch @singledispatch @@ -601,3 +601,57 @@ def _(arg: B) -> str: assert f(A()) == 'a' assert f(B()) == 'b' assert f(1) == 'default' + + +[case testDynamicallyRegisteringFunctionFromInterpretedCode] +from functools import singledispatch + +class A: pass +class B(A): pass +class C(B): pass +class D(C): pass + +@singledispatch +def f(arg) -> str: + return "default" + +@f.register +def _(arg: B) -> str: + return 'b' + +[file register_impl.py] +from native import f, A, B, C + +@f.register(A) +def a(arg) -> str: + return 'a' + +@f.register +def c(arg: C) -> str: + return 'c' + +[file driver.py] +from native import f, A, B, C +from register_impl import a, c +# We need a custom driver here because register_impl has to be run before we test this (so that the +# additional implementations are registered) +assert f(C()) == 'c' +assert f(A()) == 'a' +assert f(B()) == 'b' +assert a(C()) == 'a' +assert c(A()) == 'c' + +[case testMalformedDynamicRegisterCall] +from functools import singledispatch + +@singledispatch +def f(arg) -> None: + pass +[file register.py] +from native import f +from testutil import assertRaises + +with assertRaises(TypeError, 'Invalid first argument to `register()`'): + @f.register + def _(): + pass