Skip to content

Commit

Permalink
[FSDP/Checkpoint] Activation offload support in checkpoint_wrapper (p…
Browse files Browse the repository at this point in the history
…ytorch#70165)

Summary:
Pull Request resolved: pytorch#70165

Implements activation offload support in checkpoint_wrapper API via
save_on_cpu hooks. We avoid modifying the torch.utils.checkpoint implementation
and instead compose offload + checkpoint using the save_on_cpu hook for the
former.
ghstack-source-id: 146078900

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D33228820

fbshipit-source-id: 98b4da0828462c41c381689ee07360ad014e808a
  • Loading branch information
rohan-varma authored and facebook-github-bot committed Dec 21, 2021
1 parent e428a90 commit a197f3f
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 40 deletions.
117 changes: 85 additions & 32 deletions test/distributed/fsdp/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Owner(s): ["oncall: distributed"]

import contextlib
from copy import deepcopy
from functools import partial

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from torch.distributed._fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel as FSDP,
CPUOffload,
Expand All @@ -25,12 +25,19 @@
parametrize,
instantiate_parametrized_tests,
)
from torch.utils.checkpoint import checkpoint


class TestFSDPCheckpoint(FSDPTest):

class SequentialModule(nn.Module):
def __init__(self, checkpoint_layer=False, wrap_fsdp=False, *fsdp_args, **fsdp_kwargs):
def __init__(
self,
checkpoint_layer=False,
offload_activations=False,
wrap_fsdp=False,
*fsdp_args,
**fsdp_kwargs,
):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
super().__init__()
Expand All @@ -39,15 +46,16 @@ def __init__(self, checkpoint_layer=False, wrap_fsdp=False, *fsdp_args, **fsdp_k
l3 = nn.Linear(3, 3).cuda()

if checkpoint_layer:
l1 = checkpoint_wrapper(l1)
l2 = checkpoint_wrapper(l2)
l3 = checkpoint_wrapper(l3)
ckpt_wrapper = partial(
checkpoint_wrapper, offload_to_cpu=offload_activations
)

l1 = ckpt_wrapper(l1)
l2 = ckpt_wrapper(l2)
l3 = ckpt_wrapper(l3)

fsdp_wrapper = partial(
_maybe_wrap_fsdp,
wrap_fsdp=wrap_fsdp,
*fsdp_args,
**fsdp_kwargs
_maybe_wrap_fsdp, wrap_fsdp=wrap_fsdp, *fsdp_args, **fsdp_kwargs
)
self.ffn = nn.Sequential(
fsdp_wrapper(l1),
Expand All @@ -58,7 +66,6 @@ def __init__(self, checkpoint_layer=False, wrap_fsdp=False, *fsdp_args, **fsdp_k
def forward(self, x):
return self.ffn(x)


def _verify_parity(self, losses, outputs, models):
assert losses
assert outputs
Expand All @@ -79,18 +86,23 @@ def _verify_parity(self, losses, outputs, models):
@skip_if_lt_x_gpu(2)
@parametrize(
"cpu_offload",
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)],
)
def test_checkpoint_fsdp_wrapping(self, cpu_offload):
@parametrize("offload_activations", [True, False])
def test_checkpoint_fsdp_wrapping(self, cpu_offload, offload_activations):
# Test checkpoint(FSDP(layer1), FSDP(layer2), ....)
ckpt_sequential_wrapped_fsdp = checkpoint_wrapper(
TestFSDPCheckpoint.SequentialModule(
wrap_fsdp=True, cpu_offload=cpu_offload
)
),
offload_to_cpu=offload_activations,
)
# Test FSDP(checkpoint(layer1)), FSDP(checkpoint(layer2)), ....
inner_ckpt = TestFSDPCheckpoint.SequentialModule(
checkpoint_layer=True, wrap_fsdp=True, cpu_offload=cpu_offload
checkpoint_layer=True,
offload_activations=offload_activations,
wrap_fsdp=True,
cpu_offload=cpu_offload,
)

baseline = TestFSDPCheckpoint.SequentialModule(
Expand All @@ -101,17 +113,29 @@ def test_checkpoint_fsdp_wrapping(self, cpu_offload):
# flag set.
inp = torch.randn(10, 3, device=torch.cuda.current_device(), requires_grad=True)

models = [
ckpt_sequential_wrapped_fsdp,
inner_ckpt,
baseline
]
models = [ckpt_sequential_wrapped_fsdp, inner_ckpt, baseline]

for _ in range(2):
offload_to_cpu_event = "Memcpy DtoH"

for i in range(2):
losses = []
outputs = []
for m in models:
out = m(inp)
check_offload = m != baseline and i == 0 and offload_activations
profiler_ctx = (
torch.profiler.profile(use_cuda=True)
if check_offload
else contextlib.suppress()
)
with profiler_ctx as prof:
out = m(inp)

if check_offload:
event_names = [event.name for event in prof.events()]
offload_occured = any(
offload_to_cpu_event in name for name in event_names
)
self.assertTrue(offload_occured)
loss = out.sum()
loss.backward()
losses.append(loss)
Expand All @@ -122,16 +146,23 @@ def test_checkpoint_fsdp_wrapping(self, cpu_offload):
@skip_if_lt_x_gpu(2)
@parametrize(
"cpu_offload",
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)],
)
def test_basic_checkpoint_end_to_end(self, cpu_offload):
@parametrize("offload_activations", [True, False])
def test_basic_checkpoint_end_to_end(self, cpu_offload, offload_activations):
seq = TestFSDPCheckpoint.SequentialModule().to(torch.cuda.current_device())
# Runs FSDP with no checkpointing
fsdp_only_seq = FSDP(deepcopy(seq), cpu_offload=cpu_offload)
# Runs checkpoint-wrapped FSDP
checkpointed_fsdp = checkpoint_wrapper(FSDP(deepcopy(seq), cpu_offload=cpu_offload))
checkpointed_fsdp = checkpoint_wrapper(
FSDP(deepcopy(seq), cpu_offload=cpu_offload),
offload_to_cpu=offload_activations,
)
# Runs FSDP-wrapped checkpointed module
fsdp_wrapped_checkpoint = FSDP(checkpoint_wrapper(deepcopy(seq)), cpu_offload=cpu_offload)
fsdp_wrapped_checkpoint = FSDP(
checkpoint_wrapper(deepcopy(seq), offload_to_cpu=offload_activations),
cpu_offload=cpu_offload,
)
# Runs FSDP with manual calls to checkpoint.
fsdp_call_checkpoint = FSDP(deepcopy(seq), cpu_offload=cpu_offload)
# note that reentrant-based checkpointing requires inputs to have grad
Expand All @@ -143,17 +174,39 @@ def test_basic_checkpoint_end_to_end(self, cpu_offload):
fsdp_only_seq,
checkpointed_fsdp,
fsdp_wrapped_checkpoint,
fsdp_call_checkpoint
fsdp_call_checkpoint,
]

for _ in range(6):
offload_to_cpu_event = "Memcpy DtoH"

for i in range(6):
losses = []
outputs = []
for m in models:
if m == fsdp_call_checkpoint:
out = checkpoint(m, inp)
else:
out = m(inp)
check_offload = m != fsdp_only_seq and i == 0 and offload_activations
profiler_ctx = (
torch.profiler.profile(use_cuda=True)
if check_offload
else contextlib.suppress()
)
with profiler_ctx as prof:
if m == fsdp_call_checkpoint:
offload_ctx = (
torch.autograd.graph.save_on_cpu(pin_memory=True)
if offload_activations
else contextlib.suppress()
)
with offload_ctx:
out = checkpoint(m, inp)
else:
out = m(inp)

if check_offload:
event_names = [event.name for event in prof.events()]
offload_occured = any(
offload_to_cpu_event in name for name in event_names
)
self.assertTrue(offload_occured)
loss = out.sum()
loss.backward()
losses.append(loss)
Expand Down
34 changes: 26 additions & 8 deletions torch/distributed/algorithms/_checkpoint/_checkpoint_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from enum import Enum, auto
from contextlib import suppress

import torch
from torch.autograd.graph import save_on_cpu
from torch.utils.checkpoint import checkpoint


Expand All @@ -17,22 +19,28 @@ def __init__(
self,
mod: torch.nn.Module,
checkpoint_impl: CheckpointImpl = CheckpointImpl.REENTRANT,
offload_to_cpu: bool = False,
):
super().__init__()
self.mod = mod
self.checkpoint_impl = checkpoint_impl
self.offload_to_cpu = offload_to_cpu

def forward(self, *args, **kwargs):
return checkpoint(
self.mod,
use_reentrant=(self.checkpoint_impl == CheckpointImpl.REENTRANT),
*args,
**kwargs,
)
offload_mgr = save_on_cpu(pin_memory=True) if self.offload_to_cpu else suppress()
with offload_mgr: # type: ignore[attr-defined]
return checkpoint(
self.mod,
use_reentrant=(self.checkpoint_impl == CheckpointImpl.REENTRANT),
*args,
**kwargs,
)


def checkpoint_wrapper(
module: torch.nn.Module, checkpoint_impl: CheckpointImpl = CheckpointImpl.REENTRANT
module: torch.nn.Module,
checkpoint_impl: CheckpointImpl = CheckpointImpl.REENTRANT,
offload_to_cpu: bool = False,
) -> torch.nn.Module:
"""
A convenience wrapper for activation checkpointing. If the module is wrapped
Expand All @@ -48,6 +56,10 @@ def checkpoint_wrapper(
checkpoint_impl (Optional[CheckpointImpl]):
The checkpointing implementation to use. Currently only
CheckpointImpl.REENTRANT is supported.
offload_to_cpu (Optional[bool]):
Whether to offload outer activations to CPU. Note that this
currently only works with CheckpointImpl.REENTRANT.
Returns:
(nn.Module):
Wrapped module
Expand All @@ -58,4 +70,10 @@ def checkpoint_wrapper(
"No support for non-reentrant based checkpoint implementation."
)

return _CheckpointWrapper(module, checkpoint_impl)
if offload_to_cpu and checkpoint_impl != CheckpointImpl.REENTRANT:
raise ValueError(
"No support for CPU offload activations and non-reentrant based "
"checkpoint implementation."
)

return _CheckpointWrapper(module, checkpoint_impl, offload_to_cpu)

0 comments on commit a197f3f

Please sign in to comment.