Skip to content

Commit

Permalink
Merge pull request FlagAI-Open#33 from baai-open-internal/vit-checkpo…
Browse files Browse the repository at this point in the history
…inting-activations

add  checkpointing activations for VIT when env_type = pytorch/deepspeed
marscrazy authored Jul 21, 2022
2 parents c35d4b6 + dc6fc3d commit c1cec9f
Showing 5 changed files with 102 additions and 28 deletions.
1 change: 1 addition & 0 deletions examples/vit_cifar100/train_DDP.py
Original file line number Diff line number Diff line change
@@ -28,6 +28,7 @@
save_interval=1000,
num_checkpoints=1,
hostfile="./hostfile",
training_script="train_DDP.py"
)

def build_cifar():
1 change: 1 addition & 0 deletions examples/vit_cifar100/train_deepspeed.py
Original file line number Diff line number Diff line change
@@ -29,6 +29,7 @@
save_interval=1000,
num_checkpoints=1,
hostfile="./hostfile",
training_script="train_deepspeed.py"
)

def build_cifar():
70 changes: 70 additions & 0 deletions flagai/model/vision/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@

import os
if os.getenv('ENV_TYPE') == 'deepspeed':
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
else:
from torch.utils.checkpoint import checkpoint
import torch
from itertools import chain

def checkpoint_seq(
functions,
x,
every=1,
flatten=False,
skip_last=False,
):
r"""A helper function for checkpointing sequential models.
Sequential models execute a list of modules/functions in order
(sequentially). Therefore, we can divide such a sequence into segments
and checkpoint each segment. All segments except run in :func:`torch.no_grad`
manner, i.e., not storing the intermediate activations. The inputs of each
checkpointed segment will be saved for re-running the segment in the backward pass.
See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
.. warning::
Checkpointing currently only supports :func:`torch.autograd.backward`
and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
is not supported.
.. warning:
At least one of the inputs needs to have :code:`requires_grad=True` if
grads are needed for model inputs, otherwise the checkpointed part of the
model won't have gradients.
Args:
functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially.
x: A Tensor that is input to :attr:`functions`
every: checkpoint every-n functions (default: 1)
flatten (bool): flatten nn.Sequential of nn.Sequentials
skip_last (bool): skip checkpointing the last function in the sequence if True
preserve_rng_state (bool, optional, default=True): Omit stashing and restoring
the RNG state during each checkpoint.
Returns:
Output of running :attr:`functions` sequentially on :attr:`*inputs`
Example:
>>> model = nn.Sequential(...)
>>> input_var = checkpoint_seq(model, input_var, every=2)
"""
def run_function(start, end, functions):
def forward(_x):
for j in range(start, end + 1):
_x = functions[j](_x)
return _x
return forward

if isinstance(functions, torch.nn.Sequential):
functions = functions.children()
if flatten:
functions = chain.from_iterable(functions)
if not isinstance(functions, (tuple, list)):
functions = tuple(functions)

num_checkpointed = len(functions)
if skip_last:
num_checkpointed -= 1
end = -1
for start in range(0, num_checkpointed, every):
end = min(start + every - 1, num_checkpointed - 1)
x = checkpoint(run_function(start, end, functions), x)
if skip_last:
return run_function(end + 1, len(functions) - 1, functions)(x)
return x

53 changes: 27 additions & 26 deletions flagai/model/vision/vit.py
Original file line number Diff line number Diff line change
@@ -33,6 +33,7 @@
from flagai.model.vision.layers.drop import DropPath
from flagai.model.vision.layers.weight_init import trunc_normal_, lecun_normal_
from flagai.model.base_model import BaseModel
from flagai.model.vision.helpers import checkpoint_seq

class VitConfig:
def __init__(self,
@@ -53,7 +54,7 @@ def __init__(self,
attn_drop_rate=0.,
drop_path_rate=0.,
weight_init='',
checkpoint_activations=None):
checkpoint_activations=False):
pass
self.img_size=img_size
self.patch_size=patch_size
@@ -74,7 +75,6 @@ def __init__(self,
self.weight_init=weight_init
self.checkpoint_activations = checkpoint_activations


def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module:
if not depth_first and include_root:
fn(module=module, name=name)
@@ -206,42 +206,42 @@ def __init__(
block_fn=Block
vit_config = VitConfig(**config)
vit_config.num_classes = num_classes
config = vit_config
# config = vit_config

assert config.global_pool in ('', 'avg', 'token')
assert config.class_token or config.global_pool != 'token'
use_fc_norm = config.global_pool == 'avg' if config.fc_norm is None else config.fc_norm
assert vit_config.global_pool in ('', 'avg', 'token')
assert vit_config.class_token or vit_config.global_pool != 'token'
use_fc_norm = vit_config.global_pool == 'avg' if vit_config.fc_norm is None else vit_config.fc_norm
norm_layer = partial(nn.LayerNorm, eps=1e-6)
act_layer = nn.GELU

self.num_classes = num_classes
self.global_pool = config.global_pool
self.num_features = self.embed_dim = config.embed_dim # num_features for consistency with other models
self.num_tokens = 1 if config.class_token else 0
self.grad_checkpointing = False
self.global_pool = vit_config.global_pool
self.num_features = self.embed_dim = vit_config.embed_dim # num_features for consistency with other models
self.num_tokens = 1 if vit_config.class_token else 0
self.grad_checkpointing = vit_config.checkpoint_activations

self.patch_embed = embed_layer(
img_size=config.img_size, patch_size=config.patch_size, in_chans=config.in_chans, embed_dim=config.embed_dim)
img_size=vit_config.img_size, patch_size=vit_config.patch_size, in_chans=vit_config.in_chans, embed_dim=vit_config.embed_dim)
num_patches = self.patch_embed.num_patches

self.cls_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if self.num_tokens > 0 else None
self.pos_embed = nn.Parameter(torch.randn(1, num_patches + self.num_tokens, config.embed_dim) * .02)
self.pos_drop = nn.Dropout(p=config.drop_rate)
self.cls_token = nn.Parameter(torch.zeros(1, 1, vit_config.embed_dim)) if self.num_tokens > 0 else None
self.pos_embed = nn.Parameter(torch.randn(1, num_patches + self.num_tokens, vit_config.embed_dim) * .02)
self.pos_drop = nn.Dropout(p=vit_config.drop_rate)

dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.depth)] # stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, vit_config.drop_path_rate, vit_config.depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(*[
block_fn(
dim=config.embed_dim, num_heads=config.num_heads, mlp_ratio=config.mlp_ratio, qkv_bias=config.qkv_bias, init_values=config.init_values,
drop=config.drop_rate, attn_drop=config.attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
for i in range(config.depth)])
self.norm = norm_layer(config.embed_dim) if not use_fc_norm else nn.Identity()
dim=vit_config.embed_dim, num_heads=vit_config.num_heads, mlp_ratio=vit_config.mlp_ratio, qkv_bias=vit_config.qkv_bias, init_values=vit_config.init_values,
drop=vit_config.drop_rate, attn_drop=vit_config.attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
for i in range(vit_config.depth)])
self.norm = norm_layer(vit_config.embed_dim) if not use_fc_norm else nn.Identity()

# Classifier Head
self.fc_norm = norm_layer(config.embed_dim) if use_fc_norm else nn.Identity()
self.fc_norm = norm_layer(vit_config.embed_dim) if use_fc_norm else nn.Identity()
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

if config.weight_init != 'skip':
self.init_weights(config.weight_init)
if vit_config.weight_init != 'skip':
self.init_weights(vit_config.weight_init)

def init_weights(self, mode=''):
assert mode in ('jax', 'jax_nlhb', 'moco', '')
@@ -290,10 +290,11 @@ def forward_features(self, x):
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = self.pos_drop(x + self.pos_embed)
# if self.grad_checkpointing and not torch.jit.is_scripting():
# x = checkpoint_seq(self.blocks, x)
# else:
x = self.blocks(x)

if self.config["checkpoint_activations"]:
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
x = self.norm(x)
return x

5 changes: 3 additions & 2 deletions flagai/trainer.py
Original file line number Diff line number Diff line change
@@ -348,7 +348,8 @@ def train(self,
train_dataset=None,
valid_dataset=None,
metric_methods=[],
collate_fn=None):
collate_fn=None,
find_unused_parameters=True):
"""Training Loops"""
"""
Trainer is a simple but unifed training and eval loop for PyTorch/Deepspeed/Megatron-LM.
@@ -416,7 +417,7 @@ def train(self,
model.to(torch.device('cuda', self.local_rank))
model = DDP(model,
device_ids=[self.local_rank],
find_unused_parameters=True)
find_unused_parameters=find_unused_parameters)

elif self.env_type == 'pytorch':
model.to(self.pytorch_device)

0 comments on commit c1cec9f

Please sign in to comment.