Skip to content

Commit

Permalink
Add support for FSDP + looped bfs
Browse files Browse the repository at this point in the history
ghstack-source-id: 1d054962ec0090f494c699d3bdc5ca1121f15467
Pull Request resolved: #1068
  • Loading branch information
wconstab committed Apr 18, 2024
1 parent 289db25 commit 59a3fc6
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 75 deletions.
3 changes: 3 additions & 0 deletions pippy/PipelineSchedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,9 @@ def step_microbatches(

for stage in reversed(self._stages):
for i in range(self._n_microbatches):
stage._configure_data_parallel_mode(
i == self._n_microbatches - 1
)
with record_function(f"Stage {stage.stage_index} Backward"):
ops = stage.get_bwd_recv_ops()
if ops:
Expand Down
169 changes: 94 additions & 75 deletions test/test_composability.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
import torch.distributed as dist
import torch.nn as nn
from pippy.ManualPipelineStage import ManualPipelineStage
from pippy.PipelineSchedule import Schedule1F1B, ScheduleGPipe
from pippy.PipelineSchedule import (
Schedule1F1B,
ScheduleGPipe,
ScheduleLoopedBFS,
)
from torch.distributed._composable.fsdp.fully_shard import (
fully_shard,
MixedPrecisionPolicy,
Expand Down Expand Up @@ -193,7 +197,7 @@ def test_manual_pipeline_with_manual_allreduce(self):
ddp_pp_model.all_reduce(num_microbatches)
print(f"{self.rank} finished all_reduce")

@parametrize("schedule_name", ["gpipe", "1f1b"])
@parametrize("schedule_name", ["gpipe", "1f1b", "looped_bfs"])
def test_manual_pipeline_with_fsdp(self, schedule_name):
device_mesh, device = self._init_device_mesh(
mesh_shape=(2, 2), mesh_dim_names=("dp", "pp")
Expand All @@ -204,46 +208,51 @@ def test_manual_pipeline_with_fsdp(self, schedule_name):
assert type(dp_mesh) == DeviceMesh

# create "entire model"
pp_group_size = pp_group.size()

# 8 layers
layers_per_model = 4
total_layers = 8
dim = 10
full_model = nn.ModuleList(
[
nn.Linear(dim, dim)
for _ in range(pp_group_size * layers_per_model)
]
[nn.Linear(dim, dim) for _ in range(total_layers)]
)
ref_model = nn.Sequential(*copy.deepcopy(full_model))
ref_model.to(device)

# divide the model (8 layers) by the number of ranks (2)
partial_model = nn.Sequential(
*full_model[
pp_group.rank()
* layers_per_model : (pp_group.rank() + 1)
* layers_per_model
]
)
partial_model.to(device)

# apply FSDP
mp_policy = MixedPrecisionPolicy(
# TODO(whc) need to fix PP + FSDP-mixed-precision
# tracer for PP assumes f32 and is caught off guard when runtime FSDP interacts using bf16 inputs
# param_dtype=torch.bfloat16, reduce_dtype=torch.float32
param_dtype=torch.float32,
reduce_dtype=torch.float32,
)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer in partial_model.children():
fully_shard(
layer,
**fsdp_config,
reshard_after_forward=False,
def build_stage(stage_idx, num_stages):
layers_per_model = total_layers // num_stages
assert layers_per_model * num_stages == total_layers
# return offset so validation code can match partial layer back to orig model
offset = stage_idx * layers_per_model
partial_model = nn.Sequential(
*full_model[offset : (stage_idx + 1) * layers_per_model]
)
partial_model.to(device)

# apply FSDP
mp_policy = MixedPrecisionPolicy(
# TODO(whc) need to fix PP + FSDP-mixed-precision
# tracer for PP assumes f32 and is caught off guard when runtime FSDP interacts using bf16 inputs
# param_dtype=torch.bfloat16, reduce_dtype=torch.float32
param_dtype=torch.float32,
reduce_dtype=torch.float32,
)
fsdp_model = fully_shard(partial_model, **fsdp_config)
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer in partial_model.children():
fully_shard(
layer,
**fsdp_config,
reshard_after_forward=False,
)
fsdp_model = fully_shard(partial_model, **fsdp_config)

stage = self._create_manual_pipeline_stage(
fsdp_model,
stage_idx,
num_stages,
device,
pp_group,
input_mb[0],
num_microbatches,
)
return stage, offset

# apply PP
num_microbatches = 8
Expand All @@ -255,33 +264,52 @@ def test_manual_pipeline_with_fsdp(self, schedule_name):
input_mb = [
[input[i].reshape((1, dim))] for i in range(num_microbatches)
]
pipeline_stage = self._create_manual_pipeline_stage(
fsdp_model,
pp_group.rank(),
pp_group.size(),
device,
pp_group,
input_mb[0],
num_microbatches,
)

# dummy loss needed just to force backwards to run in schedule step
loss_fn = lambda y, t: y.sum()

if schedule_name == "gpipe":
pipeline_schedule = ScheduleGPipe(
pipeline_stage,
n_microbatches=num_microbatches,
loss_fn=loss_fn,
)
elif schedule_name == "1f1b":
pipeline_schedule = Schedule1F1B(
pipeline_stage,
n_microbatches=num_microbatches,
loss_fn=loss_fn,
)
# divide the model (8 layers) by the number of ranks (2)
if schedule_name in {"looped_bfs", "looped_dfs"}:
n_virtual = 2
num_stages = pp_group.size() * n_virtual
stages = []
offsets = []
for i in range(n_virtual):
stage, offset = build_stage(
pp_group.rank() + n_virtual * i, num_stages
)
stages.append(stage)
offsets.append(offset)
partial_models = [
pipeline_stage.submod for pipeline_stage in stages
]

if schedule_name == "looped_bfs":
pipeline_schedule = ScheduleLoopedBFS(
stages,
n_microbatches=num_microbatches,
loss_fn=loss_fn,
)
else:
raise RuntimeError(f"unsupported schedule {schedule_name}")
pipeline_stage, offset = build_stage(
pp_group.rank(), pp_group.size()
)
partial_models = [pipeline_stage.submod]
offsets = [offset]

if schedule_name == "gpipe":
pipeline_schedule = ScheduleGPipe(
pipeline_stage,
n_microbatches=num_microbatches,
loss_fn=loss_fn,
)
elif schedule_name == "1f1b":
pipeline_schedule = Schedule1F1B(
pipeline_stage,
n_microbatches=num_microbatches,
loss_fn=loss_fn,
)
else:
raise RuntimeError(f"unsupported schedule {schedule_name}")

pipeline_schedule.step_microbatches(
arg_mbs=input_mb, target_mbs=input_mb
Expand All @@ -297,26 +325,17 @@ def test_manual_pipeline_with_fsdp(self, schedule_name):
for p in ref_model.parameters():
p.grad /= dp_mesh.size()

# Pretty ugly way to deal with using the sequential container to hold slices of the model.
# (on pp-stage 1, ref-model's 4th layer gets re-indexed as the 0th layer in the submodule)
fqn_map = {
"0": "4",
"1": "5",
"2": "6",
"3": "7",
}

# Validate that whichever weights we have locally match that part of our local/full ref model
# (we force FSDP's grads to be all-gathered (.full_tensor) to make it simpler)
ref_parameters = dict(ref_model.named_parameters())
for name, p in partial_model.named_parameters():
parts = name.split(".")
if pp_group.rank() == 1:
parts[0] = fqn_map[parts[0]]
name = ".".join(parts)
ref_p = ref_parameters[name]
self.assertTrue(isinstance(p.grad, DTensor))
self.assertEqual(ref_p.grad, p.grad.full_tensor())
for partial_model, offset in zip(partial_models, offsets):
for name, p in partial_model.named_parameters():
parts = name.split(".")
parts[0] = str(int(parts[0]) + offset)
name = ".".join(parts)
ref_p = ref_parameters[name]
self.assertTrue(isinstance(p.grad, DTensor))
self.assertEqual(ref_p.grad, p.grad.full_tensor())


instantiate_parametrized_tests(TestPipelineComposability)
Expand Down

0 comments on commit 59a3fc6

Please sign in to comment.