-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use vmap with kwargs #70
Comments
vmap We probably do want some way for vmap to specify that it is mapping over kwargs, especially because PyTorch operators can have kwarg-only arguments. Your second approach, ( |
Using bias as positional arguments can solve the problem print(vmap(F.linear, in_dims=(0, 0, 0))(x, w, b).shape) And you are right... My usage seems not to be supported by jax either. BTW, I notice that, even if I am passing bias as positional arg, when I step into the function call, If that is the case, I think a twist to the vmap API is indeed very necessary 🤔 |
Does this happen when you use
In Python it's possible to define a function with an argument that must be passed as a kwarg (and cannot be passed as positional). There are some examples here: https://python-3-for-scientists.readthedocs.io/en/latest/python3_advanced.html and some PyTorch operators are written this way. F.linear isn't, though |
Yes! It happens when using "torch_function" |
Hi
I am currently having the following use case:
However, this would raise exception:
I also tried
which also does not work.
I am wondering what's the correct way to specify the
in_dims
for keyword arguments in **kwargs ?Or is it the case that vmapin_dims
only accept positions?Thanks
The text was updated successfully, but these errors were encountered: