Skip to content

Commit

Permalink
[Reland][Autograd/Checkpoint] Checkpoint implementation without reent…
Browse files Browse the repository at this point in the history
…rant autograd (pytorch#69508)

Summary:
Pull Request resolved: pytorch#69508

Original Phabricator Diff: D32704467 (pytorch@e032dae)

Reland, fix is to not test traditional checkpoint when input does not require grad as that is unsupported as documented.

Original PR body:

Resubmission of pytorch#62964 with the
suggestions and tests discussed in
pytorch#65537.

Adds a `use_reentrant=False` flag to `checkpoint` function. When
`use_reentrant=True` is specified, a checkpointing implementation that uses
SavedVariableHooks instead of re-entrant autograd is used. This makes it more
composable with things such as `autograd.grad` as well as DDP (still need to
add thorough distributed testing).

As discussed in pytorch#65537, the tests that we need to add are:

- [x] Gradient hooks are called once
- [x] works when input does require grads but Tensor that require grads are captures (like first layer in a nn)
- [x] works for functions with arbitrary input/output objects
- [x] distributed tests (next PR)

Note that this is only for `torch.utils.checkpoint`, if this approach overall looks good, we will do something similar for `checkpoint_sequential`.
ghstack-source-id: 144948501

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D32902634

fbshipit-source-id: 2ee87006e5045e5471ff80c36a07fbecc2bea3fe
  • Loading branch information
rohan-varma authored and facebook-github-bot committed Dec 8, 2021
1 parent 3456c2c commit 049debd
Show file tree
Hide file tree
Showing 2 changed files with 313 additions and 18 deletions.
200 changes: 199 additions & 1 deletion test/test_autograd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Owner(s): ["module: autograd"]

import contextlib
import gc
import io
import math
Expand Down Expand Up @@ -31,7 +32,8 @@
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import (TestCase, run_tests, skipIfNoLapack,
slowTest, IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck,
disable_gc, gradcheck, gradgradcheck)
disable_gc, gradcheck, gradgradcheck,
parametrize, instantiate_parametrized_tests)
from torch.autograd import Variable, Function, detect_anomaly, kineto_available
from torch.autograd.function import InplaceFunction
import torch.autograd.forward_ad as fwAD
Expand Down Expand Up @@ -4308,6 +4310,50 @@ def test_checkpointing(self):
mean_combined = torch.stack(feat_combined).mean()
mean_combined.backward()

@slowTest
@parametrize("input_requires_grad", [True, False])
def test_checkpointing_without_reentrant(self, input_requires_grad):
"""
Basic test for checkpoint without reentrant autograd.
"""
num_inp = 2000
nz_inp = 10
nz_out = 10
nz_bottleneck = 1000

# small proxy network for some complex reasoning we want to do per input
module = nn.Sequential(
nn.Linear(nz_inp, nz_bottleneck),
nn.ReLU(),
nn.Linear(nz_bottleneck, nz_inp)
)

# Run model with and without checkpointing and verify gradients are
# equivalent, regardless of if inputs require grads or not.
module_copy = deepcopy(module)

feat_combined = []
feat_combined_no_checkpoint = []
for r in range(num_inp):
data_r = torch.empty(1, nz_inp)
data_r.uniform_()
data_r.requires_grad = input_requires_grad
data_r_copy = data_r.clone()
feat_r = checkpoint(module, data_r, use_reentrant=False)
feat_combined.append(feat_r)
feat_r_no_checkpoint = module_copy(data_r)
feat_combined_no_checkpoint.append(feat_r_no_checkpoint)


# compute mean as a proxy for some joint reasoning
mean_combined = torch.stack(feat_combined).mean()
mean_combined.backward()
mean_combined_no_checkpoint = torch.stack(feat_combined_no_checkpoint).mean()
mean_combined_no_checkpoint.backward()

for checkpoint_param, param in zip(module.parameters(), module_copy.parameters()):
self.assertEqual(checkpoint_param.grad, param.grad)

def test_checkpoint_valid_reset_on_error(self):
a = torch.randn(2, 2, requires_grad=True)

Expand All @@ -4318,6 +4364,156 @@ def test_checkpoint_valid_reset_on_error(self):
c = checkpoint(torch.exp, a).sum()
c.backward()

@parametrize("use_reentrant", [True, False])
def test_checkpointing_without_reentrant_detached_tensor(self, use_reentrant):
class NoGradModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 2, bias=False)
self.lin2 = nn.Linear(2, 2, bias=False)

def forward(self, x):
with torch.no_grad():
return self.lin2(self.linear(x))

module = NoGradModule()

err_ctx = (
self.assertRaisesRegex(
RuntimeError,
"none of output has requires_grad=True"
)
if use_reentrant
else contextlib.suppress()
)

a = torch.randn(2, 2, requires_grad=True)
for _ in range(3):
with err_ctx:
# out does not require grad
out = checkpoint(module, a, use_reentrant=use_reentrant)
# Make loss require grad, otherwise we would run into
# "element 0 of tensors does not require grad and does not have a grad_fn"
out += a
out.sum().backward()

def test_checkpointing_without_reentrant_correct_grad(self):
"""
Verifies that correct gradients are calculated for checkpoint
without reentrant autograd, for both backward() and autograd.grad().
"""
a = torch.randn(2, 2, requires_grad=True)

b = torch.exp(a).sum()
b.backward()
b_grad = a.grad

a.grad = None
c = checkpoint(torch.exp, a, use_reentrant=False).sum()
c.backward()
c_grad = a.grad

a.grad = None
d = checkpoint(torch.exp, a, use_reentrant=False).sum()
d_grad, = torch.autograd.grad(d, (a,))

self.assertEqual(b_grad, c_grad)
self.assertEqual(b_grad, d_grad)

def test_checkpointing_without_reentrant_dataparallel(self):
"""
Verifies gradient correctness when checkpoint without reentrant autograd
is used in conjunction with DataParallel.
"""
class LinearModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 2, bias=False)

def forward(self, inp):
return self.linear(inp)

a = torch.randn(2, 2, requires_grad=True)
if torch.cuda.is_available():
a = a.cuda()

model = LinearModule()
if torch.cuda.is_available():
model = model.cuda()

b = deepcopy(model)(a).sum()
b.backward()
b_grad = a.grad

a.grad = None

module = torch.nn.DataParallel(deepcopy(model))
c = checkpoint(module, a, use_reentrant=False).sum()
c.backward()
c_grad = a.grad

self.assertEqual(b_grad, c_grad)

def test_checkpointing_without_reentrant_parameter_used_in_an_out(self):
"""
Ensures that gradient hooks are only called once per tensor.
"""
w = torch.randn(10, 10, requires_grad=True)
count = 0

def hook(grad):
nonlocal count
count += 1

w.register_hook(hook)
x = torch.rand(10, 10, requires_grad=True)
h = w * x # Using w outside the checkpoint
out = checkpoint(lambda x: w * x, h, use_reentrant=False) # Using w inside the checkpoint

out.sum().backward()
# should only call hook once
self.assertEqual(count, 1)

def test_checkpointing_without_reentrant_arbitrary_input_output(self):
"""
Ensures checkpointing without reentrant autograd works with functions
with arbitrary input/output structures.
"""

class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(5, 5, bias=False)

def forward(self, dict_input):
tensor = dict_input["tensor"]
return {
"result": self.layer(tensor)
}

model_no_checkpoint = MyModel()
model_checkpoint_without_reentrant = deepcopy(model_no_checkpoint)

inp = {
"tensor": torch.randn(5, 5)
}

out_no_checkpoint = model_no_checkpoint(inp)["result"].sum()

out_checkpoint = checkpoint(
model_checkpoint_without_reentrant,
inp,
use_reentrant=False
)["result"].sum()

self.assertEqual(out_checkpoint, out_no_checkpoint)

out_no_checkpoint.backward()
out_checkpoint.backward()

for param, checkpoint_param in zip(model_no_checkpoint.parameters(), model_checkpoint_without_reentrant.parameters()):
self.assertEqual(param.grad, checkpoint_param.grad)

def test_callback_adds_callback(self):
called = [0]

Expand Down Expand Up @@ -9108,5 +9304,7 @@ def fn(x1, x2):
except_for=None
)

instantiate_parametrized_tests(TestAutograd)

if __name__ == '__main__':
run_tests()
Loading

0 comments on commit 049debd

Please sign in to comment.