forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Reland][Autograd/Checkpoint] Checkpoint implementation without reent…
…rant autograd (pytorch#69508) Summary: Pull Request resolved: pytorch#69508 Original Phabricator Diff: D32704467 (pytorch@e032dae) Reland, fix is to not test traditional checkpoint when input does not require grad as that is unsupported as documented. Original PR body: Resubmission of pytorch#62964 with the suggestions and tests discussed in pytorch#65537. Adds a `use_reentrant=False` flag to `checkpoint` function. When `use_reentrant=True` is specified, a checkpointing implementation that uses SavedVariableHooks instead of re-entrant autograd is used. This makes it more composable with things such as `autograd.grad` as well as DDP (still need to add thorough distributed testing). As discussed in pytorch#65537, the tests that we need to add are: - [x] Gradient hooks are called once - [x] works when input does require grads but Tensor that require grads are captures (like first layer in a nn) - [x] works for functions with arbitrary input/output objects - [x] distributed tests (next PR) Note that this is only for `torch.utils.checkpoint`, if this approach overall looks good, we will do something similar for `checkpoint_sequential`. ghstack-source-id: 144948501 Test Plan: CI Reviewed By: zhaojuanmao Differential Revision: D32902634 fbshipit-source-id: 2ee87006e5045e5471ff80c36a07fbecc2bea3fe
- Loading branch information
1 parent
3456c2c
commit 049debd
Showing
2 changed files
with
313 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.