Skip to content

Commit

Permalink
[aotd] Support HOP effects in backward (pytorch#132638)
Browse files Browse the repository at this point in the history
Support of effectful operations in backward:

1/ AOTD collects metadata from forward fn only, so we can have usage of effectful ops in backward, that were not used in forward => Allowing tokens discovery during joint function .

FunctionalTensorMode holds _tokens, in Joint function after tracing forward we memoize _tokens as `_tokens_forward_output`.

2/ Tokens are added as primals inputs (forward) in EffectTokensWrapper.
Tokens that will be used in backward are in partitioner saved values. We do not have control on which positions they are saved in forward outputs.

2/ If new tokens discovered in backward after tracing joint_fn, the result graph will be manually added in the end of primals.
_aot_autograd/utils.py

3/ All effectful ops during backward are marked with 'must_be_in_backward' partitioner_tag, to prevent partiitoner to place them in forward.

For that functional_tensor_mode got new optional state `self._effects_partitioner_tag` for effectful ops, to set after tracing forward.

There are additional changes in partitioner to improve functionality of 'must_be_in_backward'

4/ Unlift tokens now should run for both forward and backward.
- As saved for backward tokens are placed on non static places - we identify input and output tokens to erase, by input and output of `with_effects` operation
- In forward we can have input tokens, discovered in backward, that are not used in with_effects ops in forward, but saved for backward. We identify them by position in forward inputs.

5/ Adding aot debug logging for graphs before unlifting and before adding additional primal for backward tokens.

Tests:
```
python test/higher_order_ops/test_with_effects.py
```

Pull Request resolved: pytorch#132638
Approved by: https://github.com/bdhirsh
  • Loading branch information
IvanKobzarev authored and pytorchmergebot committed Aug 23, 2024
1 parent 7fd3b69 commit 8ae4f82
Show file tree
Hide file tree
Showing 12 changed files with 566 additions and 102 deletions.
1 change: 1 addition & 0 deletions test/dynamo/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,7 @@ def fn(a):
"fusion",
"overlap",
"aot_graphs",
"aot_graphs_effects",
"post_grad_graphs",
"compiled_autograd",
"compiled_autograd_verbose",
Expand Down
286 changes: 263 additions & 23 deletions test/higher_order_ops/test_with_effects.py

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ def _create_graph(f, args, *, aot_config: AOTConfig) -> torch.fx.GraphModule:
# FunctionalTensorMode must be enabled here.
# See Note [Accessing .grad_fn on FunctionalTensor]
with enable_python_dispatcher(), FunctionalTensorMode(
pre_dispatch=aot_config.pre_dispatch, export=aot_config.is_export
pre_dispatch=aot_config.pre_dispatch,
export=aot_config.is_export,
# Allow token discovery for joint fn tracing as tokens can be used in backward.
_allow_token_discovery=True,
):
fx_g = make_fx(
f,
Expand Down Expand Up @@ -191,7 +194,7 @@ def _map_assigned_buffer_to_proxy(_mod, name, buffer):
# See Note [Side-Effectful Tokens in AOTAutograd]
num_tokens = len(fw_metadata.tokens)
if num_tokens != 0 and config.unlift_effect_tokens:
unlift_tokens(fw_module, fw_metadata)
unlift_tokens(fw_module, fw_metadata, aot_config)
saved_updated_flat_args_subclasses_desugared = (
saved_updated_flat_args_subclasses_desugared[num_tokens:]
)
Expand Down
29 changes: 20 additions & 9 deletions torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,10 +384,16 @@ def aot_dispatch_autograd(
)

# See Note [Side-Effectful Tokens in AOTAutograd]
if num_tokens != 0 and config.unlift_effect_tokens:
unlift_tokens(fw_module, fw_metadata)
if config.unlift_effect_tokens and (
num_tokens > 0 or fw_metadata.num_backward_tokens > 0
):
unlift_tokens(fw_module, fw_metadata, aot_config, bw_module)

num_inner_fwd_outputs -= num_tokens
joint_inputs = (joint_inputs[0][num_tokens:], joint_inputs[1])
joint_inputs = (
joint_inputs[0][num_tokens:],
joint_inputs[1],
)

fw_outs = next(iter(fw_module.graph.find_nodes(op="output"))).args[0]
# we only need to bookkeep the symints that are saved for bw, not any symints
Expand Down Expand Up @@ -484,16 +490,21 @@ def aot_dispatch_autograd(
# (b) The grad_outputs that we AOT computed in our backward graph are the desugared tensor tensors,
# so we need to figure out which subclass fw inputs they map to.
if maybe_subclass_meta is None:
num_backward_tokens: int = inner_meta.num_backward_tokens
assert (
len(bw_outs)
== len(fw_metadata.input_info) + inner_meta.num_outputs_rng_offset
== len(fw_metadata.input_info)
+ inner_meta.num_outputs_rng_offset
+ num_backward_tokens
)
bw_outs_no_rng = bw_outs
if inner_meta.num_outputs_rng_offset > 0:
bw_outs_no_rng = bw_outs[: -inner_meta.num_outputs_rng_offset]
assert len(bw_outs_no_rng) == len(fw_metadata.input_info)
bw_outs_no_rng_no_tokens = bw_outs
if (inner_meta.num_outputs_rng_offset + num_backward_tokens) > 0:
bw_outs_no_rng_no_tokens = bw_outs[
: -(inner_meta.num_outputs_rng_offset + num_backward_tokens)
]
assert len(bw_outs_no_rng_no_tokens) == len(fw_metadata.input_info)

for i, (bw_out) in enumerate(bw_outs_no_rng):
for i, (bw_out) in enumerate(bw_outs_no_rng_no_tokens):
# If our input experiences a metadata mutation inside the graph (e.g. set_()),
# we *must* not detach, otherwise it will be the detach'd input that gets the metadata mutation
metadata_mutation_in_graph = (
Expand Down
45 changes: 33 additions & 12 deletions torch/_functorch/_aot_autograd/runtime_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ def post_compile(
@wraps(compiled_fn)
def inner_fn(args: List[Any]):
if num_tokens > 0:
# Pass in effect tokens (See Note [Side-Effectful Tokens in AOTAutograd])
# Pass in forward effect tokens (See Note [Side-Effectful Tokens in AOTAutograd])
old_args = args
args = [*([None] * num_tokens), *args]
old_args.clear()
Expand Down Expand Up @@ -1730,20 +1730,23 @@ def backward(ctx, *flat_args):
# Add the seed and offset to args
rng_args = CUDARngStateHelper.get_torch_state_as_tuple()

bw_tokens = [None] * CompiledFunction.metadata.num_backward_tokens

# - note: donated buffer logic requires (*ctx.symints, *ctx.saved_tensors) showing up first
# in the bw output order.

# Every dereference of ctx.saved_tensors incurs saved_tensors_hooks calls
# There are tests that count these calls, saving to var.
ctx_saved_tensors = ctx.saved_tensors
num_ctx_saved_tensors = len(ctx_saved_tensors)
all_args = [
*ctx.symints,
*ctx.saved_tensors,
*ctx_saved_tensors,
*flat_bw_args_with_grads,
*bw_tokens,
*rng_args,
]
del flat_bw_args_with_grads

tangents_start_idx = (
len(all_args) - num_flat_bw_args_with_grads - len(rng_args)
)
tangents_end_idx = len(all_args) - len(rng_args)
del ctx_saved_tensors

# Note: [AOTAutograd Backward Guards]
# During AOTDispatch, we eagerly create and trace out a joint fw-bw graph.
Expand Down Expand Up @@ -1771,9 +1774,8 @@ def backward(ctx, *flat_args):
len(CompiledFunction.metadata.output_types)
== num_flat_bw_args_with_grads
)
grad_output_types = [
type(x) for x in all_args[-num_flat_bw_args_with_grads:]
]

grad_output_types = [type(x) for x in flat_bw_args_with_grads]
# In general, we can add more asserts/guards here for when we partitioned
# with incorrect assumptions about the grad_outputs.
# Normalize FakeTensor -> torch.Tensor
Expand All @@ -1791,6 +1793,17 @@ def backward(ctx, *flat_args):
Expected grad_output types: {str(CompiledFunction.metadata.output_types)}
Got grad_output types: {str(grad_output_types)}"""

del flat_bw_args_with_grads

tangents_start_idx = (
len(all_args)
- num_flat_bw_args_with_grads
- len(rng_args)
- len(bw_tokens)
)
assert tangents_start_idx == len(ctx.symints) + num_ctx_saved_tensors
tangents_end_idx = len(all_args) - len(rng_args) - len(bw_tokens)

# TODO: figure out how to refactor the backward properly
# so I can use aot_dispatch_subclass_wrapper() here.
if CompiledFunction.maybe_subclass_metadata is not None:
Expand Down Expand Up @@ -1855,7 +1868,9 @@ def get_types_for_tangents(tangents):
all_args = unwrap_tensor_subclasses(
all_args, is_joint_structure=False
)
tangents_start_idx = len(all_args) - len_tangents - len(rng_args)
tangents_start_idx = (
len(all_args) - len_tangents - len(rng_args) - len(bw_tokens)
)
tangents_end_idx = tangents_start_idx + len_tangents

# Make the tangents contiguous. Note that we must do this after subclass desugaring
Expand Down Expand Up @@ -1968,6 +1983,12 @@ def call_compiled_backward():
steal_args=True,
disable_amp=disable_amp,
)

# Toss out the backward output tokens
num_bw_tokens = CompiledFunction.metadata.num_backward_tokens
if num_bw_tokens > 0:
out = out[:-num_bw_tokens]

# TODO: replace this with FunctionalizedRngRuntimeWrapper.post_compile
out = FunctionalizedRngRuntimeWrapper()._functionalized_rng_runtime_epilogue(
CompiledFunction.metadata, out, offset_index=len(out) - 1
Expand Down
5 changes: 5 additions & 0 deletions torch/_functorch/_aot_autograd/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,10 @@ class ViewAndMutationMeta:
# and backward output.
bw_donated_idxs: Optional[List[int]] = None

# Number of tokens used in backward, appended at the end of backward outputs.
# Filled after tracing joint function.
num_backward_tokens: int = 0

def __post_init__(self):
# pre-compute the indices of the inputs that are mutated.
# When keep_input_mutations is set, we don't need to worry about our epilogue
Expand Down Expand Up @@ -566,6 +570,7 @@ def __eq__(self, other):
x.shape == y.shape and x.dtype == y.dtype
for x, y, in zip(self.traced_tangents, other.traced_tangents)
)
and self.num_backward_tokens == other.num_backward_tokens
)


Expand Down
58 changes: 47 additions & 11 deletions torch/_functorch/_aot_autograd/traced_function_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,24 @@ def inner_fn(primals: List[Any], tangents: List[Any]):
backward_out: Tuple[Tensor, ...] = ()
# Call the backwards pass
if grad_primals:
with fx_traceback.preserve_node_meta():
functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode(
torch._C._TorchDispatchModeKey.FUNCTIONAL
)
if functional_tensor_mode is not None:
# Side-Effect Tokens:
# We want to have independent chains of tokens for forward and backward.
# functional_tensor_mode._tokens is used by both.
# We memoize the result tokens of forward in functional_tensor_mode._tokens_forward_output,
# to return them as joint graph outputs.
# We clean functional_tensor_mode._tokens before backward, to prevent reuse of forward tokens in backward.
# Joint graph tracing allows tokens discovery,
# So all the tokens in backward will be created and added as a graph inputs during tracing.
functional_tensor_mode._tokens_forward_output = (
functional_tensor_mode._tokens
)
functional_tensor_mode._tokens = {}

with set_partitioner_tag_is_backward(), fx_traceback.preserve_node_meta():
# for full graph export, we always export a joint graph where we assume no tangents are needed.
if aot_config.no_tangents:
assert len(needed_tangents) == 1 and needed_tangents[0].numel() == 1
Expand Down Expand Up @@ -348,6 +365,14 @@ def set_partitioner_tag(tag: str):
fx_traceback.current_meta[meta_key] = original_val


def set_partitioner_tag_is_backward():
return set_partitioner_tag("is_backward")


def set_partitioner_tag_must_be_in_backward():
return set_partitioner_tag("must_be_in_backward")


# This creates the final function that we want to trace using make_fx(),
# in both aot_dispatch_autograd and aot_dispatch_base.
# Preconditions:
Expand Down Expand Up @@ -439,9 +464,7 @@ def _functionalized_f_helper(*args):
# Not banning here mutations on inpt_info.requires_grad -
# we'll check at runtime and fail only when backward is under torch.is_grad_enabled (create_graph)
# Add node meta for copy_ for partitioner that this node should be in backward graph.
with torch.fx.traceback.preserve_node_meta(), set_partitioner_tag(
"must_be_in_backward"
):
with torch.fx.traceback.preserve_node_meta(), set_partitioner_tag_must_be_in_backward():
before.copy_(after)
meta.indices_of_inputs_that_requires_grad_with_mutations_in_bw.append(
idx
Expand Down Expand Up @@ -649,13 +672,13 @@ def inner_fn(*args):
if trace_joint:
assert isinstance(args, tuple) and isinstance(args[0], (list, tuple))
tokens = args[0][:num_tokens]
assert all(token.numel() == 0 for token in tokens)
args = (args[0][num_tokens:], *args[1:])
else:
tokens = args[:num_tokens]
assert all(token.numel() == 0 for token in tokens)
args = args[num_tokens:]

assert all(token.numel() == 0 for token in tokens)

# Populate the current FunctionalTensorMode with the tokens per
# operator. See Note [FunctionalTensorMode is Stateful]
functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode(
Expand All @@ -671,17 +694,30 @@ def inner_fn(*args):

# Return both the tokens and the outputs
# See Note [Side-Effectful Tokens in AOTAutograd]
f_out_tokens = functional_tensor_mode._tokens.values()
out_tokens = [from_fun(t) for t in f_out_tokens]
if trace_joint:
assert len(outs) == 2
assert len(functional_tensor_mode._tokens_forward_output) == num_tokens
fwd_out_tokens = functional_tensor_mode._tokens_forward_output.values()

bwd_out_tokens = functional_tensor_mode._tokens.values()

f_fwd_out_tokens = [from_fun(t) for t in fwd_out_tokens]
f_bwd_out_tokens = [from_fun(t) for t in bwd_out_tokens]

meta.num_backward_tokens = len(bwd_out_tokens)
return ((*f_fwd_out_tokens, *outs[0]), (*outs[1], *f_bwd_out_tokens))

out_tokens = [from_fun(t) for t in functional_tensor_mode._tokens.values()]
return (*out_tokens, *outs)

# Additionally pass in tokens as inputs
# See Note [Side-Effectful Tokens in AOTAutograd]
additional_token_inputs = [torch.tensor([])] * len(meta.tokens)
additional_fwd_token_inputs = [torch.tensor([])] * num_tokens

if trace_joint:
args = ([*additional_token_inputs, *args[0]], *args[1:])
args = ([*additional_fwd_token_inputs, *args[0]], *args[1:])
else:
args = [*additional_token_inputs, *args]
args = [*additional_fwd_token_inputs, *args]
return inner_fn, args


Expand Down
Loading

0 comments on commit 8ae4f82

Please sign in to comment.