-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Decompose CompositeImplicitAutograd ops at the FuncTorchBatched key #56
Comments
Yes, this is trouble. I have two parallel thoughts here:
@zou3519 Well, VMap key has higher precedence than CompositeImplicitAutograd, so yes, that will just work. |
If functionalization could take care of this then that would be great. @ailzhang does functionalization handle something like the following?
@ezyang one alternative along the lines of "providing both the non-mutating and mutating versions" could be if we have the ability to define our own set of primitives with respect to autograd. Registering an override for the vmap key for I'm not sure it's possible to "define a new primitive with respect to autograd" out of tree, though: autograd functions exist but I'm not sure they're sufficient |
After some experimenting... it looks like if I want to make a new primitive called
unfortunately there's a lot of boilerplate here (e.g. setting up the autograd::Function and registering all of those overrides) |
My conception of functionalization is that it is a functional transformation, much like grad/vmap are, which take traces that have mutations and transform them into traces without mutation. So in the
Yes, this is possible. Today we have CPU and we have AutogradCPU; it is possible that given Batched, we should have AutogradBatched (this is a little weird, because Batched isn't a backend, but I'm guessing we probably could make it work). Then you would override the definition of |
We had a meeting on Thursday to discuss this. The main points:
|
Summary: See pytorch/functorch#56 Pull Request resolved: #63616 Reviewed By: zou3519 Differential Revision: D30438316 Pulled By: Chillee fbshipit-source-id: e84446d9f68b87daa0cfff75b3b8a972f36ec85a
Done |
We decided to revert this for now. Essentially, the "failure mode" for decomposing an op is much worse than not decomposing an operator. If we decompose an op and it's then slow/throws an error, the user will see a warning/error like So, for now, we think it's better to err on the side of not decomposing an op unless we explicitly do so. |
Background
@ezyang suggested to try this to minimize the number of operators we have to override. More concretely, instead of registering all 2000 operators to FuncTorchBatched; we only have to register (insert number here) of operators that are not composite w.r.t. autograd.
To be concrete, the suggestion was to add FuncTorchBatched to https://github.com/pytorch/pytorch/blob/8dd0570b34c7c378ae9729c21267546cba07fdc9/c10/core/DispatchKeySet.cpp#L28-L32
The experiment
I added FuncTorchBatched to https://github.com/pytorch/pytorch/blob/8dd0570b34c7c378ae9729c21267546cba07fdc9/c10/core/DispatchKeySet.cpp#L28-L32, recompiled PyTorch and functorch, and then ran the test suite. This leads to a fun number of failures (see here) that have the same root cause!
The problem is that some CompositeImplicitAutograd ops decompose to in-place operations that are not compatible with vmap (note here).
Can we solve these problems by just registering an override for the vmap key for those operations?
The text was updated successfully, but these errors were encountered: