Skip to content

Commit

Permalink
AOTAutograd: handle set_(), detect metadata mutations that cancel out (
Browse files Browse the repository at this point in the history
…pytorch#111554)

This should be enough to get @voznesenskym 's FSDP branch to plumb `set_()` through AOTAutograd properly and have everything properly no-op out. Main changes are:

(1) graph break on `aten::set_.source_Tensor_storage_offset` (we could support it but it isn't needed, seems safer to graph break)

(2) Functionalization: add a "proper" functionalization kernel for `aten::set_.source_Tensor`. The previous one we had was codegen'd and it was wrong (it would just clone() and call set_(), which does not do the right thing). I also manually mark on the `FunctionalTensorWrapper` when a given tensor has been mutated by a `set_()` call.

(3) AOTAutograd: I added a new field, `InputAliasInfo.mutates_storage_metadata`, so we can distinguish between "regular" metadata mutations, and metadata mutations due to `set_()` calls. This is mainly because at runtime, one requires calling `as_strided_()` to fix up metadata, while the other requires calling `set_()`.

(4) Made AOTAutograd's detection for metadata mutations / set_() mutations smarter and detect no-ops (if the storage and metadata are all the same).

I also killed `was_updated()` and `was_metadata_updated()`, and replaced them with (existing) `has_data_mutation() ` and (new) `has_data_mutation()`, which can more accurately distinguish between data-mutation vs. `set_()` calls vs. metadata-mutation

**This PR is still silently correct in one case though**, which I'd like to discuss more. In particular, this example:
```
def f(x):
    x_view = x.view(-1)
    x.set_(torch.ones(2))
    x_view.mul_(2)
    return
```

If you have an input that experiences both a data-mutation **and** a `x_old.set_(x_new)` call, there are two cases:

(a) the data mutation happened on the storage of `x_new`. This case should be handled automatically: if x_new is a graph intermediate then we will functionalize the mutation. If x_new is a different graph input, then we will perform the usual `copy_()` on that other graph input

(b) the data mutation happened on the storage of `x_old`. This is more of a pain to handle, and doesn't currently work. At runtime, the right thing to do is probably something like:
```

def functionalized_f(x):
    x_view = x.view(-1)
    # set_() desugars into a no-op; later usages of x will use x_output
    x_output = torch.ones(2)
    # functionalize the mutation on x_view
    x_view_updated = x.mul(2)
    x_updated = x_view_updated.view(x.shape)
    # x experienced TWO TYPES of mutations; a data mutation and a metatadata mutation
    # We need to return both updated tensors in our graph
    return x_updated, x_output
def runtime_wrapper(x):
    x_data_mutation_result, x_set_mutation_result = compiled_graph(x)
    # First, perform the data mutation on x's old storage
    x.copy_(x_data_mutation_result)
    # Then, swap out the storage of x with the new storage
    x.set_(x_set_mutation_result)
```

There are two things that make this difficult to do though:

(1) Functionalization: the functionalization rule for `set_()` will fully throw away the old `FunctionalStorageImpl` on the graph input. So if there are any mutations to that `FunctionalStorageImpl` later on in the graph, the current graph input won't know about it. Maybe we can have a given `FunctionalTensorWrapper` remember all previous storages that it had, and track mutations on all of them - although this feels pretty complicated.

(2) AOTAutograd now needs to know that we might have *two* graph outputs that correspond to a single "mutated input", which is annoying.

It's worth pointing out that this issue is probably extremely unlikely for anyone to run into - can we just detect it and error? This feels slightly easier than solving it, although not significantly easier. We would still need `FunctionalTensorWrapper` to keep track of mutations on any of its "previous" storages, so it can report this info back to AOTAutograd so we can raise an error.

Pull Request resolved: pytorch#111554
Approved by: https://github.com/ezyang
  • Loading branch information
bdhirsh authored and pytorchmergebot committed Nov 13, 2023
1 parent 1d9919c commit 3afb4e5
Show file tree
Hide file tree
Showing 8 changed files with 357 additions and 71 deletions.
29 changes: 29 additions & 0 deletions aten/src/ATen/FunctionalTensorWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,35 @@ void FunctionalTensorWrapper::replace_(const Tensor& other) {
mutation_counter_++;
}

bool FunctionalTensorWrapper::has_data_mutation() {
// Current tensor's data was mutated if its storage saw any mutations.
return functional_storage_impl()->generation() > 0;
}

void FunctionalTensorWrapper::set__impl(const FunctionalTensorWrapper* other) {
// self.set_(src) will cause self to have all of the tensor properties of self.
value_ = other->value_;
generation_ = other->generation_;
view_metas_ = other->view_metas_;
// FREEZE the old storage, preventing mutations to it.
// this is a huge pain to handle properly in all cases, so we ban it.
functional_storage_impl()->freeze();
// Unsafely swap out the storage with other's storage,
// disconnecting `self` with its view chain
storage_ = other->storage_;
/// explicitly mark the tensor as having its storage changed from set_()
// Otherwise, we don't actually have a 100% accurate way to check this.
// (We could check if the updated value has a new storage than the original value,
// but this won't also let us uniquely determine if the tensor **also**
// experienced a data mutation).
was_storage_changed_ = true;

auto sizes_ = value_.sym_sizes();
auto strides_ = value_.sym_strides();
auto storage_offset_ = value_.sym_storage_offset();
set_sizes_and_strides(sizes_, strides_, storage_offset_);
}

void FunctionalTensorWrapper::maybe_replace_storage(const Tensor& other) {
// Note [resize_() in functionalization pass]
// resize_() is a special operator in functionalization because it can reallocate its underlying storage.
Expand Down
14 changes: 14 additions & 0 deletions aten/src/ATen/FunctionalTensorWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,18 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
// tensor by replaying the views off of the alias.
void mutate_view_meta(at::functionalization::ViewMeta meta);

// Custom implementation of self.set_(src)
void set__impl(const FunctionalTensorWrapper* other);

// Returns whether the current tensor's data was ever mutated
bool has_data_mutation();
//
// Returns whether the current FunctionalTensorWrapper
// experienced a set_() call.
bool was_storage_changed() {
return was_storage_changed_;
}

// The functionalization pass can be used to remove mutations.
// It does so by replacing any mutation op with it's corresponding
// out-of-place op, followed by a call to replace_(). e.g:
Expand Down Expand Up @@ -195,6 +207,8 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
uint64_t mutation_hidden_from_autograd_counter_ = 0;
bool has_metadata_mutation_ = false;
bool is_multi_output_view_ = false;
// Did the tensor experience a set_() call.
bool was_storage_changed_ = false;

size_t generation_ = 0;
std::vector<at::functionalization::ViewMeta> view_metas_;
Expand Down
25 changes: 25 additions & 0 deletions aten/src/ATen/FunctionalizeFallbackKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,28 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt
return out;
}

static at::Tensor& set__functionalize(at::Tensor& self, const at::Tensor& src) {
// error case
TORCH_CHECK(at::functionalization::impl::isFunctionalTensor(self) || !at::functionalization::impl::isFunctionalTensor(src),
"set__functionalize: Tried to mutate a non-functional tensor with a functional tensor, which is not allowed");

TORCH_CHECK(at::functionalization::impl::isFunctionalTensor(src),
"set__functionalize: We do not currently support x.set_(y) where y is not a FunctionalTensor. Please file an issue");

// nop case
if (!at::functionalization::impl::isFunctionalTensor(self) && !at::functionalization::impl::isFunctionalTensor(src)) {
at::AutoDispatchSkipFunctionalize guard;
return self.set_(src);
}

TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self));
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(src));
auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self);
auto src_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(src);
self_impl->set__impl(src_impl);
return self;
}

TORCH_LIBRARY_IMPL(_, Functionalize, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&functionalizeFallback>());
}
Expand All @@ -310,4 +332,7 @@ TORCH_LIBRARY_IMPL(aten, Functionalize, m) {
m.impl("lift_fresh_copy", TORCH_FN(lift_fresh_functionalize_copy));
m.impl("_to_copy", TORCH_FN(_to_copy_functionalize));
m.impl("_unsafe_view", TORCH_FN(_unsafe_view_functionalize));
// The overloads of set_() that take in a storage should never
// appear with torch.compile, because dynamo graph breaks
m.impl("set_.source_Tensor", TORCH_FN(set__functionalize));
}
126 changes: 111 additions & 15 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,97 @@ def forward(self, primals_1):
mul_1 = torch.ops.aten.mul.Tensor(mul, 3)
return [mul, mul_1]""")

def test_input_mutation_set__input_mutation(self):
def f(a):
b = torch.arange(9, dtype=a.dtype).reshape(3, 3)
with torch.no_grad():
a.set_(b)
return a * b
inp = [torch.ones(3, 3, requires_grad=True)]
self.verify_aot_autograd(f, inp, test_mutation=True)
inp = [torch.ones(3, 3, requires_grad=False)]
self.verify_aot_autograd(f, inp, test_mutation=True)

def test_set__steals_view_chain(self):
def f(a, b):
a_ = a.mul(2)
b_ = b.mul(2)
b_slice = b_[1].view(3, 3)
# a_clone should inherit the view chain from b_slice
a_.set_(b_slice)
# Also mutates b_,
a_.view(-1).mul_(2)
return a_ * b_slice
inp = [torch.ones(3, 3, requires_grad=False), torch.zeros(3, 9, requires_grad=False)]
self.verify_aot_autograd(f, inp)

def test_set__and_data_mutation_good(self):
def f(a, b):
# The data mutation happens *after* the set_(). This is ok (see the graph below)
with torch.no_grad():
a.set_(b)
b.mul_(2)
return a + b
inp = [torch.ones(3, 3, requires_grad=True), torch.ones(3, 3, requires_grad=True)]
fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
inp = [torch.ones(3, 3, requires_grad=False), torch.zeros(3, 3, requires_grad=False)]
self.verify_aot_autograd(f, inp, test_mutation=True)
# Important things to note:
# - "return a.set_(b)" desugars into "return b"
# - Both a and b are recorded as experiencing mutations,
# which is why we see "b_updated" (output of the mul) twice in the graph outputs.
# a is recorded as both a data mutation and a metadata mutation (due to set_ swapping its storage).
# - the runtime epilogue for a is "a.set_(mul)"
# - the runtime epilogue for b is "b.copy_(mul)"
self.assertExpectedInline(fw_graph.code.strip(), """\
def forward(self, primals_1, primals_2):
clone = torch.ops.aten.clone.default(primals_2); primals_2 = None
mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None
add = torch.ops.aten.add.Tensor(mul, mul)
return [mul, mul, add]""")

# This is a (hopefully) extremely rare case that is difficult to handle,
# so we ban it.
def test_set__and_data_mutation_bad(self):
def f(a):
a_view = a.view(-1)
tmp = torch.ones(3, 3, requires_grad=True)
# Now, any mutations on either tmp
# will be tracked as graph input mutations.
with torch.no_grad():
a.set_(tmp)
# BAD: a_view is now detached from every graph input,
# so we won't recognize that this caused an input mutation!
a_view.mul_(2)
return a + tmp
inp = [torch.ones(3, 3, requires_grad=True)]
with self.assertRaisesRegex(RuntimeError, "cannot mutate tensors with frozen storage"):
self.verify_aot_autograd(f, inp, test_mutation=True)

def test_input_mutation_set__nop(self):
def f(a):
b = torch.arange(9, dtype=a.dtype)
a_old = torch.ops.aten.alias.default(a)
with torch.no_grad():
a.set_(b)
a.set_(a_old)
return a + b.reshape(3, 3)
inp = [torch.ones(3, 3, requires_grad=True)]
fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
inp = [torch.ones(3, 3, requires_grad=False)]
self.verify_aot_autograd(f, inp, test_mutation=True)
# Things to note:
# - There are no set_() calls in the graph (we functionalize a.set_(b) into "b")
# - There is only **1** graph output. We properly realized that the two set_() calls
# undo each other, and so effectively no inputs are mutated.
self.assertExpectedInline(fw_graph.code.strip(), """\
def forward(self, primals_1):
arange = torch.ops.aten.arange.default(9, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
alias = torch.ops.aten.alias.default(primals_1); primals_1 = None
view = torch.ops.aten.view.default(arange, [3, 3]); arange = None
add = torch.ops.aten.add.Tensor(alias, view); alias = view = None
return [add]""")

def test_input_mutation_simple_with_none_and_nontensor(self):
# Tensor, None, int
def f(a, b, c):
Expand Down Expand Up @@ -1597,10 +1688,9 @@ def inp_callable(req_grad):
# Expectation: fwd() takes in 2 args, and we don't construct a synthetic base.
self.assertExpectedInline(fw_graph.code.strip(), """\
def forward(self, primals_1, primals_2):
view = torch.ops.aten.view.default(primals_1, [4]); primals_1 = None
t = torch.ops.aten.t.default(view); view = None
add = torch.ops.aten.add.Tensor(t, primals_2); primals_2 = None
return [t, add]""")
t = torch.ops.aten.t.default(primals_1); primals_1 = None
add = torch.ops.aten.add.Tensor(t, primals_2); t = primals_2 = None
return [add]""")

def test_input_mutation_aliases_and_none_require_gradients(self):
def f(a, b, c):
Expand Down Expand Up @@ -1639,7 +1729,7 @@ def test_input_mutation_aliases_bases_out_of_order(self):
# So we don't need to do the base construction / deconstruction
def f(a, b, c, d):
b.add_(1)
d.t_()
d.unsqueeze_(0)
return a + c + d, b.view(-1)

def inp_callable(req_grad):
Expand Down Expand Up @@ -1668,11 +1758,11 @@ def forward(self, primals_1, primals_2, primals_3):
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None
add_1 = torch.ops.aten.add.Tensor(primals_2, primals_3); primals_2 = primals_3 = None
as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
t_1 = torch.ops.aten.t.default(as_strided_3); as_strided_3 = None
add_2 = torch.ops.aten.add.Tensor(add_1, t_1); add_1 = None
unsqueeze_1 = torch.ops.aten.unsqueeze.default(as_strided_3, 0); as_strided_3 = None
add_2 = torch.ops.aten.add.Tensor(add_1, unsqueeze_1); add_1 = None
as_strided_11 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_1 = torch.ops.aten.view.default(as_strided_11, [-1]); as_strided_11 = None
return [as_strided_scatter, add_2, view_1, t_1]""") # noqa: B950
view_2 = torch.ops.aten.view.default(as_strided_11, [-1]); as_strided_11 = None
return [as_strided_scatter, add_2, view_2, unsqueeze_1]""") # noqa: B950

@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
def test_synthetic_base_base_attribute_is_none(self):
Expand Down Expand Up @@ -1910,7 +2000,7 @@ def f(x, y):
def test_dupe_arg_torture(self):
def f(x, y):
x.t_()
y.t_()
y.unsqueeze_(0)
return x + y

x = torch.randn(3, 3, requires_grad=True).clone()
Expand Down Expand Up @@ -1971,8 +2061,8 @@ def test_invalid_dupe_fake(self, counter):
def _test_invalid_dupe(self, counter, fake):
class F(torch.nn.Module):
def forward(self, x, y):
x.t_()
y.t_()
x.unsqueeze_(0)
y.unsqueeze_(0)
return (x + y,)

x = torch.randn(3, 3, requires_grad=True).clone()
Expand All @@ -1991,16 +2081,22 @@ def forward(self, x, y):
fxy = aot_module_simplified(F(), (x, y), nop)

fxy(x, y)
x = torch.randn(3, 3, requires_grad=True).clone()
y = torch.randn(3, 3, requires_grad=True).clone()
fxy(x, x) # is ok!

if fake:
fxx = aot_module_simplified(F(), (fake_x, fake_x), nop)
else:
fxx = aot_module_simplified(F(), (x, x), nop)

x = torch.randn(3, 3, requires_grad=True).clone()
y = torch.randn(3, 3, requires_grad=True).clone()
fxx(x, x)
# Note This should not raise! Once we have guards in place here,
# we will have this working correctly, as it should recompile.
x = torch.randn(3, 3, requires_grad=True).clone()
y = torch.randn(3, 3, requires_grad=True).clone()
self.assertExpectedRaisesInline(
AssertionError, lambda: fxx(x, y),
"""At compilation time, graph 1 was compiled under the assumption that input 1 would be a duplicate of input 0, but at runtime this was not the case. This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""" # noqa: B950
Expand Down Expand Up @@ -2621,7 +2717,7 @@ def fn(p, x):
x.t_()
return (x * 2,)
mod = TestMod(fn)
inp = torch.randn(2)
inp = torch.randn(2, 4)
with self.assertRaisesRegex(
RuntimeError, "Found an input that received a metadata mutation"
):
Expand Down Expand Up @@ -3330,7 +3426,7 @@ def f(a, b):
def test_aot_dispatch_input_metadata_mutation(self):
def f(a, b):
a.t_()
b.t_()
b.unsqueeze_(0)
return a + b

b1_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3)
Expand Down Expand Up @@ -3375,7 +3471,7 @@ def f(a, b):
def test_aot_dispatch_input_data_and_metadata_mutation(self):
def f(a, b):
a.t_()
b.t_()
b.unsqueeze_(0)
a.mul_(2)
b.mul_(3)
return a + b
Expand Down
8 changes: 8 additions & 0 deletions torch/_dynamo/variables/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,14 @@ def has_bool_key(v):
elif name in ("resize_", "resize_as_"):
# Handling resizing in its full generality is difficult.
unimplemented(f"Tensor.{name}")
elif name == "set_" and len(args) > 1:
# torch.Tensor.set_() has several overloads.
# aten::set_.source_Tensor(Tensor) gets special handling
# in AOTAutograd and functionalization, because it is the most common
# overload and is used by FSDP.
# graph-breaking on aten::set_source_Tensor_storage_offset for now,
# unless we find that we need to make it work.
unimplemented("Tensor.set_.source_Tensor_storage_offset")
elif (
name == "add_" and len(args) == 1 and len(kwargs) == 1 and "alpha" in kwargs
):
Expand Down
Loading

0 comments on commit 3afb4e5

Please sign in to comment.