Skip to content

Commit

Permalink
Better reshape with autograd support (pytorch#82754) (pytorch#84154)
Browse files Browse the repository at this point in the history
The original author is @YifanShenSZ  and the original PR is: pytorch#82754
# Summary:
Previous reshape [https://github.com/pytorch/pytorch/issues/80981](https://github.com/pytorch/pytorch/pull/80981) is ok for forward, but needs improvement for backward: need to handle "sometimes view sometimes copy" behavior.

This pull request fixes it by:
1. add a new alias dispatch key `CompositeImplicitAutogradNestedTensor`, which ideally would work as nested-tensor version of `CompositeImplicitAutograd`
2. register `reshape_nested` to `reshape` by `CompositeImplicitAutogradNestedTensor`

Side changes:
* add contiguous memory format support to `clone_nested`
* add `view_nested`
* add `reshape_as_nested`

Fix issue [https://github.com/pytorch/pytorch/issues/83041](https://github.com/pytorch/pytorch/issues/83041)

Pull Request resolved: pytorch#82754

Test Plan:
Imported from GitHub, without a `Test Plan:` line.

**Static Docs Preview: executorch**
|[Full Site](https://our.intern.facebook.com/intern/staticdocs/eph/D39023822/V13/executorch/)|

|**Modified Pages**|

Reviewed By: albanD

Differential Revision: D39023822

Pulled By: drisspg

Pull Request resolved: pytorch#84154
Approved by: https://github.com/bdhirsh, https://github.com/albanD
  • Loading branch information
YifanShenSZ authored and pytorchmergebot committed Sep 1, 2022
1 parent 9bcad06 commit 673b35c
Show file tree
Hide file tree
Showing 17 changed files with 326 additions and 122 deletions.
3 changes: 3 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ generated_cpu_cpp = [
"aten/src/ATen/RegisterSparseCsrCPU.cpp",
"aten/src/ATen/RegisterZeroTensor.cpp",
"aten/src/ATen/RegisterCompositeImplicitAutograd.cpp",
"aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor.cpp",
"aten/src/ATen/RegisterCompositeExplicitAutograd.cpp",
"aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp",
"aten/src/ATen/RegisterMeta.cpp",
Expand All @@ -66,6 +67,8 @@ generated_cpu_cpp = [
"aten/src/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h",
"aten/src/ATen/CompositeImplicitAutogradFunctions.h",
"aten/src/ATen/CompositeImplicitAutogradFunctions_inl.h",
"aten/src/ATen/CompositeImplicitAutogradNestedTensorFunctions.h",
"aten/src/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h",
"aten/src/ATen/CompositeViewCopyKernels.cpp",
"aten/src/ATen/FunctionalInverses.h",
"aten/src/ATen/Functions.h",
Expand Down
15 changes: 15 additions & 0 deletions aten/src/ATen/core/dispatch/OperatorEntry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,21 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
// For AutogradOther, we return ambiguousAutogradOtherKernel() if there's registration
// to any of its backends.
// See Note [Undefined in dispatchTable_] for the special handling for Undefined.

// If the dispatch key is included in CompositeImplicitAutogradNestedTensor,
// then we register it to nested-tensor kernel rather than
// regular-tensor CompositeImplicitAutograd kernel.
// We have no intention to change the behavior of Undefined,
// so this nested-tensor branch requires `dispatch_key != DispatchKey::Undefined`
// to let the original CompositeImplicitAutograd handle Undefined
if (dispatch_key != DispatchKey::Undefined && isIncludedInAlias(dispatch_key, DispatchKey::CompositeImplicitAutogradNestedTensor)) {
if (auto nested_registration = getKernelForDispatchKey(DispatchKey::CompositeImplicitAutogradNestedTensor)) {
if (!has_backend_kernel) {
return {*nested_registration, "nested kernel"};
}
}
}

if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::CompositeImplicitAutograd)) {
if (auto math_registration = getKernelForDispatchKey(DispatchKey::CompositeImplicitAutograd)) {
if (dispatch_key == DispatchKey::AutogradOther
Expand Down
11 changes: 0 additions & 11 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1256,17 +1256,6 @@ Tensor alias_with_sizes_and_strides(
}

Tensor reshape(const Tensor& self, IntArrayRef proposed_shape) {
// reshape has special autograd logic since it sometimes returns a view but sometimes does not
// we have to intercept here instead of using dispatcher
// otherwise we will see "autograd still running" kind of error in inference mode:
// * if we create a tensor in inference mode scope,
// then pass it to a inference mode decorated function,
// everything is fine
// * but if we create the input tensor not with inference mode,
// then errors like "Cannot set version_counter for inference tensor" arise
if (self.is_nested()) {
return at::_reshape_nested(self, proposed_shape);
}
if (self.is_sparse()) {
AT_ERROR("reshape is not implemented for sparse tensors");
}
Expand Down
15 changes: 6 additions & 9 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4200,16 +4200,9 @@
variants: function, method
device_check: NoCheck
device_guard: False

- func: _reshape_nested(Tensor self, int[] shape) -> Tensor
dispatch:
NestedTensorCPU, NestedTensorCUDA: _reshape_nested
autogen: _reshape_nested.out

- func: _reshape_nested_backward(Tensor self, Tensor grad) -> Tensor
dispatch:
NestedTensorCPU, NestedTensorCUDA: _reshape_nested_backward
autogen: _reshape_nested_backward.out
CompositeImplicitAutograd: reshape
CompositeImplicitAutogradNestedTensor: reshape_nested

# NOTE [ _reshape_alias ] is meant to be used in the implementation of reshape.
# They are not user-facing, hence the leading underscore. Please don't use it
Expand All @@ -4233,6 +4226,9 @@
variants: method
device_check: NoCheck
device_guard: False
dispatch:
CompositeImplicitAutograd: reshape_as
CompositeImplicitAutogradNestedTensor: reshape_as_nested

- func: round(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
Expand Down Expand Up @@ -6889,6 +6885,7 @@
Meta: view_meta
ZeroTensor, CPU, CUDA, QuantizedCPU, QuantizedCUDA, MPS: view
MkldnnCPU: mkldnn_view
NestedTensorCPU, NestedTensorCUDA: view_nested

# Warning: If you want to change the name or overload name of this
# operator, you might also want to change the `isBlockListedSchema`
Expand Down
17 changes: 0 additions & 17 deletions aten/src/ATen/native/nested/NestedTensorBackward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,23 +66,6 @@ std::tuple<Tensor, Tensor, Tensor> nested_linear_backward(
return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
}

Tensor _reshape_nested_backward(const Tensor& self, const Tensor& grad) {
auto self_ptr = get_nested_tensor_impl(self);
// TODO: this is to reproduce self_ptr->opt_sizes_
// if an accessor is provided in the future, can replace this
std::vector<int64_t> sizes;
for (int64_t i = 0; i < self_ptr->dim(); i++) {
c10::optional<int64_t> opt_size = self_ptr->opt_size(i);
if (opt_size.has_value()) {
sizes.push_back(*opt_size);
}
else {
sizes.push_back(-1);
}
}
return grad.reshape(sizes);
}

Tensor nested_softmax_backward(
const Tensor& grad,
const Tensor& output,
Expand Down
Loading

0 comments on commit 673b35c

Please sign in to comment.