diff --git a/test/test_autograd.py b/test/test_autograd.py index f5976548689c3..b60737c1655aa 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -1,5 +1,6 @@ # Owner(s): ["module: autograd"] +import contextlib import gc import io import math @@ -31,7 +32,8 @@ from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_utils import (TestCase, run_tests, skipIfNoLapack, slowTest, IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck, - disable_gc, gradcheck, gradgradcheck) + disable_gc, gradcheck, gradgradcheck, + parametrize, instantiate_parametrized_tests) from torch.autograd import Variable, Function, detect_anomaly, kineto_available from torch.autograd.function import InplaceFunction import torch.autograd.forward_ad as fwAD @@ -4283,7 +4285,8 @@ def backward(ctx, x): @slowTest - def test_checkpointing(self): + @parametrize("input_requires_grad", [True, False]) + def test_checkpointing(self, input_requires_grad): num_inp = 2000 nz_inp = 10 nz_out = 10 @@ -4300,7 +4303,7 @@ def test_checkpointing(self): for r in range(num_inp): data_r = torch.empty(1, nz_inp) data_r.uniform_() - data_r.requires_grad = True + data_r.requires_grad = input_requires_grad feat_r = checkpoint(module, data_r) feat_combined.append(feat_r) @@ -4308,6 +4311,50 @@ def test_checkpointing(self): mean_combined = torch.stack(feat_combined).mean() mean_combined.backward() + @slowTest + @parametrize("input_requires_grad", [True, False]) + def test_checkpointing_without_reentrant(self, input_requires_grad): + """ + Basic test for checkpoint without reentrant autograd. + """ + num_inp = 2000 + nz_inp = 10 + nz_out = 10 + nz_bottleneck = 1000 + + # small proxy network for some complex reasoning we want to do per input + module = nn.Sequential( + nn.Linear(nz_inp, nz_bottleneck), + nn.ReLU(), + nn.Linear(nz_bottleneck, nz_inp) + ) + + # Run model with and without checkpointing and verify gradients are + # equivalent, regardless of if inputs require grads or not. + module_copy = deepcopy(module) + + feat_combined = [] + feat_combined_no_checkpoint = [] + for r in range(num_inp): + data_r = torch.empty(1, nz_inp) + data_r.uniform_() + data_r.requires_grad = input_requires_grad + data_r_copy = data_r.clone() + feat_r = checkpoint(module, data_r, use_reentrant=False) + feat_combined.append(feat_r) + feat_r_no_checkpoint = module_copy(data_r) + feat_combined_no_checkpoint.append(feat_r_no_checkpoint) + + + # compute mean as a proxy for some joint reasoning + mean_combined = torch.stack(feat_combined).mean() + mean_combined.backward() + mean_combined_no_checkpoint = torch.stack(feat_combined_no_checkpoint).mean() + mean_combined_no_checkpoint.backward() + + for checkpoint_param, param in zip(module.parameters(), module_copy.parameters()): + self.assertEqual(checkpoint_param.grad, param.grad) + def test_checkpoint_valid_reset_on_error(self): a = torch.randn(2, 2, requires_grad=True) @@ -4318,6 +4365,156 @@ def test_checkpoint_valid_reset_on_error(self): c = checkpoint(torch.exp, a).sum() c.backward() + @parametrize("use_reentrant", [True, False]) + def test_checkpointing_without_reentrant_detached_tensor(self, use_reentrant): + class NoGradModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 2, bias=False) + self.lin2 = nn.Linear(2, 2, bias=False) + + def forward(self, x): + with torch.no_grad(): + return self.lin2(self.linear(x)) + + module = NoGradModule() + + err_ctx = ( + self.assertRaisesRegex( + RuntimeError, + "none of output has requires_grad=True" + ) + if use_reentrant + else contextlib.suppress() + ) + + a = torch.randn(2, 2, requires_grad=True) + for _ in range(3): + with err_ctx: + # out does not require grad + out = checkpoint(module, a, use_reentrant=use_reentrant) + # Make loss require grad, otherwise we would run into + # "element 0 of tensors does not require grad and does not have a grad_fn" + out += a + out.sum().backward() + + def test_checkpointing_without_reentrant_correct_grad(self): + """ + Verifies that correct gradients are calculated for checkpoint + without reentrant autograd, for both backward() and autograd.grad(). + """ + a = torch.randn(2, 2, requires_grad=True) + + b = torch.exp(a).sum() + b.backward() + b_grad = a.grad + + a.grad = None + c = checkpoint(torch.exp, a, use_reentrant=False).sum() + c.backward() + c_grad = a.grad + + a.grad = None + d = checkpoint(torch.exp, a, use_reentrant=False).sum() + d_grad, = torch.autograd.grad(d, (a,)) + + self.assertEqual(b_grad, c_grad) + self.assertEqual(b_grad, d_grad) + + def test_checkpointing_without_reentrant_dataparallel(self): + """ + Verifies gradient correctness when checkpoint without reentrant autograd + is used in conjunction with DataParallel. + """ + class LinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 2, bias=False) + + def forward(self, inp): + return self.linear(inp) + + a = torch.randn(2, 2, requires_grad=True) + if torch.cuda.is_available(): + a = a.cuda() + + model = LinearModule() + if torch.cuda.is_available(): + model = model.cuda() + + b = deepcopy(model)(a).sum() + b.backward() + b_grad = a.grad + + a.grad = None + + module = torch.nn.DataParallel(deepcopy(model)) + c = checkpoint(module, a, use_reentrant=False).sum() + c.backward() + c_grad = a.grad + + self.assertEqual(b_grad, c_grad) + + def test_checkpointing_without_reentrant_parameter_used_in_an_out(self): + """ + Ensures that gradient hooks are only called once per tensor. + """ + w = torch.randn(10, 10, requires_grad=True) + count = 0 + + def hook(grad): + nonlocal count + count += 1 + + w.register_hook(hook) + x = torch.rand(10, 10, requires_grad=True) + h = w * x # Using w outside the checkpoint + out = checkpoint(lambda x: w * x, h, use_reentrant=False) # Using w inside the checkpoint + + out.sum().backward() + # should only call hook once + self.assertEqual(count, 1) + + def test_checkpointing_without_reentrant_arbitrary_input_output(self): + """ + Ensures checkpointing without reentrant autograd works with functions + with arbitrary input/output structures. + """ + + class MyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(5, 5, bias=False) + + def forward(self, dict_input): + tensor = dict_input["tensor"] + return { + "result": self.layer(tensor) + } + + model_no_checkpoint = MyModel() + model_checkpoint_without_reentrant = deepcopy(model_no_checkpoint) + + inp = { + "tensor": torch.randn(5, 5) + } + + out_no_checkpoint = model_no_checkpoint(inp)["result"].sum() + + out_checkpoint = checkpoint( + model_checkpoint_without_reentrant, + inp, + use_reentrant=False + )["result"].sum() + + self.assertEqual(out_checkpoint, out_no_checkpoint) + + out_no_checkpoint.backward() + out_checkpoint.backward() + + for param, checkpoint_param in zip(model_no_checkpoint.parameters(), model_checkpoint_without_reentrant.parameters()): + self.assertEqual(param.grad, checkpoint_param.grad) + def test_callback_adds_callback(self): called = [0] @@ -9108,5 +9305,7 @@ def fn(x1, x2): except_for=None ) +instantiate_parametrized_tests(TestAutograd) + if __name__ == '__main__': run_tests() diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index 7238c5926b4cc..5afb24992c82e 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -1,6 +1,6 @@ import torch import warnings -from typing import Any, Iterable, List, Tuple +from typing import Any, Iterable, List, Tuple, Union def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]: @@ -142,7 +142,7 @@ def backward(ctx, *args): return (None, None) + grads -def checkpoint(function, *args, **kwargs): +def checkpoint(function, *args, use_reentrant: bool = True, **kwargs): r"""Checkpoint a model or part of the model Checkpointing works by trading compute for memory. Rather than storing all @@ -165,10 +165,6 @@ def checkpoint(function, *args, **kwargs): consisting of Tensors, these Tensors nested in custom structures will not be considered as part of autograd. - .. warning:: - Checkpointing currently only supports :func:`torch.autograd.backward` - and only if its `inputs` argument is not passed. :func:`torch.autograd.grad` - is not supported. .. warning:: If :attr:`function` invocation during backward does anything different @@ -177,18 +173,30 @@ def checkpoint(function, *args, **kwargs): detected. .. warning:: - If checkpointed segment contains tensors detached from the computational - graph by `detach()` or `torch.no_grad()`, the backward pass will raise an - error. This is because `checkpoint` makes all the outputs require - gradients which causes issues when a tensor is defined to have no - gradient in the model. To circumvent this, detach the tensors outside of - the `checkpoint` function. + If ``use_reentrant=True`` is specified, then if the checkpointed segment + contains tensors detached from the computational graph by `detach()` or + `torch.no_grad()`, the backward pass will raise an error. This is + because `checkpoint` makes all the outputs require gradients which + causes issues when a tensor is defined to have no gradient in the model. + To circumvent this, detach the tensors outside of the `checkpoint` + function. Note that the checkpointed segment can contain tensors + detached from the computational graph if ``use_reentrant=False`` is + specified. .. warning:: - At least one of the inputs needs to have :code:`requires_grad=True` if - grads are needed for model inputs, otherwise the checkpointed part of the - model won't have gradients. At least one of the outputs needs to have - :code:`requires_grad=True` as well. + If ``use_reentrant=True`` is specified, at least one of the inputs needs + to have :code:`requires_grad=True` if grads are needed for model inputs, + otherwise the checkpointed part of the model won't have gradients. At + least one of the outputs needs to have :code:`requires_grad=True` as + well. Note that this does not apply if ``use_reentrant=False`` is + specified. + + .. warning:: + If ``use_reentrant=True`` is specified, checkpointing currently only + supports :func:`torch.autograd.backward` and only if its `inputs` + argument is not passed. :func:`torch.autograd.grad` + is not supported. If ``use_reentrant=False`` is specified, checkpointing + will work with :func:`torch.autograd.grad`. Args: function: describes what to run in the forward pass of the model or @@ -198,6 +206,13 @@ def checkpoint(function, *args, **kwargs): first input as ``activation`` and the second input as ``hidden`` preserve_rng_state(bool, optional, default=True): Omit stashing and restoring the RNG state during each checkpoint. + use_reentrant(bool, optional, default=True): Use checkpointing + implementation that requires re-entrant autograd. + If ``use_reentrant=False`` is specified, ``checkpoint`` will use an + implementation that does not require re-entrant autograd. This + allows ``checkpoint`` to support additional functionality, such as + working as expected with ``torch.autograd.grad``. Note that future + versions of PyTorch will default to ``use_reentrant=False``. args: tuple containing inputs to the :attr:`function` Returns: @@ -208,7 +223,14 @@ def checkpoint(function, *args, **kwargs): if kwargs: raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) - return CheckpointFunction.apply(function, preserve, *args) + if use_reentrant: + return CheckpointFunction.apply(function, preserve, *args) + else: + return _checkpoint_without_reentrant( + function, + preserve, + *args + ) def checkpoint_sequential(functions, segments, input, **kwargs): @@ -275,3 +297,78 @@ def forward(input): input = checkpoint(run_function(start, end, functions), input, preserve_rng_state=preserve) return run_function(end + 1, len(functions) - 1, functions)(input) + +def _checkpoint_without_reentrant(function, preserve_rng_state=True, *args): + """Checkpointining without re-entrant autograd + Args: + function: describes what to run in the forward pass of the model or + part of the model. It should also know how to handle the inputs + passed as the tuple. For example, in LSTM, if user passes + ``(activation, hidden)``, :attr:`function` should correctly use the + first input as ``activation`` and the second input as ``hidden`` + preserve_rng_state(bool, optional, default=True): Omit stashing and restoring + the RNG state during each checkpoint. + *args: Arguments to pass in to the given ``function``. + """ + had_autocast_in_fwd = torch.is_autocast_enabled() + + if preserve_rng_state: + fwd_cpu_state = torch.get_rng_state() + # Don't eagerly initialize the cuda context by accident. + # (If the user intends that the context is initialized later, within their + # run_function, we SHOULD actually stash the cuda state here. Unfortunately, + # we have no way to anticipate this will happen before we run the function. + # If they do so, we raise an error.) + had_cuda_in_fwd = False + if torch.cuda._initialized: + had_cuda_in_fwd = True + fwd_gpu_devices, fwd_gpu_states = get_device_states(*args) + + storage: List[Union[torch.Tensor, None]] = [] + counter = 0 + + def pack(x): + nonlocal counter + counter += 1 + # TODO(varal7): Instead of returning indices, we can return things metadata (such as + # size, device, ...) to catch certain cases of undeterministic behavior of the forward + return counter - 1 + + def unpack(x): + if len(storage) == 0: + + def inner_pack(inner): + storage.append(inner) + return None + + def inner_unpack(packed): + raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.") + + # Stash the surrounding rng state, and mimic the state that was + # present at this time during forward. Restore the surrounding state + # when we're done. + rng_devices = [] + if preserve_rng_state and had_cuda_in_fwd: + rng_devices = fwd_gpu_devices + with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state): + if preserve_rng_state: + torch.set_rng_state(fwd_cpu_state) + if had_cuda_in_fwd: + set_device_states(fwd_gpu_devices, fwd_gpu_states) + with torch.enable_grad(), torch.cuda.amp.autocast(had_autocast_in_fwd): + with torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): + _unused = function(*args) + + return storage[x] + + with torch.autograd.graph.saved_tensors_hooks(pack, unpack): + output = function(*args) + if torch.cuda._initialized and not had_cuda_in_fwd: + # Cuda was not initialized before running the forward, so we didn't + # stash the CUDA state. + raise RuntimeError( + "PyTorch's CUDA state was initialized in the forward pass " + "of a Checkpoint, which is not allowed. Please open an issue " + "if you need this feature.") + + return output