forked from FlagAI-Open/FlagAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request FlagAI-Open#33 from baai-open-internal/vit-checkpo…
…inting-activations add checkpointing activations for VIT when env_type = pytorch/deepspeed
Showing
5 changed files
with
102 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
|
||
import os | ||
if os.getenv('ENV_TYPE') == 'deepspeed': | ||
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint | ||
else: | ||
from torch.utils.checkpoint import checkpoint | ||
import torch | ||
from itertools import chain | ||
|
||
def checkpoint_seq( | ||
functions, | ||
x, | ||
every=1, | ||
flatten=False, | ||
skip_last=False, | ||
): | ||
r"""A helper function for checkpointing sequential models. | ||
Sequential models execute a list of modules/functions in order | ||
(sequentially). Therefore, we can divide such a sequence into segments | ||
and checkpoint each segment. All segments except run in :func:`torch.no_grad` | ||
manner, i.e., not storing the intermediate activations. The inputs of each | ||
checkpointed segment will be saved for re-running the segment in the backward pass. | ||
See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works. | ||
.. warning:: | ||
Checkpointing currently only supports :func:`torch.autograd.backward` | ||
and only if its `inputs` argument is not passed. :func:`torch.autograd.grad` | ||
is not supported. | ||
.. warning: | ||
At least one of the inputs needs to have :code:`requires_grad=True` if | ||
grads are needed for model inputs, otherwise the checkpointed part of the | ||
model won't have gradients. | ||
Args: | ||
functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially. | ||
x: A Tensor that is input to :attr:`functions` | ||
every: checkpoint every-n functions (default: 1) | ||
flatten (bool): flatten nn.Sequential of nn.Sequentials | ||
skip_last (bool): skip checkpointing the last function in the sequence if True | ||
preserve_rng_state (bool, optional, default=True): Omit stashing and restoring | ||
the RNG state during each checkpoint. | ||
Returns: | ||
Output of running :attr:`functions` sequentially on :attr:`*inputs` | ||
Example: | ||
>>> model = nn.Sequential(...) | ||
>>> input_var = checkpoint_seq(model, input_var, every=2) | ||
""" | ||
def run_function(start, end, functions): | ||
def forward(_x): | ||
for j in range(start, end + 1): | ||
_x = functions[j](_x) | ||
return _x | ||
return forward | ||
|
||
if isinstance(functions, torch.nn.Sequential): | ||
functions = functions.children() | ||
if flatten: | ||
functions = chain.from_iterable(functions) | ||
if not isinstance(functions, (tuple, list)): | ||
functions = tuple(functions) | ||
|
||
num_checkpointed = len(functions) | ||
if skip_last: | ||
num_checkpointed -= 1 | ||
end = -1 | ||
for start in range(0, num_checkpointed, every): | ||
end = min(start + every - 1, num_checkpointed - 1) | ||
x = checkpoint(run_function(start, end, functions), x) | ||
if skip_last: | ||
return run_function(end + 1, len(functions) - 1, functions)(x) | ||
return x | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters