Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use vmap with kwargs #70

Open
xidulu opened this issue Jul 12, 2021 · 4 comments
Open

Use vmap with kwargs #70

xidulu opened this issue Jul 12, 2021 · 4 comments

Comments

@xidulu
Copy link

xidulu commented Jul 12, 2021

Hi

I am currently having the following use case:

from functorch import vmap
x = torch.randn(2,10)
w = torch.randn(2,5,10)
b = torch.randn(2,5)
print(vmap(F.linear, in_dims=(0, 0, 0))(x, w, bias=b).shape)

However, this would raise exception:

ValueError: vmap(linear, in_dims=(0, 0, 0), ...)(<inputs>): in_dims is not compatible with the structure of `inputs`. in_dims has structure TreeSpec(tuple, None, [*, *, *]) but inputs has structure TreeSpec(tuple, None, [*, *]).

I also tried

from functorch import vmap
x = torch.randn(2,10)
w = torch.randn(2,5,10)
b = torch.randn(2,5)
print(vmap(F.linear, in_dims=(0, 0, {'bias':0}))(x, w, bias=b).shape)

which also does not work.

I am wondering what's the correct way to specify the in_dims for keyword arguments in **kwargs ?Or is it the case that vmap in_dims only accept positions?

Thanks

@zou3519
Copy link
Contributor

zou3519 commented Jul 12, 2021

vmap in_dims only accepts positional args; this follows the behavior of jax.vmap. For F.linear, is it possible to workaround this by passing bias as a positional arg for now?

We probably do want some way for vmap to specify that it is mapping over kwargs, especially because PyTorch operators can have kwarg-only arguments. Your second approach, (print(vmap(F.linear, in_dims=(0, 0, {'bias':0}))(x, w, bias=b).shape)) seems pretty reasonable as an API

@xidulu
Copy link
Author

xidulu commented Jul 12, 2021

@zou3519

Using bias as positional arguments can solve the problem

print(vmap(F.linear, in_dims=(0, 0, 0))(x, w, b).shape)

And you are right... My usage seems not to be supported by jax either.

BTW, I notice that, even if I am passing bias as positional arg, when I step into the function call, bias is still interpreted as a keyword argument? (I can see it inside **kwargs)
Is that what you mean by PyTorch operators can have kwarg-only arguments?

If that is the case, I think a twist to the vmap API is indeed very necessary 🤔

@zou3519
Copy link
Contributor

zou3519 commented Jul 12, 2021

BTW, I notice that, even if I am passing bias as positional arg, when I step into the function call, bias is still interpreted as a keyword argument? (I can see it inside **kwargs)

Does this happen when you use __torch_function__? If so, the reason is because it's being treated as a keyword arg here: https://github.com/pytorch/pytorch/blob/d46689a2017cc046abdc938247048952df4f6de7/torch/nn/functional.py#L1846. I can look into why this actually is this way; I always thought it was weird

Is that what you mean by PyTorch operators can have kwarg-only arguments?

In Python it's possible to define a function with an argument that must be passed as a kwarg (and cannot be passed as positional). There are some examples here: https://python-3-for-scientists.readthedocs.io/en/latest/python3_advanced.html and some PyTorch operators are written this way. F.linear isn't, though

@xidulu
Copy link
Author

xidulu commented Jul 12, 2021

Yes! It happens when using "torch_function"
And that's indeed, little bit weird.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants