Skip to content

Commit

Permalink
Update AOTAutograd to use FunctionalTensorMode instead of C++ functio…
Browse files Browse the repository at this point in the history
…nalization (pytorch#106406)

Now that FunctionalTensor and `FunctionalTensorMode` are lower down in this stack, the changes in this PR are more mechanical: Everywhere in AOTAutograd that I used to use the C++ functionalization API, I now use the python functionalization API.

Note that this doesn't actually cause functionalization to run underneath torch_dispatch. I'm saving that re-ordering for later in the stack.

Pull Request resolved: pytorch#106406
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#108654, pytorch#109662, pytorch#109632, pytorch#109023
  • Loading branch information
bdhirsh authored and pytorchmergebot committed Sep 22, 2023
1 parent 63526a6 commit b5d6e83
Show file tree
Hide file tree
Showing 18 changed files with 211 additions and 127 deletions.
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/xla.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
fc48cfe24eb99b883585ca16590a3ef32652300e
5dd9108adb3f5693c833b9e637f4614dc5770057
19 changes: 1 addition & 18 deletions functorch/experimental/_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._dispatch.python import suspend_functionalization
from torch._functorch.aot_autograd import AOTConfig, create_joint
from torch._functorch.aot_autograd import AOTConfig, create_joint, from_fun

from torch._higher_order_ops.cond import (
_has_potential_branch_input_alias,
Expand All @@ -27,23 +27,6 @@
from torch.multiprocessing.reductions import StorageWeakRef


# TODO: pull these helpers from AOTAutograd later
def to_fun(t):
if isinstance(t, torch.Tensor):
return FunctionalTensor.to_functional(t)
return t


def from_fun(t):
if not isinstance(t, FunctionalTensor):
# quick sanity assert
if isinstance(t, torch.Tensor):
assert not torch._is_functional_tensor(t)
return t
torch._sync(t)
return torch._from_functional_tensor(t.elem)


# TODO: We add this to prevent dymamo from tracing into map_wrapper,
# remove the wrapper call when it's ready.
class MapWrapper(HigherOrderOperator):
Expand Down
2 changes: 2 additions & 0 deletions test/dynamo/test_aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,7 @@ def _prepare_model_args():
0|aten.add.Tensor|l__self___bn1
1|aten._native_batch_norm_legit_functional.default|l__self___bn1
2|aten.relu.default|l__self___relu1
2|aten.detach.default|l__self___relu1
3|aten.add.Tensor|add
4|aten.view.default|flatten
5|aten.view.default|l__self___fc1
Expand All @@ -841,6 +842,7 @@ def _prepare_model_args():
6|aten.t.default|
5|aten.view.default|
4|aten.view.default|
2|aten.detach.default|
2|aten.threshold_backward.default|
1|aten.native_batch_norm_backward.default|
0|aten.convolution_backward.default|
Expand Down
46 changes: 27 additions & 19 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,8 @@ def forward(self, primals_1):
view = torch.ops.aten.view.default(mul, [-1])
select = torch.ops.aten.select.int(mul, 0, 0)
detach = torch.ops.aten.detach.default(select); select = None
return [view, mul, detach]""")
detach_1 = torch.ops.aten.detach.default(detach); detach = None
return [view, mul, detach_1]""")

def test_output_aliases_intermediate_inplace_view(self):
def f(a):
Expand Down Expand Up @@ -1301,8 +1302,8 @@ def forward(self, primals_1):
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = add = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_2 = as_strided_5 = None
as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_3); as_strided_2 = as_strided_3 = None
return [as_strided_scatter, add_1]""") # noqa: B950

def test_input_mutation_aliases_other_input2(self):
Expand All @@ -1327,8 +1328,8 @@ def forward(self, primals_1):
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = add = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2, 2], [2, 1], 0)
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_2 = as_strided_5 = None
as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [2, 2], [2, 1], 0)
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_3); as_strided_2 = as_strided_3 = None
return [as_strided_scatter, add_1]""") # noqa: B950

def test_input_mutation_aliases_and_output_alias(self):
Expand All @@ -1351,8 +1352,8 @@ def forward(self, primals_1):
as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0)
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None
as_strided_8 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_1 = torch.ops.aten.view.default(as_strided_8, [4]); as_strided_8 = None
as_strided_6 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_1 = torch.ops.aten.view.default(as_strided_6, [4]); as_strided_6 = None
return [as_strided_scatter, view_1]""") # noqa: B950

def test_input_aliased_with_mutation_output_alias(self):
Expand Down Expand Up @@ -1381,8 +1382,8 @@ def forward(self, primals_1, primals_2):
mul = torch.ops.aten.mul.Tensor(as_strided_1, 2); as_strided_1 = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
as_strided_7 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_1 = torch.ops.aten.view.default(as_strided_7, [-1]); as_strided_7 = None
as_strided_6 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_1 = torch.ops.aten.view.default(as_strided_6, [-1]); as_strided_6 = None
return [as_strided_scatter, add, view_1]""") # noqa: B950

def test_input_metadata_mutation_aliases(self):
Expand Down Expand Up @@ -1465,11 +1466,11 @@ def forward(self, primals_1, primals_2, primals_3):
add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None
add_1 = torch.ops.aten.add.Tensor(primals_2, primals_3); primals_2 = primals_3 = None
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
t_1 = torch.ops.aten.t.default(as_strided_5); as_strided_5 = None
as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
t_1 = torch.ops.aten.t.default(as_strided_3); as_strided_3 = None
add_2 = torch.ops.aten.add.Tensor(add_1, t_1); add_1 = None
as_strided_14 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_1 = torch.ops.aten.view.default(as_strided_14, [-1]); as_strided_14 = None
as_strided_11 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_1 = torch.ops.aten.view.default(as_strided_11, [-1]); as_strided_11 = None
return [as_strided_scatter, add_2, view_1, t_1]""") # noqa: B950

@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
Expand Down Expand Up @@ -1533,8 +1534,8 @@ def forward(self, primals_1, primals_2):
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
t = torch.ops.aten.t.default(view); view = None
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
add = torch.ops.aten.add.Tensor(as_strided_5, as_strided_2); as_strided_5 = as_strided_2 = None
as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
add = torch.ops.aten.add.Tensor(as_strided_3, as_strided_2); as_strided_3 = as_strided_2 = None
view_1 = torch.ops.aten.view.default(add, [-1])
t_1 = torch.ops.aten.t.default(t)
unsqueeze = torch.ops.aten.unsqueeze.default(view_1, 0)
Expand Down Expand Up @@ -2169,12 +2170,18 @@ def forward(self, arg0_1: f32[3, 1, 1, 1], arg1_1: f32[3], arg2_1: f32[3], arg3_
getitem_4: f32[3] = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
relu: f32[1, 3, 3, 3] = torch.ops.aten.relu.default(getitem); getitem = None
detach: f32[1, 3, 3, 3] = torch.ops.aten.detach.default(relu)
sum_1: f32[] = torch.ops.aten.sum.default(relu)
detach_1: f32[1, 3, 3, 3] = torch.ops.aten.detach.default(relu)
detach_2: f32[1, 3, 3, 3] = torch.ops.aten.detach.default(detach_1); detach_1 = None
sum_1: f32[] = torch.ops.aten.sum.default(relu)
detach_3: f32[1, 3, 3, 3] = torch.ops.aten.detach.default(relu); relu = None
detach_4: f32[1, 3, 3, 3] = torch.ops.aten.detach.default(detach_3); detach_3 = None
detach_5: f32[1, 3, 3, 3] = torch.ops.aten.detach.default(detach_4); detach_4 = None
detach_6: f32[1, 3, 3, 3] = torch.ops.aten.detach.default(detach_5); detach_5 = None
ones_like: f32[] = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format)
expand: f32[1, 3, 3, 3] = torch.ops.aten.expand.default(ones_like, [1, 3, 3, 3]); ones_like = None
threshold_backward: f32[1, 3, 3, 3] = torch.ops.aten.threshold_backward.default(expand, relu, 0); expand = relu = None
detach_7: f32[1, 3, 3, 3] = torch.ops.aten.detach.default(detach_2); detach_2 = None
detach_8: f32[1, 3, 3, 3] = torch.ops.aten.detach.default(detach_7); detach_7 = None
threshold_backward: f32[1, 3, 3, 3] = torch.ops.aten.threshold_backward.default(expand, detach_8, 0); expand = detach_8 = None
native_batch_norm_backward = torch.ops.aten.native_batch_norm_backward.default(threshold_backward, convolution, arg2_1, getitem_3, getitem_4, getitem_1, getitem_2, True, 1e-05, [True, True, True]); threshold_backward = convolution = arg2_1 = getitem_1 = getitem_2 = None
getitem_5: f32[1, 3, 3, 3] = native_batch_norm_backward[0]
getitem_6: f32[3] = native_batch_norm_backward[1]
Expand All @@ -2183,7 +2190,7 @@ def forward(self, arg0_1: f32[3, 1, 1, 1], arg1_1: f32[3], arg2_1: f32[3], arg3_
getitem_8 = convolution_backward[0]
getitem_9: f32[3, 1, 1, 1] = convolution_backward[1]
getitem_10: f32[3] = convolution_backward[2]; convolution_backward = None
return (getitem_3, getitem_4, add, sum_1, detach_2, getitem_9, getitem_10, getitem_6, getitem_7)
return (getitem_3, getitem_4, add, sum_1, detach_6, getitem_9, getitem_10, getitem_6, getitem_7)
""") # noqa: B950


Expand Down Expand Up @@ -2213,7 +2220,8 @@ def forward(self, arg0_1: f32[3, 1, 1, 1], arg1_1: f32[3], arg2_1: f32[3], arg3_
relu: f32[1, 3, 3, 3] = torch.ops.aten.relu.default(getitem); getitem = None
sum_1: f32[] = torch.ops.aten.sum.default(relu)
detach: f32[1, 3, 3, 3] = torch.ops.aten.detach.default(relu); relu = None
return (getitem_3, getitem_4, add, sum_1, detach)
detach_1: f32[1, 3, 3, 3] = torch.ops.aten.detach.default(detach); detach = None
return (getitem_3, getitem_4, add, sum_1, detach_1)
""") # noqa: B950
# Some important characteristics of the exported graph below:
# 8 arguments: 2 params from conv, 2 params from batchnorm, 2 buffers from 1 batchnorm, 1 user input
Expand Down
71 changes: 42 additions & 29 deletions torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from functools import partial, wraps
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union, NewType
from unittest.mock import patch
from torch.utils._python_dispatch import is_traceable_wrapper_subclass

from functorch import make_fx

Expand All @@ -27,6 +28,7 @@
from torch._prims_common import CUDARngStateHelper
from torch._logging import getArtifactLogger
from torch._subclasses import FakeTensor, FakeTensorMode
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
from torch.fx import immutable_collections, Interpreter
from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types
from torch.fx.experimental.symbolic_shapes import ShapeEnv, is_concrete_int, fx_placeholder_vals
Expand Down Expand Up @@ -676,17 +678,37 @@ def gen_alias_from_base(aliased_base_tensor, target_meta_tensor, target_requires

def to_fun(t):
if isinstance(t, Tensor):
out = torch._to_functional_tensor(t)
torch._mirror_autograd_meta_to(t, out)
return out
return FunctionalTensor.to_functional(t)
else:
return t

def from_fun(t):
if not isinstance(t, Tensor) or not torch._is_functional_tensor(t):
if not isinstance(t, FunctionalTensor):
# quick sanity assert
if isinstance(t, torch.Tensor):
assert not torch._is_functional_tensor(t)
return t
torch._sync(t)
return torch._from_functional_tensor(t)
return torch._from_functional_tensor(t.elem)

def is_fun(t):
if isinstance(t, Tensor) and is_traceable_wrapper_subclass(t):
# See Note [Functionalization always runs last]
# This means that if we want to "functionalize" a subclass, we need to ensure that the functional wrapper
# goes at the bottom.
# recurse here, so we can support nested wrapper subclasses
t_attrs, _ = t.__tensor_flatten__()
t_inners = [getattr(t, attr) for attr in t_attrs]
any_fun = any(is_fun(x) for x in t_inners)
all_fun = all(is_fun(x) for x in t_inners)
assert any_fun == all_fun
return any_fun

return isinstance(t, FunctionalTensor)

def has_metadata_mutation(t):
assert isinstance(t, FunctionalTensor)
return torch._functionalize_has_metadata_mutation(t.elem)

def _get_hints(exprs):
"""
Expand Down Expand Up @@ -725,23 +747,16 @@ def run_functionalized_fw_and_collect_metadata(
) -> ViewAndMutationMeta:
memo = {}

def to_fun(t):
def _to_fun(t):
if isinstance(t, Tensor):
if t in memo:
return memo[t]
r = torch._to_functional_tensor(t)
torch._mirror_autograd_meta_to(t, r)
r = to_fun(t)
memo[t] = r
return r
else:
return t

def from_fun(t):
if not isinstance(t, Tensor) or not torch._is_functional_tensor(t):
return t
torch._sync(t)
return torch._from_functional_tensor(t)

@wraps(f)
def inner(*flat_args):
# This function is meant to be run with the forward, which expects a flat list of tensor/symint/other args.
Expand All @@ -752,30 +767,28 @@ def inner(*flat_args):
input_requires_grad_info: List[bool] = []
output_requires_grad_info: List[bool] = []

flat_f_args = pytree.tree_map(to_fun, flat_args)
flat_f_args = pytree.tree_map(_to_fun, flat_args)

torch._enable_functionalization(reapply_views=True)
try:
# See Note [Disabling Functionalize TLS Above Python Functionalization]
disable_above = torch._C._ExcludeDispatchKeyGuard(torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize))
with disable_above, FunctionalTensorMode():
# precondition: The passed in function already handles unflattening inputs + flattening outputs
flat_f_outs = f(*flat_f_args)
finally:
torch._disable_functionalization()

# Inspect the state of the input tensor functional wrapper to detect input mutation info
# If inp[i] has a metadata-only mutation, then maybe_inputs_with_mutated_metadata[i] contains the updated version
for (i, (arg, f_arg)) in enumerate(zip(flat_args, flat_f_args)):
if not isinstance(arg, Tensor):
new_arg = arg
else:
torch._sync(f_arg)
new_arg = torch._from_functional_tensor(f_arg)
new_arg = from_fun(f_arg)
if arg is not new_arg:
if StorageWeakRef(arg.untyped_storage()) == StorageWeakRef(new_arg.untyped_storage()):
mutates_data = False
mutates_metadata = True
else:
mutates_data = True
mutates_metadata = torch._functionalize_has_metadata_mutation(f_arg)
mutates_metadata = has_metadata_mutation(f_arg)
# Only track requires_grad info on *mutated* inputs,
# because they show up in the autograd.Function.forward as outputs
input_requires_grad_info.append(
Expand Down Expand Up @@ -1360,12 +1373,12 @@ def create_functionalized_graph(
def functionalized_f_helper(*args):
# Wrap inputs into functional wrappers
f_args = pytree.tree_map(to_fun, args)
torch._enable_functionalization(reapply_views=True)
try:

# See Note [Disabling Functionalize TLS Above Python Functionalization]
disable_above = torch._C._ExcludeDispatchKeyGuard(torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize))
with disable_above, FunctionalTensorMode():
# Run the joint
f_outs = fn(*f_args)
finally:
torch._disable_functionalization()

if aot_config.keep_inference_input_mutations and not trace_joint:
# Note: This is a bit annoying. There's a layering issue here, where:
Expand Down Expand Up @@ -1395,8 +1408,8 @@ def functionalized_f_helper(*args):
for i, (inpt_old, inpt_f) in enumerate(zip(args, f_args)):
if not isinstance(inpt_f, torch.Tensor):
continue
torch._sync(inpt_f)
inpt_new = torch._from_functional_tensor(inpt_f)
assert is_fun(inpt_f)
inpt_new = from_fun(inpt_f)
if meta.input_info[i].mutates_data and not meta.input_info[i].mutates_metadata:
# We found an input that had a (data-only) mutation.
# Since keep_input_mutations is set, we need to faithfully apply a copy_()
Expand Down Expand Up @@ -2068,7 +2081,7 @@ def create_synthetic_base_metadata(
input_metadata_output_info = [
OutputAliasInfo(
output_type=OutputType.alias_of_input,
raw_type=torch.Tensor,
raw_type=FunctionalTensor,
dynamic_dims={i for i, s in enumerate(outer_args[outer_idx].shape) if not is_concrete_int(s)},
base_idx=synthetic_base_info[outer_idx][0],
) for outer_idx in outer_aliased_arg_idx_with_metadata_mutations]
Expand Down
Loading

0 comments on commit b5d6e83

Please sign in to comment.