Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Enable FSDP usage for PyCls models (#176)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #176

Add support for running PyCls ViT models with [FSDP](https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html).

Authors: Anjali Sridhar, Vaibhav Aggarwal

Reviewed By: sdebnathusc

Differential Revision: D34689650

fbshipit-source-id: 062021781d69383fa4f6fe114cf92e30db068c36
  • Loading branch information
Vaibhav Aggarwal authored and facebook-github-bot committed Mar 21, 2022
1 parent 8709e4f commit 0a97464
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 16 deletions.
54 changes: 54 additions & 0 deletions configs/fsdp/example.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
MODEL:
TYPE: vit
NUM_CLASSES: 1000
VIT:
PATCH_SIZE: 16
STEM_TYPE: "patchify"
NUM_LAYERS: 32
NUM_HEADS: 16
HIDDEN_DIM: 1280
MLP_DIM: 5120
CLASSIFIER_TYPE: "pooled"
BN:
USE_PRECISE_STATS: False
OPTIM:
OPTIMIZER: adamw
LR_POLICY: cos
BASE_LR: 0.001
MIN_LR: 0.005
MAX_EPOCH: 100
WEIGHT_DECAY: 0.24
WARMUP_EPOCHS: 5
EMA_ALPHA: 1.0e-5
EMA_UPDATE_PERIOD: 32
BIAS_USE_CUSTOM_WEIGHT_DECAY: True
BIAS_CUSTOM_WEIGHT_DECAY: 0.
MTA: True
LN:
EPS: 1e-6
USE_CUSTOM_WEIGHT_DECAY: True
CUSTOM_WEIGHT_DECAY: 0.
TRAIN:
DATASET: imagenet
IM_SIZE: 224
BATCH_SIZE: 256
MIXUP_ALPHA: 0.8
CUTMIX_ALPHA: 1.0
LABEL_SMOOTHING: 0.1
AUGMENT: AutoAugment
MIXED_PRECISION: True
TEST:
DATASET: imagenet
IM_SIZE: 224
BATCH_SIZE: 256
DATA_LOADER:
NUM_WORKERS: 10
LOG_PERIOD: 100
NUM_GPUS: 8
LAUNCH:
GPU_TYPE: "volta32gb"
MODE: "local"
FSDP:
ENABLED: True
RESHARD_AFTER_FW: False
LAYER_NORM_FP32: True
14 changes: 9 additions & 5 deletions pycls/core/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@ def has_checkpoint():

def save_checkpoint(model, model_ema, optimizer, epoch, test_err, ema_err):
"""Saves a checkpoint and also the best weights so far in a best checkpoint."""
# Save checkpoints only from the main process
if not dist.is_main_proc():
return

# Ensure that the checkpoint dir exists
pathmgr.mkdirs(get_checkpoint_dir())
# Record the state
Expand All @@ -72,15 +70,21 @@ def save_checkpoint(model, model_ema, optimizer, epoch, test_err, ema_err):
"optimizer_state": optimizer.state_dict(),
"cfg": cfg.dump(),
}

# Write the checkpoint
checkpoint_file = get_checkpoint(epoch + 1)
# Save checkpoints only from the main process
if not dist.is_main_proc():
return

with pathmgr.open(checkpoint_file, "wb") as f:
torch.save(checkpoint, f)
# Store the best model and model_ema weights so far
if not pathmgr.exists(get_checkpoint_best()):
pathmgr.copy(checkpoint_file, get_checkpoint_best())
with pathmgr.open(get_checkpoint_best(), "wb") as f:
torch.save(checkpoint, f)
else:
with pathmgr.open(get_checkpoint_best(), "rb") as f:
with open(get_checkpoint_best(), "rb") as f:
best = torch.load(f, map_location="cpu")
# Select the best model weights and the best model_ema weights
if test_err < best["test_err"] or ema_err < best["ema_err"]:
Expand Down
14 changes: 14 additions & 0 deletions pycls/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@
# Iteration frequency with which to update EMA weights
_C.OPTIM.EMA_UPDATE_PERIOD = 32

# Enable usage of multi tensor apply optimizers for better performance.
_C.OPTIM.MTA = False

# --------------------------------- Training options --------------------------------- #
_C.TRAIN = CfgNode()
Expand Down Expand Up @@ -397,6 +399,18 @@
_C.LAUNCH.TIME_LIMIT = 4200
_C.LAUNCH.EMAIL = ""

# --------------------------------- FSDP keys -----------------------------------------#
_C.FSDP = CfgNode()

# Enable FSDP sharding
_C.FSDP.ENABLED = False

# Enable resharding after the FW pass. This saves memory but tradesoff communication.
_C.FSDP.RESHARD_AFTER_FW = True

# Enable wrapping LayerNorm in a FSDP wrapper which allows weights and stats to remain in FP32.
_C.FSDP.LAYER_NORM_FP32 = True


# ----------------------------------- Misc options ----------------------------------- #
# Optional description of a config
Expand Down
19 changes: 15 additions & 4 deletions pycls/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ def construct_optimizer(model):
param_wds = [{"params": p, "weight_decay": w} for (p, w) in zip(params, wds) if p]
# Set up optimizer
if optim.OPTIMIZER == "sgd":
optimizer = torch.optim.SGD(
if cfg.OPTIM.MTA:
optimizer_fn = torch.optim._multi_tensor.SGD
else:
optimizer_fn = torch.optim.SGD
return optimizer_fn(
param_wds,
lr=optim.BASE_LR,
momentum=optim.MOMENTUM,
Expand All @@ -57,22 +61,29 @@ def construct_optimizer(model):
nesterov=optim.NESTEROV,
)
elif optim.OPTIMIZER == "adam":
optimizer = torch.optim.Adam(
if cfg.OPTIM.MTA:
optimizer_fn = torch.optim._multi_tensor.Adam
else:
optimizer_fn = torch.optim.Adam
return optimizer_fn(
param_wds,
lr=optim.BASE_LR,
betas=(optim.BETA1, optim.BETA2),
weight_decay=wd,
)
elif optim.OPTIMIZER == "adamw":
optimizer = torch.optim.AdamW(
if cfg.OPTIM.MTA:
optimizer_fn = torch.optim._multi_tensor.AdamW
else:
optimizer_fn = torch.optim.AdamW
return optimizer_fn(
param_wds,
lr=optim.BASE_LR,
betas=(optim.BETA1, optim.BETA2),
weight_decay=wd,
)
else:
raise NotImplementedError
return optimizer


def lr_fun_steps(cur_epoch):
Expand Down
107 changes: 100 additions & 7 deletions pycls/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

"""Tools for training and testing a model."""

import contextlib
import os
import random
from copy import deepcopy
Expand All @@ -24,13 +25,73 @@
import pycls.datasets.loader as data_loader
import torch
import torch.cuda.amp as amp
from fairscale.nn import auto_wrap, config_auto_wrap_policy, enable_wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.optim.grad_scaler import ShardedGradScaler
from pycls.core.config import cfg
from pycls.core.io import cache_url, pathmgr


logger = logging.get_logger(__name__)


def auto_wrap_ln(module, fsdp_config=None, wrap_it=None, assert_on_collision=None):
"""
Auto wrap all LayerNorm instances with a safer FSDP, esp. when convert
to LayerNorm is used and the outer FSDP is flattening.
We put KN in is own full precision, unflatten, single GPU group FSDP. Note, LNs still have
a group size == world_size. The input and output for LN are still FP16 in mixed precision mode.
See ``keep_batchnorm_fp32`` here: https://nvidia.github.io/apex/amp.html
This needs to be done at each rank, like models being wrapped by FSDP at each rank.
Args:
module (nn.Module):
The model (or part of the model) in which BN to be pre-wrapped.
process_group (ProcessGroup):
Optional process group to be used.
fsdp_config (Dict):
Optional fsdp_config to be used.
wrap_it (bool):
Whether or not wrap the module after setting the config.
Default: True
assert_on_collision (bool):
Whether or not assert if a wrapper_config already exists on the module.
Default: True
Returns:
Processed module, where BNs are wrapped with a special FSDP instance.
"""
if fsdp_config is None:
fsdp_config = {
"mixed_precision": False, # Keep the weights in FP32.
"flatten_parameters": False, # Do not flatten.
# Reshard==False is good for performance. When FSDP(checkpoint(FSDP(bn))) is used, this
# **must** be False because BN's FSDP wrapper's pre-backward callback isn't called
# within the checkpoint's outer backward when multiple forward passes are used.
"reshard_after_forward": False,
# No bucketing or small bucketing should be enough for BNs.
"bucket_cap_mb": 0,
# Setting this for SyncBatchNorm. This may have a performance impact. If
# SyncBatchNorm is used, this can be enabled by passing in the `fsdp_config` argument.
"force_input_to_fp32": False,
}

# Assign the config dict to BNs.
for m in module.modules():
if isinstance(m, torch.nn.LayerNorm):
if assert_on_collision:
assert not hasattr(
m, "wrapper_config"
), "Module shouldn't already have a wrapper_config. Is it tagged already by another policy?"
m.wrapper_config = fsdp_config

# Wrap it.
with (
enable_wrap(config_auto_wrap_policy, wrapper_cls=FSDP)
if wrap_it
else contextlib.suppress()
):
return auto_wrap(module)


def setup_env():
"""Sets up environment for training or testing."""
if dist.is_main_proc():
Expand Down Expand Up @@ -60,18 +121,48 @@ def setup_model():
"""Sets up a model for training or testing and log the results."""
# Build the model
model = builders.build_model()

if cfg.FSDP.ENABLED:
fsdp_config = {}
fsdp_config["reshard_after_forward"] = cfg.FSDP.RESHARD_AFTER_FW
fsdp_config["mixed_precision"] = cfg.TRAIN.MIXED_PRECISION
if cfg.TRAIN.MIXED_PRECISION:
fsdp_config["clear_autocast_cache"] = True

# Enable LAYER_NORM_FP32 wrapping for mixed precision training only. It is not
# required for full precision training.
ema_model = builders.build_model()
if cfg.TRAIN.MIXED_PRECISION and cfg.FSDP.LAYER_NORM_FP32:

def do_wrap(m):
m = auto_wrap_ln(m)
return FSDP(m, **fsdp_config)

model = do_wrap(model)
ema_model = do_wrap(ema_model)
else:
with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
model = FSDP(model, **fsdp_config)
ema_model = FSDP(ema_model, **fsdp_config)

else:
# Build the model
model = builders.build_model()
ema_model = deepcopy(model)

logger.info("Model:\n{}".format(model)) if cfg.VERBOSE else ()
# Log model complexity
logger.info(logging.dump_log_data(net.complexity(model), "complexity"))
# Transfer the model to the current GPU device
cur_device = torch.cuda.current_device()
model = model.cuda(device=cur_device)
ema_model = ema_model.cuda(device=cur_device)
# Use multi-process data parallel model in the multi-gpu setting
if cfg.NUM_GPUS > 1:
# Make model replica operate on the current device
ddp = torch.nn.parallel.DistributedDataParallel
model = ddp(module=model, device_ids=[cur_device], output_device=cur_device)
return model
return model, ema_model


def get_weights_file(weights_file):
Expand Down Expand Up @@ -161,8 +252,7 @@ def train_model():
# Setup training/testing environment
setup_env()
# Construct the model, ema, loss_fun, and optimizer
model = setup_model()
ema = deepcopy(model)
model, ema = setup_model()
loss_fun = builders.build_loss_fun().cuda()
optimizer = optim.construct_optimizer(model)
# Load checkpoint or initial weights
Expand All @@ -183,7 +273,10 @@ def train_model():
test_meter = meters.TestMeter(len(test_loader))
ema_meter = meters.TestMeter(len(test_loader), "test_ema")
# Create a GradScaler for mixed precision training
scaler = amp.GradScaler(enabled=cfg.TRAIN.MIXED_PRECISION)
if cfg.FSDP.ENABLED:
scaler = ShardedGradScaler(enabled=cfg.TRAIN.MIXED_PRECISION)
else:
scaler = amp.GradScaler(enabled=cfg.TRAIN.MIXED_PRECISION)
# Compute model and loader timings
if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0:
benchmark.compute_time_full(model, loss_fun, train_loader, test_loader)
Expand Down Expand Up @@ -212,7 +305,7 @@ def test_model():
# Setup training/testing environment
setup_env()
# Construct the model
model = setup_model()
model, _ = setup_model()
# Load model weights
test_weights = get_weights_file(cfg.TEST.WEIGHTS)
cp.load_checkpoint(test_weights, model)
Expand All @@ -229,7 +322,7 @@ def time_model():
# Setup training/testing environment
setup_env()
# Construct the model and loss_fun
model = setup_model()
model, _ = setup_model()
loss_fun = builders.build_loss_fun().cuda()
# Compute model and loader timings
benchmark.compute_time_model(model, loss_fun)
Expand All @@ -240,7 +333,7 @@ def time_model_and_loader():
# Setup training/testing environment
setup_env()
# Construct the model and loss_fun
model = setup_model()
model, _ = setup_model()
loss_fun = builders.build_loss_fun().cuda()
# Create data loaders
train_loader = data_loader.construct_train_loader()
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
black==19.3b0
isort==4.3.21
iopath
fairscale
flake8
pyyaml
matplotlib
Expand Down

0 comments on commit 0a97464

Please sign in to comment.