Skip to content

Commit

Permalink
[DDP Checkpointing] non-reentrant checkpoint tests (pytorch#69060)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#69060

Saved variable hooks checkpointing was added in pytorch#69508, this PR adds some tests for DDP.

Specifically, we can support almost all DDP use cases with this new API, such as dynamic module with find_unused_parameters=True. One case remains to be supported, which is static_graph + non-reentrant based checkpointing. The underlying reason this does not work is pytorch#58111.
ghstack-source-id: 147219887

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D32712126

fbshipit-source-id: ba5ae9ca77fd8929ee020c7dc97838bae9a1931b
(cherry picked from commit 9c7f93e)
  • Loading branch information
rohan-varma authored and pytorchmergebot committed Jan 19, 2022
1 parent 75aaa9f commit 3b589c3
Showing 1 changed file with 172 additions and 58 deletions.
230 changes: 172 additions & 58 deletions test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import tempfile
import threading
import time
from contextlib import contextmanager
from contextlib import contextmanager, suppress
from datetime import timedelta
from itertools import product
from unittest import mock
Expand Down Expand Up @@ -49,6 +49,8 @@
TEST_WITH_DEV_DBG_ASAN,
TEST_WITH_ROCM,
sandcastle_skip,
instantiate_parametrized_tests,
parametrize,
sandcastle_skip_if,
)
from torch.utils.checkpoint import checkpoint
Expand Down Expand Up @@ -2090,14 +2092,15 @@ class CheckpointOnceModule(nn.Module):
"""
Runs checkpoint for a single layer in the model.
"""
def __init__(self):
def __init__(self, use_reentrant=True):
super().__init__()
self.l1 = nn.Linear(20, 20)
self.l2 = nn.Linear(20, 20)
self.use_reentrant = use_reentrant

def forward(self, inp):
x = self.l1(inp)
x = checkpoint(self.l2, x)
x = checkpoint(self.l2, x, use_reentrant=self.use_reentrant)
return x

class CheckpointTwiceModule(CheckpointOnceModule):
Expand All @@ -2106,29 +2109,49 @@ class CheckpointTwiceModule(CheckpointOnceModule):
cases such as pipeline parallel where the same layer can be checkpointed
more than one time.
"""
def __init__(self):
super().__init__()
def __init__(self, use_reentrant=True):
super().__init__(use_reentrant=use_reentrant)

def forward(self, inp):
x = self.l1(inp)
x = checkpoint(self.l2, x)
x = checkpoint(self.l2, x)
x = checkpoint(self.l2, x, use_reentrant=self.use_reentrant)
x = checkpoint(self.l2, x, use_reentrant=self.use_reentrant)
return x

class CheckpointTwiceModuleWeightSharing(CheckpointTwiceModule):
"""
Similar to CheckpointTwiceModule but the weights are shared.
"""
def __init__(self):
super().__init__()
self.l1.weight = self.l2.weight
def __init__(self, use_reentrant=True):
super().__init__(use_reentrant=use_reentrant)

def forward(self, inp):
x = self.l1(inp)
x = checkpoint(self.l2, x)
x = checkpoint(self.l2, x)
x = checkpoint(self.l2, x, use_reentrant=self.use_reentrant)
x = checkpoint(self.l2, x, use_reentrant=self.use_reentrant)
return x


class DynamicCheckpointTwiceModule(CheckpointTwiceModule):
def __init__(self, use_reentrant=True):
super().__init__(use_reentrant=use_reentrant)
self.count = 0

def forward(self, inp):
if self.count % 2:
x = checkpoint(self.l1, inp, use_reentrant=self.use_reentrant)
else:
x = checkpoint(self.l2, inp, use_reentrant=self.use_reentrant)

self.count += 1
return x

class DynamicCheckpointTwiceModuleWeightSharing(DynamicCheckpointTwiceModule):
def __init__(self, use_reentrant=True):
super().__init__(use_reentrant=use_reentrant)
self.l1.weight = self.l2.weight


def _prepare_dummy_data(self):
ddp_bs = 16
bs = ddp_bs * self.world_size
Expand All @@ -2139,10 +2162,10 @@ def _prepare_dummy_data(self):
ddp_target = target[offset : offset + ddp_bs]
return input, ddp_input, target, ddp_target

def _train_model(self, model, input_var, target, loss, run_checkpoint=False):
def _train_model(self, model, input_var, target, loss, run_checkpoint=False, use_reentrant=True):
model.train()
if run_checkpoint:
output = checkpoint(model, input_var)
output = checkpoint(model, input_var, use_reentrant=use_reentrant)
else:
output = model(input_var)
l = loss(output, target)
Expand All @@ -2156,6 +2179,8 @@ def _test_ddp_checkpointing(
find_unused_parameters=False,
static_graph=False,
run_checkpoint=False,
use_reentrant=True,
allow_none_grads=False,
):
# to reproduce the same training results
torch.cuda.set_device(self.rank)
Expand All @@ -2177,28 +2202,32 @@ def _test_ddp_checkpointing(
)
input, ddp_input, target, ddp_target = self._prepare_dummy_data()
loss = nn.MSELoss()
for i in range(5):
n_iters = 5
for i in range(n_iters):
model.zero_grad(set_to_none=False)
ddp_model.zero_grad(set_to_none=False)
self._train_model(model, input, target, loss, run_checkpoint=run_checkpoint)
self._train_model(model, input, target, loss, run_checkpoint=run_checkpoint, use_reentrant=use_reentrant)
self._train_model(
ddp_model, ddp_input, ddp_target, loss, run_checkpoint=run_checkpoint
ddp_model, ddp_input, ddp_target, loss, run_checkpoint=run_checkpoint, use_reentrant=use_reentrant
)
for i, j in zip(model.parameters(), ddp_model.parameters()):
self.assertTrue(i.grad is not None)
self.assertTrue(j.grad is not None)
if not allow_none_grads:
self.assertTrue(i.grad is not None)
self.assertTrue(j.grad is not None)
self.assertEqual(i.grad, j.grad, rtol=1.3e-06, atol=5e-5)

# DDP works as expect when layer is checkpointed only once,
# when find_unused_parameters=False.
@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_ddp_checkpointing_once(self):
@parametrize("use_reentrant", [True, False])
def test_ddp_checkpointing_once(self, use_reentrant):
"""
DDP works as expected when layer is checkpointed only once.
"""
store = c10d.FileStore(self.file_name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
for use_bucket_view, static_graph in product((False, True), (False, True)):
self._test_ddp_checkpointing(
self.CheckpointOnceModule(),
self.CheckpointOnceModule(use_reentrant=use_reentrant),
process_group=process_group,
use_bucket_view=use_bucket_view,
static_graph=static_graph,
Expand All @@ -2214,49 +2243,70 @@ def test_ddp_checkpointing_once(self):
find_unused_parameters=True,
)

# DDP will fail when there are unused_parameters in the model and we are not
# using static graph training.
@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_ddp_checkpointing_unused_params(self):
@parametrize("use_reentrant", [True, False])
def test_ddp_checkpointing_unused_params(self, use_reentrant):
"""
With reentrant autograd checkpointing impl, DDP will fail when there are
unused params in the model and no static graph training. With
non-reentrant checkpointing implementation, this works as expected.
"""
store = c10d.FileStore(self.file_name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
for use_bucket_view in (True, False):
with self.assertRaisesRegex(
RuntimeError,
"Expected to mark a variable ready only once.",
):
err_ctx = (
suppress() if not use_reentrant else
self.assertRaisesRegex(
RuntimeError,
"Expected to mark a variable ready only once."
)
)
with err_ctx:
model = self._test_ddp_checkpointing(
self.CheckpointOnceModule(),
self.CheckpointOnceModule(use_reentrant=use_reentrant),
process_group=process_group,
use_bucket_view=use_bucket_view,
find_unused_parameters=True,
)
# test passes when static_graph is true
model = self._test_ddp_checkpointing(
self.CheckpointOnceModule(use_reentrant=use_reentrant),
process_group=process_group,
use_bucket_view=use_bucket_view,
find_unused_parameters=True,
static_graph=True,
)

# DDP will fail when the same layer is checkpointed twice, for both settings
# of find_unused_parameters, and non-static graph.
@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_ddp_checkpointing_twice_non_static_graph(self):
@parametrize("use_reentrant", [True, False])
def test_ddp_checkpointing_twice(self, use_reentrant):
"""
Checkpoitning twice fails for non-static graph with reentrant checkpoint
implementation, succeeds with non-reentrant checkpoint implementation.
"""
store = c10d.FileStore(self.file_name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
for use_bucket_view in (True, False):
error_ctx = self.assertRaisesRegex(
RuntimeError,
"Expected to mark a variable ready only once.",
err_ctx = (
suppress() if not use_reentrant else
self.assertRaisesRegex(
RuntimeError,
"Expected to mark a variable ready only once."
)
)

with error_ctx:
with err_ctx:
model = self._test_ddp_checkpointing(
self.CheckpointTwiceModule(),
self.CheckpointTwiceModule(use_reentrant=use_reentrant),
process_group=process_group,
use_bucket_view=use_bucket_view,
static_graph=False,
)

with error_ctx:
with err_ctx:
model = self._test_ddp_checkpointing(
self.CheckpointTwiceModule(),
self.CheckpointTwiceModule(use_reentrant=use_reentrant),
process_group=process_group,
use_bucket_view=use_bucket_view,
static_graph=False,
Expand All @@ -2265,23 +2315,73 @@ def test_ddp_checkpointing_twice_non_static_graph(self):

@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_ddp_checkpointing_twice_static_graph(self):
@parametrize("use_reentrant", [True, False])
def test_ddp_checkpointing_twice_static_graph(self, use_reentrant):
"""
Regardless of reentrant or non-reentrant checkpointing impl,
checkpointing twice works with static graph enabled.
"""
store = c10d.FileStore(self.file_name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
for use_bucket_view in (True, False):
# Test passes when static_graph=True.
model = self._test_ddp_checkpointing(
self.CheckpointTwiceModule(),
self.CheckpointTwiceModule(use_reentrant=use_reentrant),
process_group=process_group,
use_bucket_view=use_bucket_view,
static_graph=True,
)

# DDP works as expected if there is weight sharing among layers and we
# checkpoint once.
@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_ddp_checkpointing_weight_sharing(self):
def test_ddp_checkpointing_dynamic_module(self):
"""
Dynamic module can be checkpointed, multiple times, with non-reentrant
checkpointing implementation.
"""
store = c10d.FileStore(self.file_name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
for use_bucket_view in (True, False):
model = self._test_ddp_checkpointing(
self.DynamicCheckpointTwiceModule(use_reentrant=False),
process_group=process_group,
use_bucket_view=use_bucket_view,
static_graph=False,
find_unused_parameters=True,
# Grads can be none sometimes due to dynamic module not using
# all params.
allow_none_grads=True
)

@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_ddp_checkpointing_dynamic_weight_sharing(self):
"""
Dynamic module can be checkpointed multiple times with weight sharing
using non-reentrant checkpointing implementation.
"""
store = c10d.FileStore(self.file_name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
for use_bucket_view in (True, False):
model = self._test_ddp_checkpointing(
self.DynamicCheckpointTwiceModuleWeightSharing(use_reentrant=False),
process_group=process_group,
use_bucket_view=use_bucket_view,
static_graph=False,
find_unused_parameters=True,
# Grads can be none sometimes due to dynamic module not using
# all params.
allow_none_grads=True
)

# DDP works as expected if there is weight sharing among layers
@requires_nccl()
@skip_if_lt_x_gpu(2)
@parametrize("use_reentrant", [True, False])
def test_ddp_checkpointing_weight_sharing(self, use_reentrant):
"""
Test that checkpointing with weight sharing works.
"""
store = c10d.FileStore(self.file_name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
torch.cuda.set_device(self.rank)
Expand All @@ -2291,20 +2391,33 @@ def test_ddp_checkpointing_weight_sharing(self):
l2 = nn.Linear(20, 20)
l1.weight = l2.weight
model = nn.Sequential(l1, l2)
self._test_ddp_checkpointing(
model,
process_group=process_group,
use_bucket_view=use_bucket_view,
static_graph=static_graph,
run_checkpoint=True,
# TODO: non-reentrant based checkpointing of DDP module with
# static_graph runs into the below issue, see
# https://github.com/pytorch/pytorch/issues/70865 and
# https://github.com/pytorch/pytorch/issues/58111 for details.
err_ctx = (
self.assertRaisesRegex(
RuntimeError,
"Your training graph has changed in this iteration"
) if static_graph and not use_reentrant else suppress()
)
with err_ctx:
self._test_ddp_checkpointing(
model,
process_group=process_group,
use_bucket_view=use_bucket_view,
static_graph=static_graph,
run_checkpoint=True,
use_reentrant=use_reentrant,
)

# Checkpointing should work with static graph
# in the case of checkpointing same layer twice and
# having weights shared across layers.
@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_ddp_checkpoint_twice_weight_sharing(self):
def test_ddp_checkpointing_twice_weight_sharing(self):
"""
Checkpointing should work with static graph in the case of checkpointing
same layer twice and having weights shared acrosss layers.
"""
store = c10d.FileStore(self.file_name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
torch.cuda.set_device(self.rank)
Expand All @@ -2317,7 +2430,6 @@ def test_ddp_checkpoint_twice_weight_sharing(self):
)



class NcclErrorHandlingTest(MultiProcessTestCase):
def setUp(self):
super(NcclErrorHandlingTest, self).setUp()
Expand Down Expand Up @@ -2799,6 +2911,8 @@ def test_nccl_warn_not_in_group_debug_info(self):
def test_nccl_warn_not_in_group_debug_off(self):
self._test_warn_not_in_group(backend="nccl")

instantiate_parametrized_tests(DistributedDataParallelTest)

if __name__ == "__main__":
assert (
not torch.cuda._initialized
Expand Down

0 comments on commit 3b589c3

Please sign in to comment.