Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add argument number dispatch mechanism for std::function casting #5285

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

rath3t
Copy link

@rath3t rath3t commented Aug 3, 2024

Description

The proposed PR partly fixing function overloads where the arguments are std::functions.
It fixes #3035.

The problem appears, if we have at least two functions:

PYBIND11_EMBEDDED_MODULE(func_module, m) {
    m.def("func", [](const std::function<int(int, int)>& f) { return f(2, 3); })
       .def("func", [](const std::function<int(int)>& f) { return f(2); });
}

and we want to call them in Python

import func_module
def f(a):
    return a*2

func_module.func(f)

The function f is cast to the first overload, which takes two arguments.
Then the call fails with

f() takes 1 positional argument but 2 were given

Calling it from c++ with

auto f = std::function([](int x) { return 2 * x; });
auto m = py::module_::import("func_module");
m.attr("func")(f);

fails with

)m.attr(" func ")(f).cast<int>() failed with: \nTypeError: (): incompatible function arguments. The following argument types are supported:
    1. (arg0: int) -> int

Invoked with: 2, 3

This is fixed with this PR.

In Compiler Explorer I created the failing example, see if you want https://godbolt.org/z/c4Y9bTxTE

Suggested changelog entry:

Number of arguments are now checked for casting `std::function`.

@rath3t rath3t force-pushed the fix/stdFunctioncast branch from 87b9bb2 to 7b9d61f Compare August 4, 2024 05:32
@rwgk
Copy link
Collaborator

rwgk commented Aug 7, 2024

I'm wondering about the tradeoff: extra-code-complexity vs. usability-benefit

For example:

    m.def("func", [](const std::function<int(int, int)>& f) { return f(2, 3); })
       .def("func", [](const std::function<int(int)>& f) { return f(2); });

This could simply become something like:

    m.def("handle_point", [](const std::function<int(int, int)>& f) { return f(2, 3); })
       .def("handle_scalar", [](const std::function<int(int)>& f) { return f(2); });

Meaningful names will make the client code more obvious and readable.

When is it important that the names are identical?

@rath3t
Copy link
Author

rath3t commented Aug 10, 2024

I think you are right and this is one workaround. We used it in our code but we changed the code to just accept a py::object and inspect it inside C++ using the inspect module. I didn't like this boilerplate everywhere in my library.

I have more complicated types as arguments, where it does not make much sense to change the function name, e.g.
in my use case I have a C++ API, which I cannot change and it would be nice to have the same function names for the Python bindings to not confuse users.
For C++, IMHO this kind of function overloading is so common, the fact that it is not supported, confuses Python bindings writers, which come from the C++ world.
For me this use case is not different to just having a Python tuple converted to an std::tuple, where this dispatch works, doesn't it?

To stay in the example I don't want to write

m.def("handle_point2D", [](const std::function<int(int, int)>& f) { return f(2, 3); })

m.def("handle_point3D", [](const std::function<int(int, int, int)>& f) { return f(2, 3, 4); })

but I just want to have a single function handle_point, that works for 2D and 3D.
This dispatch could also be done within Python by providing a wrapper fun tion but I think it is clear what I mean.

So I think the dispatch should be added to really dispatch this case or at least to show an exception that really states the problem, since for me it also took a while to figure out, why it always tries to call the wrong overload.

If this is not added it would be nice to explicitly state it in the documentation, that this is not supported.

@rwgk
Copy link
Collaborator

rwgk commented Aug 11, 2024

but I just want to have a single function handle_point, that works for 2D and 3D.

Understood, and if the implementation was easier, I'd say fine, why not.

I'm worried about relying on __code__ and __call__ attribute lookups (runtime overhead; also might be fragile in the future).

I haven't looked in great detail, but I'm pretty sure the error handling is incomplete. This will add to the existing code.

So it comes down to judgement, cost vs benefit. I think it'll be "too much code" for "a relative niche feature".

But I'm not the only one here: @henryiii @EthanSteinberg @Skylion007 What's your opinion?

@EthanSteinberg
Copy link
Collaborator

I think this argument number dispatch is a good feature that will make it easier to write pythonic bindings (as python code generally leans in the direction of functions doing type dispatching).

I am a bit worried about runtime overhead though. At minimum, we probably only want to do this check if there is more than one overload.

Let me give this a proper review

Copy link
Collaborator

@EthanSteinberg EthanSteinberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! Just a couple of minor things that I think need to get fixed.

@@ -75,6 +84,36 @@ struct type_caster<std::function<Return(Args...)>> {
// PYPY segfaults here when passing builtin function like sum.
// Raising an fail exception here works to prevent the segfault, but only on gcc.
// See PR #1413 for full details
} else {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To try to minimize the runtime overhead, I would consider adding a rec->next check here to only trigger this if there is a another possible overload to try.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to understand your comment but I'm afraid I don't know what you are suggesting. Can you please elaborate?

// Check number of arguments of Python function
auto getArgCount = [&](PyObject *obj) {
// This is faster then doing import inspect and inspect.signature(obj).parameters
auto *t = PyObject_GetAttrString(obj, "__code__");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this cause a memory leak as GetAttrString returns a new reference, but you never Py_DECREF?

Why not use pybind11's object wrappers to make sure you never leak memory? (https://pybind11.readthedocs.io/en/stable/reference.html#_CPPv46object)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you are right I used PyObject_GetAttrString and .attr where i don't have to check if the attribute.

};
long argCount = -1;

if (static_cast<bool>(PyObject_HasAttrString(src.ptr(), "__code__"))) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

static_cast makes this code look more confusing than it should be. Either just have an if statement directly on HasAttrString or explicitly do if (has_attr_string == 1)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd try to only use PyObject_GetAttrString() here and not PyObject_HasAttrString() at all (unless you really never actually need the attribute).

I looked at this in depth in some other context (a couple years ago probably, but I'd bet it's still the same) and found that PyObject_HasAttrString() uses PyObject_GetAttrString() and just throws the object away.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General: when you feel you're done, I'd review this code do see what can be factored out to a non-templated inline function (or functions). That might help the compiler a little, but is also nice for humans reading the code.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I used PyObject_GetAttrString and .attr, where I'm sure that the attribute is existing, since there I didn't want to write try ... catch around .attr for performance reasons.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I found out how to use getattr. I hope this is the way to go.

@@ -103,6 +103,20 @@ def test_cpp_callable_cleanup():
assert alive_counts == [0, 1, 2, 1, 2, 1, 0]


def test_cpp_correct_overload_resolution():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add a class test (that uses def call) here as well?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

while (rec != nullptr) {
const int correctingSelfArgument = rec->is_method ? 1 : 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self_offset preferred as variable name.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

while (rec != nullptr) {
const int correctingSelfArgument = rec->is_method ? 1 : 0;
if (rec->nargs - correctingSelfArgument != sizeof...(Args)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rec->nargs != sizeof...(Args) + 1

to avoid any doubts/tooling noise around unsigned 0 - 1. (probably doesn't matter, but super easy general defensive approach)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now it reads:

if (rec->nargs != sizeof...(Args) + self_offset)

I assume you indicated like that?

@rath3t rath3t force-pushed the fix/stdFunctioncast branch from 5f66fd1 to febc0b1 Compare August 20, 2024 09:42
@rath3t
Copy link
Author

rath3t commented Aug 20, 2024

Thanks for your comments. I tried to refactor accordingly.

@rath3t rath3t force-pushed the fix/stdFunctioncast branch 3 times, most recently from 2187294 to 9768d5f Compare August 20, 2024 09:54
@rath3t rath3t force-pushed the fix/stdFunctioncast branch 11 times, most recently from ddc2aaa to ec84c33 Compare August 21, 2024 09:57
@rath3t
Copy link
Author

rath3t commented Aug 21, 2024

Ok actually I'm wittnessing some problems in different CI configurations. I will let you know when this is ready again

@rath3t rath3t force-pushed the fix/stdFunctioncast branch from 8a3896e to 4212339 Compare August 21, 2024 10:04
@rath3t rath3t force-pushed the fix/stdFunctioncast branch from 375db6a to e0be5db Compare August 21, 2024 10:16
@rath3t
Copy link
Author

rath3t commented Aug 21, 2024

Ok now this seems ready again.

@rwgk
Copy link
Collaborator

rwgk commented Aug 21, 2024

(Thanks for all the work. I'll look asap, might take me a couple days.)

@rath3t
Copy link
Author

rath3t commented Dec 18, 2024

Can I help with something here to continue with this PR? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[QUESTION] disambiguating lambda arguments in function overloads
3 participants