Skip to content

Commit

Permalink
add a new alias key for functional to view op decompositions
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#79615

Approved by: https://github.com/zou3519
  • Loading branch information
bdhirsh authored and pytorchmergebot committed Jun 15, 2022
1 parent 6114b0f commit adf8060
Show file tree
Hide file tree
Showing 16 changed files with 188 additions and 125 deletions.
2 changes: 2 additions & 0 deletions BUCK.oss
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ ATEN_EXPORTED_HEADERS = {
"CPUFunctions_inl.h": ":gen_aten[CPUFunctions_inl.h]",
"CompositeExplicitAutogradFunctions.h": ":gen_aten[CompositeExplicitAutogradFunctions.h]",
"CompositeExplicitAutogradFunctions_inl.h": ":gen_aten[CompositeExplicitAutogradFunctions_inl.h]",
"CompositeExplicitAutogradNonFunctionalFunctions.h": ":gen_aten[CompositeExplicitAutogradNonFunctionalFunctions.h]",
"CompositeExplicitAutogradNonFunctionalFunctions_inl.h": ":gen_aten[CompositeExplicitAutogradNonFunctionalFunctions_inl.h]",
"CompositeImplicitAutogradFunctions.h": ":gen_aten[CompositeImplicitAutogradFunctions.h]",
"CompositeImplicitAutogradFunctions_inl.h": ":gen_aten[CompositeImplicitAutogradFunctions_inl.h]",
"FunctionalInverses.h": ":gen_aten[FunctionalInverses.h]",
Expand Down
3 changes: 3 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,15 @@ generated_cpu_cpp = [
"aten/src/ATen/RegisterZeroTensor.cpp",
"aten/src/ATen/RegisterCompositeImplicitAutograd.cpp",
"aten/src/ATen/RegisterCompositeExplicitAutograd.cpp",
"aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp",
"aten/src/ATen/RegisterMeta.cpp",
"aten/src/ATen/RegisterSchema.cpp",
"aten/src/ATen/CPUFunctions.h",
"aten/src/ATen/CPUFunctions_inl.h",
"aten/src/ATen/CompositeExplicitAutogradFunctions.h",
"aten/src/ATen/CompositeExplicitAutogradFunctions_inl.h",
"aten/src/ATen/CompositeExplicitAutogradNonFunctionalFunctions.h",
"aten/src/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h",
"aten/src/ATen/CompositeImplicitAutogradFunctions.h",
"aten/src/ATen/CompositeImplicitAutogradFunctions_inl.h",
"aten/src/ATen/CompositeViewCopyKernels.cpp",
Expand Down
33 changes: 23 additions & 10 deletions aten/src/ATen/core/dispatch/OperatorEntry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,13 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
// For any dispatch key, it'll pick a kernel using the following order:
// (1) Use kernel if it's directly registered to this key
// (2) Handle runtime keys that have kernels available from alias keys
// (2.1) Use kernel from DispatchKey::CompositeExplicitAutograd if available.
// (2.1) Use kernel from DispatchKey::CompositeExplicitAutogradNonFunctional if available.
// This is used to register a kernel that works for all backends in inference, except "functional" backends
// like LazyTensor/XLA. But it requires separate registration for Autograd keys to support training.
// (2.2) Use kernel from DispatchKey::CompositeExplicitAutograd if available.
// This is used to register a kernel that works for all backend in inference. But it requires
// separate registration for Autograd keys to support training.
// (2.2) Use kernel from DispatchKey::CompositeImplicitAutograd if available.
// (2.3) Use kernel from DispatchKey::CompositeImplicitAutograd if available.
// For autograd keys, we only use kernel from CompositeImplicitAutograd when there's no direct registration
// to its corresponding backend key or CompositeExplicitAutograd. See Note [CompositeExplicitAutograd and CompositeImplicitAutograd].
// For AutogradOther, we eagerly return ambiguousAutogradOtherKernel() if there's registration to any of
Expand All @@ -240,13 +243,13 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
// A CompositeExplicitAutograd kernel prevents CompositeImplicitAutograd kernel being used for Autograd keys, but it doesn't
// cause confusion for AutogradOther. It's pretty straightforward to use Autograd (if available)
// in this case.
// (2.3) Use kernel from DispatchKey::Autograd if available
// (2.4) Use kernel from DispatchKey::Autograd if available
// The implementation of (2.2) relies on the invariant that for a given backend,
// `computeDispatchTableEntryWithDebug()` will be called for that backend's autograd key after the
// backend key. See Note [Refresh Runtime Autograd entries in dispatchTable_]
// (3) Use fallthrough kernel that are registered as fallback.
// Alias Key Precedence:
// CompositeExplicitAutograd > CompositeImplicitAutograd > Autograd
// CompositExplicitAutogradNonFunctional > CompositeExplicitAutograd > CompositeImplicitAutograd > Autograd
// Note [CompositeExplicitAutograd and CompositeImplicitAutograd]
// When there're registrations to both CompositeExplicitAutograd & CompositeImplicitAutograd & Autograd, from (2.2) we know CompositeExplicitAutograd
// and Autograd kernels will be picked up and CompositeImplicitAutograd is overriden.
Expand All @@ -258,7 +261,15 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
return {*direct_registration, "kernel"};
}

// 2.1 Use CompositeExplicitAutograd kernel if available.
// 2.1 Use CompositeExplicitAutogradNonFunctional kernel if available.
// See Note [Undefined in dispatchTable_] for the special handling for Undefined.
if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::CompositeExplicitAutogradNonFunctional)) {
if (auto default_backend_registration = getKernelForDispatchKey(DispatchKey::CompositeExplicitAutogradNonFunctional)) {
return {*default_backend_registration, "default backend kernel"};
}
}

// 2.2 Use CompositeExplicitAutograd kernel if available.
// See Note [Undefined in dispatchTable_] for the special handling for Undefined.
if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::CompositeExplicitAutograd)) {
if (auto default_backend_registration = getKernelForDispatchKey(DispatchKey::CompositeExplicitAutograd)) {
Expand All @@ -273,7 +284,7 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
// See Note [No Alias Keys in DispatchKeySet]
hasKernelForDispatchKey(DispatchKey::CompositeExplicitAutograd);

// 2.2. Use CompositeImplicitAutograd kernel if available. For autograd keys, we only use kernel from CompositeImplicitAutograd
// 2.3. Use CompositeImplicitAutograd kernel if available. For autograd keys, we only use kernel from CompositeImplicitAutograd
// when there's no direct registration to its corresponding backend key or CompositeExplicitAutograd.
// For AutogradOther, we return ambiguousAutogradOtherKernel() if there's registration
// to any of its backends.
Expand All @@ -289,7 +300,7 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
}
}

// 2.3. For autograd backend keys, use kernel from DispatchKey::Autograd if available
// 2.4. For autograd backend keys, use kernel from DispatchKey::Autograd if available
if (isIncludedInAlias(dispatch_key, DispatchKey::Autograd)) {
if (auto autograd_registration = getKernelForDispatchKey(DispatchKey::Autograd)) {
return {*autograd_registration, "autograd kernel"};
Expand Down Expand Up @@ -339,9 +350,11 @@ void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, Disp
for (auto k : c10::getRuntimeDispatchKeySet(dispatch_key)) {
updateDispatchTableEntry_(dispatcher, k);
}
// Registration to CompositeExplicitAutograd and CompositeImplicitAutograd should be populated to Undefined.
// Registration to CompositeExplicitAutogradNonFunctional, CompositeExplicitAutograd and CompositeImplicitAutograd should be populated to Undefined.
// We cannot do this above since Undefined cannot be represented in DispatchKeySet.
if (dispatch_key == DispatchKey::CompositeImplicitAutograd || dispatch_key == DispatchKey::CompositeExplicitAutograd) {
if (dispatch_key == DispatchKey::CompositeImplicitAutograd
|| dispatch_key == DispatchKey::CompositeExplicitAutograd
|| dispatch_key == DispatchKey::CompositeExplicitAutogradNonFunctional) {
updateDispatchTableEntry_(dispatcher, DispatchKey::Undefined);
}
// Note [Refresh Runtime Autograd entries in dispatchTable_]
Expand Down Expand Up @@ -375,7 +388,7 @@ void OperatorEntry::updateDispatchTableFull_(const c10::Dispatcher& dispatcher)
// no dispatch keys are available we just slide into the undefined handler which would then raise
// the error message.
// In the old world of catchAll, the only way to "register" a kernel to Undefined is by registering it to
// catchAll. After catchAllKernel_ is removed, Undefined now can get a kernel from either CompositeExplicitAutograd
// catchAll. After catchAllKernel_ is removed, Undefined now can get a kernel from either CompositeExplicitAutograd,
// or CompositeImplicitAutograd alias key so that we don't break the support. Ideally isIncludedInAlias(Undefined, CompositeImplicitAutograd)
// should return true, it returns false because Undefined cannot be represented in a DispatchKeySet.
updateDispatchTable_(dispatcher, DispatchKey::Undefined);
Expand Down
14 changes: 13 additions & 1 deletion aten/src/ATen/native/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ to reuse the same function name in both cases.

Available backend options can be found by searching `dispatch_keys` in
[codegen](https://github.com/pytorch/pytorch/blob/master/torchgen/gen.py).
There are also two special "generic" backends:
There are also three special "generic" backends:

- `CompositeExplicitAutograd` (previously known as `DefaultBackend`):
implementations of kernels that work for all backends, but require an
Expand All @@ -305,6 +305,18 @@ There are also two special "generic" backends:
DispatchStub should NOT be registered as CompositeExplicitAutograd, as
DispatchStub only works for `CPU, CUDA`)

- `CompositeExplicitAutogradNonFunctional`:
Similar to CompositeExplicitAutograd, but this key should be used if:
(1) Your kernel is written for a non-aliasing operator.
(2) *and* it calls internally into an aliasing operator.
An example of this is select_backward, which is non-aliasing, but decomposes into select.
We would like to distinguish between "ordinary" CompositeExplicitAutograd kernels
and these kernels, because some backends would not like
to decompose an non-aliasing op into an aliasing op.
LazyTensor + XLA are the two current examples of this - since they operate on a functional IR,
they would prefer to directly implement a non-aliasing operator with their own kernel,
instead of using a decomposition that results in more aliasing operators.

- `CompositeImplicitAutograd` (previously known as `Math`): implementations of
kernels that work for all backends, and also can implicitly support autograd,
because all of the operations it calls support autograd. Direct use of
Expand Down
Loading

0 comments on commit adf8060

Please sign in to comment.