Skip to content

Commit

Permalink
[Dynamo][FSDP] Migrate to ModuleWrapPolicy (pytorch#88453)
Browse files Browse the repository at this point in the history
Hello @wconstab! As you saw, `transformer_auto_wrap_policy()` is a misnomer and actually works for any module classes. The PR before this one tries to add a class `ModuleWrapPolicy` that takes in the `module_classes` in its constructor and works just like `transformer_auto_wrap_policy()` without requiring the `functools.partial()`. I hope you do not mind if we update the dynamo benchmarks util file with this migration.

The PR before this one might require some back and forth within FSDP devs, so I apologize for any consequent updates to this PR, which in itself is an easy change. I will request review once we know the previous PR is good for land.

Pull Request resolved: pytorch#88453
Approved by: https://github.com/wconstab
  • Loading branch information
awgu authored and pytorchmergebot committed Nov 13, 2022
1 parent bca75fd commit 4284862
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions benchmarks/dynamo/dist_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
CheckpointImpl,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.distributed.fsdp.wrap import ModuleWrapPolicy

try:
from .torchbench import setup_torchbench_cwd
Expand Down Expand Up @@ -138,10 +138,7 @@ def apply_fsdp(args, model, use_checkpointing=False, use_wrap_policy=True):
"toy_model" if model.__class__ is ToyModel else args.torchbench_model
]
if use_wrap_policy:
# transformer policy is really a generic policy that wraps modules of specified classes
wrap_policy = functools.partial(
transformer_auto_wrap_policy, transformer_layer_cls=blocks
)
wrap_policy = ModuleWrapPolicy(blocks)

model = FSDP(model, auto_wrap_policy=wrap_policy, use_orig_params=True)
if use_checkpointing:
Expand Down

0 comments on commit 4284862

Please sign in to comment.