Skip to content

Commit

Permalink
code-generate non-aliasing {view}_copy kernels (pytorch#73442)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#73442

Test Plan: Imported from OSS

Reviewed By: ezyang

Differential Revision: D35016025

Pulled By: bdhirsh

fbshipit-source-id: 2a7f303ec76f5913b744c7822a531d55a57589c9
(cherry picked from commit 3abe13c)
  • Loading branch information
bdhirsh authored and pytorchmergebot committed Apr 11, 2022
1 parent dfcb703 commit 23b8414
Show file tree
Hide file tree
Showing 13 changed files with 909 additions and 85 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ generated_cpu_cpp = [
"aten/src/ATen/CompositeExplicitAutogradFunctions_inl.h",
"aten/src/ATen/CompositeImplicitAutogradFunctions.h",
"aten/src/ATen/CompositeImplicitAutogradFunctions_inl.h",
"aten/src/ATen/CompositeViewCopyKernels.cpp",
"aten/src/ATen/FunctionalInverses.h",
"aten/src/ATen/Functions.h",
"aten/src/ATen/Functions.cpp",
Expand Down
200 changes: 200 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@

- func: rename_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!)
variants: method
tags: inplace_view

- func: rename(Tensor(a) self, Dimname[]? names) -> Tensor(a)
variants: method
Expand Down Expand Up @@ -3262,6 +3263,7 @@
CPU: narrow_copy_dense_cpu
SparseCPU, SparseCUDA: narrow_copy_sparse
CompositeExplicitAutograd: narrow_copy_dense
tags: view_copy

- func: narrow_copy.SymInt(Tensor self, int dim, int start, SymInt length) -> Tensor
variants: function, method
Expand Down Expand Up @@ -11355,3 +11357,201 @@

- func: nested_tensor(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
variants: function

- func: _fw_primal_copy(Tensor self, int level) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: _fw_primal_copy
tags: view_copy

- func: _make_dual_copy(Tensor primal, Tensor tangent, int level) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: _make_dual_copy
tags: view_copy

- func: view_as_real_copy(Tensor self) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: view_as_real_copy
tags: view_copy

- func: view_as_complex_copy(Tensor self) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: view_as_complex_copy
tags: view_copy

- func: _conj_copy(Tensor self) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: _conj_copy
tags: view_copy

- func: _neg_view_copy(Tensor self) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: _neg_view_copy
tags: view_copy

- func: as_strided_copy(Tensor self, int[] size, int[] stride, int? storage_offset=None) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: as_strided_copy
tags: view_copy

- func: _sparse_broadcast_to_copy(Tensor self, int[] size) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: _sparse_broadcast_to_copy
tags: view_copy

- func: diagonal_copy(Tensor self, int offset=0, int dim1=0, int dim2=1) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: diagonal_copy
tags: view_copy

- func: expand_copy(Tensor self, int[] size, *, bool implicit=False) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: expand_copy
tags: view_copy

- func: permute_copy(Tensor self, int[] dims) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: permute_copy
tags: view_copy

- func: _reshape_alias_copy(Tensor self, int[] size, int[] stride) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: _reshape_alias_copy
tags: view_copy

- func: select_copy.int(Tensor self, int dim, int index) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: select_copy_int
tags: view_copy

- func: detach_copy(Tensor self) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: detach_copy
tags: view_copy

- func: slice_copy.Tensor(Tensor self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: slice_copy_Tensor
tags: view_copy

- func: split_copy.Tensor(Tensor self, int split_size, int dim=0) -> Tensor[]
variants: function
dispatch:
CompositeExplicitAutograd: split_copy_Tensor
tags: view_copy

- func: split_with_sizes_copy(Tensor self, int[] split_sizes, int dim=0) -> Tensor[]
variants: function
dispatch:
CompositeExplicitAutograd: split_with_sizes_copy
tags: view_copy

- func: squeeze_copy(Tensor self) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: squeeze_copy
tags: view_copy

- func: squeeze_copy.dim(Tensor self, int dim) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: squeeze_copy_dim
tags: view_copy

- func: t_copy(Tensor self) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: t_copy
tags: view_copy

- func: transpose_copy.int(Tensor self, int dim0, int dim1) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: transpose_copy_int
tags: view_copy

- func: unsqueeze_copy(Tensor self, int dim) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: unsqueeze_copy
tags: view_copy

- func: _indices_copy(Tensor self) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: _indices_copy
tags: view_copy

- func: _values_copy(Tensor self) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: _values_copy
tags: view_copy

- func: indices_copy(Tensor self) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: indices_copy
tags: view_copy

- func: values_copy(Tensor self) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: values_copy
tags: view_copy

- func: crow_indices_copy(Tensor self) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: crow_indices_copy
tags: view_copy

- func: col_indices_copy(Tensor self) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: col_indices_copy
tags: view_copy

- func: unbind_copy.int(Tensor self, int dim=0) -> Tensor[]
variants: function
dispatch:
CompositeExplicitAutograd: unbind_copy_int
tags: view_copy

- func: view_copy(Tensor self, int[] size) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: view_copy
tags: view_copy

- func: view_copy.dtype(Tensor self, ScalarType dtype) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: view_copy_dtype
tags: view_copy

- func: unfold_copy(Tensor self, int dimension, int size, int step) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: unfold_copy
tags: view_copy

- func: alias_copy(Tensor self) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: alias_copy
tags: view_copy
10 changes: 10 additions & 0 deletions aten/src/ATen/native/tags.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# This yaml file contains all the possible tags that can be defined in `tags` in `native_functions.yaml`

- tag: inplace_view
desc: |
This tag indicates if an operator *only* modifies the tensor metadata
- tag: view_copy
desc: |
This tag indicates operators that are *_copy* variants
of view/aliasing operators. If an operator has a view_copy tag,
then it should have the name {op}_copy, where {op} is a view operator.
20 changes: 20 additions & 0 deletions aten/src/ATen/templates/CompositeViewCopyKernels.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
// ${generated_comment}

#include <ATen/Tensor.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Operators.h>
#else
#include <ATen/ops/clone.h>
$ops_headers
#endif

namespace at {
namespace native {


${CompositeViewCopyKernel_Definitions}

} // namespace native
} // namespace at
18 changes: 18 additions & 0 deletions test/test_view_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,24 @@ def run_test(device, op):
op = partial(fn, source=0, destination=1)
run_test(device, op)

# Testing that the generated view_copy kernel and its derivative are implemented correctly
def test_view_copy(self, device):
a = torch.randn(4, device=device, requires_grad=True)
a_ref = a.clone().detach().requires_grad_()
a_view = a_ref.view(2, 2)
a_view_copy = torch.view_copy(a, (2, 2))

# view_copy ops don't preserve view relationship
self.assertTrue(self.is_view_of(a_ref, a_view))
self.assertFalse(self.is_view_of(a, a_view_copy))

a_view_copy.sum().backward()
a_view.sum().backward()

# forward and backward give the same shape + result
self.assertEqual(a_view_copy, a_view)
self.assertEqual(a.grad, a_ref.grad)

class TestOldViewOps(TestCase):
def test_ravel(self, device):

Expand Down
45 changes: 40 additions & 5 deletions tools/autograd/load_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,38 @@
tensorGeometryT, scalarTypeT, SpecialArgName,
OptionalCType, stringT)
from tools.codegen.api import cpp
from tools.codegen.gen import parse_native_yaml
from tools.codegen.gen import parse_native_yaml, get_grouped_by_view_native_functions
from tools.codegen.context import with_native_function
from tools.codegen.model import FunctionSchema, NativeFunction, Variant, Type
from tools.codegen.utils import IDENT_REGEX, split_name_params, YamlLoader
from tools.codegen.model import (
FunctionSchema, NativeFunction, Variant, Type,
NativeFunctionsViewGroup, OperatorName
)
from tools.codegen.utils import IDENT_REGEX, split_name_params, YamlLoader, concatMap

_GLOBAL_LOAD_DERIVATIVE_CACHE = {}

# This function directly adds derivative entries for {view}_copy variants of each view op.
# Since every {view} and {view}_copy op shares the same derivative formula,
# we generate them here instead of duplicating them in the yaml.
# See Note [Codegen'd {view}_copy Operators]
def add_view_copy_derivatives(
infos: List[DifferentiabilityInfo],
view_groups: List[NativeFunctionsViewGroup]
) -> List[DifferentiabilityInfo]:
# Get the map from each view op's name to its corresponding view group
view_name_to_group: Dict[OperatorName, NativeFunctionsViewGroup] = {
g.view.func.name: g for g in view_groups}

view_copy_differentiability_infos = []
for info in infos:
maybe_view_group = view_name_to_group.get(info.func.func.name, None)
if maybe_view_group is not None and maybe_view_group.view_copy is not None:
view_copy_info = info.create_view_copy_from_view_derivative(maybe_view_group)
if view_copy_info is not None:
view_copy_differentiability_infos.append(view_copy_info)

return view_copy_differentiability_infos

def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str) -> Sequence[DifferentiabilityInfo]:
# Do some caching as this is a deterministic function
global _GLOBAL_LOAD_DERIVATIVE_CACHE
Expand All @@ -30,15 +55,24 @@ def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str) -> Seque
with open(derivatives_yaml_path, 'r') as f:
definitions = yaml.load(f, Loader=YamlLoader)

functions = parse_native_yaml(native_yaml_path).native_functions
funcs = parse_native_yaml(native_yaml_path).native_functions
# From the parsed native functions, separate out the (generated) view_copy functions,
# so we can generate derivatives for them separately.
native_functions_with_view_groups = get_grouped_by_view_native_functions(funcs)
native_functions_without_view_copies = concatMap(
# We need to pull out the view_inplace ops too, since they might have their own derivative entries.
lambda g: [g] if isinstance(g, NativeFunction) else list(g.functions(include_copy=False)),
native_functions_with_view_groups
)
view_groups = [g for g in native_functions_with_view_groups if isinstance(g, NativeFunctionsViewGroup)]

# What's the difference between function schema v.s. signature?
# function schema is the complete declaration including mutability annotation / default value and etc.
# signature is the canonical schema for a group of functions (in-place/out/functional variants)
# that are semantically related.
functions_by_signature: Dict[FunctionSchema, List[NativeFunction]] = defaultdict(list)
functions_by_schema: Dict[str, NativeFunction] = dict()
for function in functions:
for function in native_functions_without_view_copies:
functions_by_signature[function.func.signature()].append(function)
assert str(function.func) not in functions_by_schema
functions_by_schema[str(function.func)] = function
Expand All @@ -50,6 +84,7 @@ def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str) -> Seque
infos = [
create_differentiability_info(defn, functions_by_signature, functions_by_schema, op_counter)
for defn in definitions]
infos += add_view_copy_derivatives(infos, view_groups)

_GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos

Expand Down
Loading

0 comments on commit 23b8414

Please sign in to comment.