Skip to content

Commit

Permalink
Relax use_count constraints for swap_tensors when AccumulateGrad hold…
Browse files Browse the repository at this point in the history
…s a reference (pytorch#127313)

### Before this PR:
`torch.utils.swap_tensors(a, b)` required the `use_count` of `a` and `b` to be 1

```python
a = torch.randn(2, 3, requires_grad=True)
b = torch.randn(2, 4)
out = a * 2
out.sum().backward()
# Calling swap_tensors here would fail due to the reference held by AccumulateGrad node, which is not cleaned up after backward
# torch.utils.swap_tensors(a, b)
del out
# Calling swap_tensors here would pass
torch.utils.swap_tensors(a, b)
```
### After this PR:
`torch.utils.swap_tensors(a, b)` requires the `use_count` of `a` and `b` to be 1 or 2 IF the second reference is held by `AccumulateGrad`

A pre-hook will be registered on the `AccumulateGrad` node so that it will fail if it is called (i.e. if user attempts to backward through the graph).

```python
a = torch.randn(2, 3, requires_grad=True)
b = torch.randn(2, 4)
out = a * 2
out.sum().backward()
# Calling swap_tensors here is ok
torch.utils.swap_tensors(a, b)
# If we ever backward to the AccumulateGrad node it will error that it was poisoned by swap_tensors
```

### Application to `nn.Module`

This issue is especially pertinent in context of `nn.Module` where parameters will have `AccumulateGrad` nodes initialized after forward. Specifically, this is intended to address pytorch#126814 (comment). Previously, this would fail at the `m.cpu()` but we want users to be able to do something like the following, and instead raise an error if the user ever attempts to backward through the poisoned `AccumulateGrad` node

```python
import torch
import torch.nn as nn
m = nn.Linear(3, 5)
inp = torch.randn(2, 3)
out = m(inp)
out.sum().backward()
m.cpu()
```

Pull Request resolved: pytorch#127313
Approved by: https://github.com/soulitzer
  • Loading branch information
mikaylagawarecki authored and pytorchmergebot committed May 30, 2024
1 parent d44ab8b commit cd06ae0
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 23 deletions.
17 changes: 15 additions & 2 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,8 @@ def test_errors(self, device, dtype, module_info, training):
else:
raise NotImplementedError(f"Unknown error type {error_input.error_on}")

@modules([module for module in module_db if not module.is_lazy])
# Only run this test for float32 because the test loops over all the dtypes
@modules([module for module in module_db if not module.is_lazy], allowed_dtypes=[torch.float32])
@parametrize('swap', [True, False])
@parametrize('set_grad', [True, False])
@wrapSwapTensorsTest()
Expand All @@ -879,6 +880,7 @@ def test_to(self, device, dtype, module_info, training, swap, set_grad):

for module_input in module_inputs:
c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs

m = module_cls(*c_args, **c_kwargs)

Expand All @@ -896,13 +898,25 @@ def _to(m, set_grad=False):
setattr(m, n, new_b)
_to(m, set_grad=set_grad)

# Check .to() can be run after forward and backward with swap
has_params = len(list(m.parameters())) > 0
if swap and not set_grad and has_params:
out = m(*args, **kwargs)
if isinstance(out, tuple):
out = out[0]
out.sum().backward()
m.to(dtype=torch.half)
# reset
m.to(dtype=torch.float32)

prev_device, prev_dtype = device, dtype
for device_, dtype_ in product(devices, dtypes):
# if device/dtype do not change, grad.to(device, dtype) is a no-op so
# swapping will not change ._cdata
# parameters will be wrapped in an nn.Parameter before swapping
# which will cause the ._cdata to change
g_no_swap = device_ == prev_device and dtype_ == prev_dtype
prev_prev_device, prev_prev_dtype = prev_device, prev_dtype
prev_device, prev_dtype = device_, dtype_

p_ids_before = [id(p) for p in m.parameters()]
Expand Down Expand Up @@ -940,7 +954,6 @@ def _to(m, set_grad=False):
self.assertTrue(all(a == b for a, b in zip(g_cdatas_before, g_cdatas_after)))
self.assertTrue(all(a == b for a, b in zip(g_ids_before, g_ids_after)))


@modules([module for module in module_db if not module.is_lazy], allowed_dtypes=[torch.float32])
@parametrize('swap', [True, False])
@wrapSwapTensorsTest()
Expand Down
22 changes: 16 additions & 6 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1594,19 +1594,29 @@ def add_one_inplace(t):
finally:
torch.__future__.set_overwrite_module_params_on_conversion(False)

def test_swap_module_params_fails_after_forward(self):
def test_swap_module_params_poisons_acc_grad(self):
try:
torch.__future__.set_swap_module_params_on_conversion(True)
# (1) backward cannot be run after _apply
# forward will init AccumulateGrad nodes, which bumps use_count of parameters' at::Tensors
# additionally, if any Tensors are saved for backward, their use_count will be bumped
m = torch.nn.Linear(2, 3)
inp = torch.randn(2, 2)
# forward will init AccumulateGrad nodes, which bumps use_count of parameters' at::Tensors
out = m(inp)
with self.assertRaisesRegex(RuntimeError, re.escape("_apply(): Couldn't swap Linear.weight")):
m.half()
del out
# works as expected now
m.half()
self.assertTrue(all(p.dtype == torch.float16 for p in m.parameters()))
with self.assertRaisesRegex(RuntimeError, "Trying to execute AccumulateGrad node that was poisoned by swap_tensors"):
out.sum().backward()
# (2) _apply can be run after backward()
# After running backward, all the references generated by "save for backward" will be cleared
# So the use_count will be 2 (1 from Tensor itself, and 1 from AccumulateGrad node), swap_tensors
# should allow this.
inp2 = torch.randn(2, 2, dtype=torch.half)
out2 = m(inp2)
out2.sum().backward()
m.float()
self.assertTrue(all(p.dtype == torch.float32 for p in m.parameters()))
out3 = m(inp)
finally:
torch.__future__.set_swap_module_params_on_conversion(False)

Expand Down
7 changes: 2 additions & 5 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10623,12 +10623,9 @@ def test_swap_basic(self):
if t1.is_floating_point():
t3 = t1.clone().detach().requires_grad_(True)
out = t3 * 2
with self.assertRaisesRegex(RuntimeError, "Expected single reference to a's"):
torch.utils.swap_tensors(t3, t2)
del out
# Now succeeds
torch.utils.swap_tensors(t3, t2)
torch.utils.swap_tensors(t1, t2)
with self.assertRaisesRegex(RuntimeError, "AccumulateGrad node that was poisoned by swap_tensors"):
out.sum().backward()

wr = weakref.ref(t1)
with self.assertRaisesRegex(RuntimeError, "has weakref"):
Expand Down
12 changes: 2 additions & 10 deletions torch/csrc/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,22 +375,14 @@ PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) {
THPVariable* a = reinterpret_cast<THPVariable*>(a_);
THPVariable* b = reinterpret_cast<THPVariable*>(b_);

TORCH_CHECK(
a->cdata->use_count() == 1,
"Expected single reference to a's Tensor object but got ",
a->cdata->use_count());
TORCH_CHECK(
b->cdata->use_count() == 1,
"Expected single reference to b's Tensor object but got ",
b->cdata->use_count());
// weak_use_count() adds 1 if use_count is non-zero
TORCH_CHECK(
a->cdata->weak_use_count() == 1,
"Expected no weakrefs to a's Tensor object but got ",
"Expected no weakrefs to t1's Tensor object but got ",
a->cdata->weak_use_count() - 1);
TORCH_CHECK(
b->cdata->weak_use_count() == 1,
"Expected no weakrefs to b's Tensor object but got ",
"Expected no weakrefs to t2's Tensor object but got ",
b->cdata->weak_use_count() - 1);

// Swap the Tensor Impl
Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/autograd/python_variable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1615,6 +1615,13 @@ int THPVariable_set_imag(PyObject* self, PyObject* imag, void* unused) {
END_HANDLE_TH_ERRORS_RET(-1)
}

PyObject* THPVariable__use_count(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
const auto& t = THPVariable_Unpack(self);
return THPUtils_packUInt64(t.use_count());
END_HANDLE_TH_ERRORS
}

// properties are registered here because we are currently only able to bind
// them manually. TODO: make declarable in native_functions
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
Expand Down Expand Up @@ -1766,6 +1773,7 @@ static PyMethodDef extra_methods[] = {
THPVariable_rev_view_func_unsafe,
METH_O,
nullptr},
{"_use_count", THPVariable__use_count, METH_NOARGS, nullptr},
{nullptr}};

struct THPVariableMeta {
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ def get_ignored_functions() -> Set[Callable]:
Tensor._is_any_true,
Tensor._addmm_activation,
Tensor.to_padded_tensor,
Tensor._use_count,
}


Expand Down
26 changes: 26 additions & 0 deletions torch/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,32 @@ def swap_attr(name):
setattr(t1, name, (getattr(t2, name)))
setattr(t2, name, tmp)

def error_pre_hook(grad_outputs):
raise RuntimeError("Trying to execute AccumulateGrad node that was poisoned by swap_tensors "
"this can happen when you try to run backward on a tensor that was swapped. "
"For a module m with `torch.__future__.set_swap_module_params_on_conversion(True)` "
"you should not change the device or dtype of the module (e.g. `m.cpu()` or `m.half()`) "
"between running forward and backward. To resolve this, please only change the "
"device/dtype before running forward (or after both forward and backward).")

def check_use_count(t, name='t1'):
use_count = t._use_count()
error_str = (f"Expected use_count of {name} to be 1 or 2 with an AccumulateGrad node but got {use_count} "
f"make sure you are not holding references to the tensor in other places.")
if use_count > 1:
if use_count == 2 and t.is_leaf:
accum_grad_node = torch.autograd.graph.get_gradient_edge(t).node
# Make sure that the accumulate_grad node was not lazy_init-ed by get_gradient_edge
if t._use_count() == 2:
accum_grad_node.register_prehook(error_pre_hook)
else:
raise RuntimeError(error_str)
else:
raise RuntimeError(error_str)

check_use_count(t1, 't1')
check_use_count(t2, 't2')

# Swap the types
# Note that this will fail if there are mismatched slots
swap_attr("__class__")
Expand Down

0 comments on commit cd06ae0

Please sign in to comment.