Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decompose CompositeImplicitAutograd ops at the FuncTorchBatched key #56

Open
zou3519 opened this issue Jun 15, 2021 · 7 comments
Open

Comments

@zou3519
Copy link
Contributor

zou3519 commented Jun 15, 2021

Background

@ezyang suggested to try this to minimize the number of operators we have to override. More concretely, instead of registering all 2000 operators to FuncTorchBatched; we only have to register (insert number here) of operators that are not composite w.r.t. autograd.

To be concrete, the suggestion was to add FuncTorchBatched to https://github.com/pytorch/pytorch/blob/8dd0570b34c7c378ae9729c21267546cba07fdc9/c10/core/DispatchKeySet.cpp#L28-L32

The experiment

I added FuncTorchBatched to https://github.com/pytorch/pytorch/blob/8dd0570b34c7c378ae9729c21267546cba07fdc9/c10/core/DispatchKeySet.cpp#L28-L32, recompiled PyTorch and functorch, and then ran the test suite. This leads to a fun number of failures (see here) that have the same root cause!

The problem is that some CompositeImplicitAutograd ops decompose to in-place operations that are not compatible with vmap (note here).

Can we solve these problems by just registering an override for the vmap key for those operations?

  • that would solve the vmap(blah) problem but I'm not sure because a vmap(grad(blah)) is always going to decompose blah since it runs through the grad transform.
@ezyang
Copy link
Contributor

ezyang commented Jun 16, 2021

The problem is that some CompositeImplicitAutograd ops decompose to in-place operations that are not compatible with vmap (note here).

Yes, this is trouble. I have two parallel thoughts here:

  • We should ensure implicit autograd composites don't ever use mutation (but at cost of efficiency?)
  • Maybe we can provide both the non-mutating and mutating versions? Perhaps using @ailzhang's functionalization pass?

Can we solve these problems by just registering an override for the vmap key for those operations?

@zou3519 Well, VMap key has higher precedence than CompositeImplicitAutograd, so yes, that will just work.

@zou3519
Copy link
Contributor Author

zou3519 commented Jun 16, 2021

If functionalization could take care of this then that would be great. @ailzhang does functionalization handle something like the following?

x = torch.empty_like(y)
x.copy_(y)

@ezyang one alternative along the lines of "providing both the non-mutating and mutating versions" could be if we have the ability to define our own set of primitives with respect to autograd.
For example, .contiguous() eventually calls .copy_() -- .copy_ is the primitive with respect to autograd.

Registering an override for the vmap key for contiguous doesn't actually work because when someone does vmap(grad(blah)) then the dispatch for the grad transform is going to break up .contiguous() into its constituents and then vmap will see the .copy_ and it will be sad (that's what is going on in #55)

I'm not sure it's possible to "define a new primitive with respect to autograd" out of tree, though: autograd functions exist but I'm not sure they're sufficient

@zou3519
Copy link
Contributor Author

zou3519 commented Jun 16, 2021

After some experimenting... it looks like if I want to make a new primitive called functorch::to, then setting up an autograd::Function for it and registering overrides for the Autograd, CPU, and CUDA keys seems to make this work:

TORCH_LIBRARY_IMPL(functorch, Autograd, m) {
  // to_autograd invokes an autograd::Function
  m.impl("to", to_autograd);
}
TORCH_LIBRARY_IMPL(functorch, CPU, m) {
  // to_kernel just calls at::to
  m.impl("to", to_kernel);
}
TORCH_LIBRARY_IMPL(functorch, CUDA, m) {
  m.impl("to", to_kernel);
}

unfortunately there's a lot of boilerplate here (e.g. setting up the autograd::Function and registering all of those overrides)

@ezyang
Copy link
Contributor

ezyang commented Jun 17, 2021

My conception of functionalization is that it is a functional transformation, much like grad/vmap are, which take traces that have mutations and transform them into traces without mutation. So in the vmap(grad( case, what you would actually do is vmap(functionalize(grad( (Don't ask me about UX, I don't think you want users to have to insert the functionalize pass in explicitly, so we'd have to figure something out about automatically inserting this pass when necessary).

one alternative along the lines of "providing both the non-mutating and mutating versions" could be if we have the ability to define our own set of primitives with respect to autograd.

Yes, this is possible. Today we have CPU and we have AutogradCPU; it is possible that given Batched, we should have AutogradBatched (this is a little weird, because Batched isn't a backend, but I'm guessing we probably could make it work). Then you would override the definition of contiguous directly in AutogradBatched to get the better behavior. I'm not sure why you'd want to implement a functorch::to though...

@ezyang
Copy link
Contributor

ezyang commented Jun 21, 2021

We had a meeting on Thursday to discuss this. The main points:

  1. I don't want to remove inplace ops from autograd formulas (deoptimizing them)
  2. I am OK with bunching autograd operators into bigger units (like contiguous as itself, not a call to copy_)
  3. Functionalization should be a tool in the toolkit, in general
  4. It should be possible to have multiple composites and pick the one that contextually makes sense

facebook-github-bot pushed a commit to pytorch/pytorch that referenced this issue Aug 22, 2021
Summary:
See pytorch/functorch#56

Pull Request resolved: #63616

Reviewed By: zou3519

Differential Revision: D30438316

Pulled By: Chillee

fbshipit-source-id: e84446d9f68b87daa0cfff75b3b8a972f36ec85a
@zou3519
Copy link
Contributor Author

zou3519 commented Oct 13, 2021

Done

@zou3519 zou3519 closed this as completed Oct 13, 2021
@Chillee
Copy link
Contributor

Chillee commented Oct 22, 2021

We decided to revert this for now.

Essentially, the "failure mode" for decomposing an op is much worse than not decomposing an operator. If we decompose an op and it's then slow/throws an error, the user will see a warning/error like aten::op_user_doesnt_use can't be vmapped, and the user then has no idea where the error came from.

So, for now, we think it's better to err on the side of not decomposing an op unless we explicitly do so.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants