Skip to content

Commit

Permalink
Allow FSDP to have ignored modules out of wrapped root (pytorch#91079)
Browse files Browse the repository at this point in the history
Motivations for this change:

1. TorchRec returns inconsistent results on `m.named_parameters()`
   and `m.m1.named_parameters()` if m1 is a `ShardedModule`. Basically,
   `ShardedModule` appears in `m.named_modules()`, but its parameters
   are not in `m.named_parameters()`. As a result, when we identify
   `ShardedModule` and pass them as `ignored_modules` to FSDP, FSDP
   complains about key error in `_get_ignored_params`.
2. If users are manually wrapping submodules with FSDP, it could be
   easier for them to keep a global set of ignored parameters, instead
   of create a new collection for every FSDP invocation.

Given the above two reasons, we allow FSDP to have ignored modules
out of the wrapped root module.

Differential Revision: [D42132394](https://our.internmc.facebook.com/intern/diff/D42132394)
Pull Request resolved: pytorch#91079
Approved by: https://github.com/awgu
  • Loading branch information
mrshenli authored and pytorchmergebot committed Dec 19, 2022
1 parent 6686e9b commit e5a48da
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 8 deletions.
32 changes: 28 additions & 4 deletions test/distributed/fsdp/test_fsdp_ignored_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import torch
import torch.nn as nn
from torch import distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
ShardingStrategy,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
CUDAInitMode,
Expand Down Expand Up @@ -89,10 +92,11 @@ def __init__(self, num_ignored: int) -> None:
class TestFSDPIgnoredModules(FSDPTest):
def _train_model(self, model, optim, num_iters, device=torch.device("cuda")):
for _ in range(num_iters):
inp = model.module.get_input(device)
module = model.module if isinstance(model, FSDP) else model
inp = module.get_input(device)
output = model(*inp)
loss = model.module.get_loss(inp, output).to(device)
model.module.run_backward(loss)
loss = module.get_loss(inp, output).to(device)
module.run_backward(loss)
optim.step()

@skip_if_lt_x_gpu(2)
Expand Down Expand Up @@ -225,6 +229,26 @@ def test_diff_ignored_modules_across_ranks(
optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
self._train_model(wrapped_model, optim, 3)

@skip_if_lt_x_gpu(2)
def test_ignored_modules_not_under_wrapped_root(self):
model = Model().cuda()
ignored_modules = list(model.layer1.children())[1:]
model.layer1 = FSDP(
model.layer1,
# sharding_strategy shouldn't matter here.
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
ignored_modules=ignored_modules,
)
model.layer3 = FSDP(
model.layer3,
# the ignored_modules contains submodule under model.layer1, which
# is out of the the local root model.layer3.
ignored_modules=ignored_modules,
)
optim = torch.optim.Adam(model.parameters(), lr=1e-3)
self._train_model(model, optim, 3)



instantiate_parametrized_tests(TestFSDPIgnoredModules)

Expand Down
21 changes: 17 additions & 4 deletions torch/distributed/fsdp/_init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,23 +570,36 @@ def _get_ignored_params(
excluding any :class:`FlatParameter` s, and their fully prefixed names,
both as :class:`set` s.
"""
ignored_params = set(
p for m in ignored_modules for p in m.parameters() if not _is_fsdp_flattened(p)
ignored_params_to_names = dict(
(p, n)
for m in ignored_modules
for n, p in m.named_parameters()
if not _is_fsdp_flattened(p)
)
# Conservatively include all shared parameters' names
param_to_unflat_param_names = _get_param_to_fqns(
root_module,
dedup_shared_params=False,
)
ignored_param_names = set()
for param in ignored_params:
for param, name in ignored_params_to_names.items():
if param not in param_to_unflat_param_names:
# Allow users to pass parameters not under FSDP root module.
# This is useful when user apply FSDP manually to different
# submodules with the same global set of ignored parameters.
warnings.warn(
f"Parameter {name} is in the ignored modules passed to FSDP, "
"but it's not under the root module wrapped by FSDP."
)
continue

unflat_param_names = param_to_unflat_param_names[param]
clean_names = []
for k in unflat_param_names:
# Clean any module wrapper prefixes in case of nested wrapping
clean_names.append(clean_tensor_name(k))
ignored_param_names.update(clean_names)
return ignored_params, ignored_param_names
return set(ignored_params_to_names.keys()), ignored_param_names


def _get_buffer_names(root_module: nn.Module) -> Set[str]:
Expand Down

0 comments on commit e5a48da

Please sign in to comment.