Skip to content

Commit

Permalink
Move FSDPOptimizerWrapper to utils (pytorch#483)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#483

As title, this class can be used outside the framework, so moving it to the utils alongside other utilities for wrapping the module in FSDP

bypass-github-export-checks

Reviewed By: JKSenthil

Differential Revision: D47946052

fbshipit-source-id: 963ef53f3ac8a680d60fdfae4762afba0c09117f
  • Loading branch information
ananthsub authored and facebook-github-bot committed Aug 4, 2023
1 parent dcf37e3 commit a340a09
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 88 deletions.
32 changes: 2 additions & 30 deletions tests/framework/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,6 @@
import torch.distributed as dist
from torch import nn

from torchtnt.utils.version import is_torch_version_geq_2_0

if is_torch_version_geq_2_0():
from torch.distributed._composable import fully_shard

from torch.distributed import launcher
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.optim.lr_scheduler import ExponentialLR
Expand All @@ -30,10 +25,8 @@
from torchtnt.framework.utils import (
_construct_tracked_optimizers_and_schedulers,
_find_optimizers_for_module,
_FSDPOptimizerWrapper,
_is_done,
_is_epoch_done,
_is_fsdp_module,
_maybe_set_distributed_sampler_epoch,
_reset_module_training_mode,
_set_module_training_mode,
Expand All @@ -42,6 +35,7 @@
)
from torchtnt.utils.env import init_from_env
from torchtnt.utils.lr_scheduler import TLRScheduler
from torchtnt.utils.prepare_module import FSDPOptimizerWrapper
from torchtnt.utils.progress import Progress
from torchtnt.utils.test_utils import get_pet_launch_config
from torchtnt.utils.timer import Timer
Expand All @@ -50,28 +44,6 @@
class UtilsTest(unittest.TestCase):
cuda_available = torch.cuda.is_available()

@staticmethod
def _test_is_fsdp_module() -> None:
dist.init_process_group("gloo")
model = nn.Linear(1, 1)
assert not _is_fsdp_module(model)
model = FSDP(nn.Linear(1, 1))
assert _is_fsdp_module(model)
if is_torch_version_geq_2_0():
fully_shard(model)
assert _is_fsdp_module(model)

@unittest.skipUnless(
dist.is_available(), reason="Torch distributed is needed to run"
)
@unittest.skipUnless(
condition=torch.cuda.is_available() and torch.cuda.device_count() > 2,
reason="This test needs 2 GPUs to run.",
)
def test_is_fsdp_module(self) -> None:
config = get_pet_launch_config(2)
dist.launcher.elastic_launch(config, entrypoint=self._test_is_fsdp_module)()

def test_maybe_set_distributed_sampler_epoch(self) -> None:
config = get_pet_launch_config(3)
result = dist.launcher.elastic_launch(
Expand Down Expand Up @@ -277,7 +249,7 @@ def _construct_optimizers() -> None:

result = _construct_tracked_optimizers_and_schedulers(auto_unit)
tc = unittest.TestCase()
tc.assertTrue(isinstance(result["optim"], _FSDPOptimizerWrapper))
tc.assertTrue(isinstance(result["optim"], FSDPOptimizerWrapper))
tc.assertTrue(isinstance(result["optim2"], torch.optim.Optimizer))
tc.assertTrue(isinstance(result["lr_scheduler"], TLRScheduler))

Expand Down
31 changes: 30 additions & 1 deletion tests/utils/test_prepare_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@
from torch.nn.parallel import DistributedDataParallel as DDP
from torchtnt.utils.env import init_from_env
from torchtnt.utils.prepare_module import (
_is_fsdp_module,
DDPStrategy,
FSDPStrategy,
prepare_ddp,
prepare_fsdp,
)
from torchtnt.utils.test_utils import get_pet_launch_config
from torchtnt.utils.version import is_torch_version_geq_2_0

if is_torch_version_geq_2_0():
from torch.distributed._composable import fully_shard


class PrepareModelTest(unittest.TestCase):
Expand Down Expand Up @@ -86,9 +91,33 @@ def _test_fsdp_pytorch_version() -> None:

tc = unittest.TestCase()
with patch(
"torchtnt.utils.prepare_model.is_torch_version_geq_1_12", return_value=False
"torchtnt.utils.prepare_module.is_torch_version_geq_1_12",
return_value=False,
), tc.assertRaisesRegex(
RuntimeError,
"Please install PyTorch 1.12 or higher to use FSDP: https://pytorch.org/get-started/locally/",
):
_ = prepare_fsdp(module, device, FSDPStrategy())

@staticmethod
def _test_is_fsdp_module() -> None:
torch.distributed.init_process_group("gloo")
model = torch.nn.Linear(1, 1)
assert not _is_fsdp_module(model)
model = FSDP(torch.nn.Linear(1, 1))
assert _is_fsdp_module(model)
model = torch.nn.Linear(1, 1)
if is_torch_version_geq_2_0():
fully_shard(model)
assert _is_fsdp_module(model)

@unittest.skipUnless(
torch.distributed.is_available(), reason="Torch distributed is needed to run"
)
@unittest.skipUnless(
condition=torch.cuda.is_available() and torch.cuda.device_count() > 2,
reason="This test needs 2 GPUs to run.",
)
def test_is_fsdp_module(self) -> None:
config = get_pet_launch_config(2)
launcher.elastic_launch(config, entrypoint=self._test_is_fsdp_module)()
63 changes: 8 additions & 55 deletions torchtnt/framework/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,13 @@
import torch
import torch.nn as nn
import typing_extensions
from torch.distributed.fsdp import (
FullyShardedDataParallel,
FullyShardedDataParallel as FSDP,
)
from torch.profiler import record_function
from torchtnt.utils.version import is_torch_version_geq_2_0

if is_torch_version_geq_2_0():
from torch.distributed._composable_state import _get_module_state
from torch.distributed.fsdp._common_utils import _FSDPState

from torchtnt.framework.state import State
from torchtnt.framework.unit import AppStateMixin
from torchtnt.utils.lr_scheduler import TLRScheduler
from torchtnt.utils.prepare_module import _is_fsdp_module, FSDPOptimizerWrapper
from torchtnt.utils.progress import Progress


_logger: logging.Logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -132,64 +122,27 @@ def _step_requires_iterator(step_func: Callable[[State, object], object]) -> boo
return typing_extensions.get_origin(annotated_type) is collections.abc.Iterator


def _is_fsdp_module(module: torch.nn.Module) -> bool:
if isinstance(module, FSDP):
return True

if is_torch_version_geq_2_0():
# Also check for composable FSDP API
maybe_composable_state = _get_module_state(module)
if maybe_composable_state is not None:
return isinstance(maybe_composable_state, _FSDPState)

return False


class _FSDPOptimizerWrapper:
"""
Wrapper for FSDP optimizer to call specific FSDP optimizer state checkpointing APIs.
"""

def __init__(
self, module: torch.nn.Module, optimizer: torch.optim.Optimizer
) -> None:
self.module = module
self.optimizer = optimizer

def state_dict(self) -> Dict[str, Any]:
optim_state_dict = FullyShardedDataParallel.optim_state_dict(
self.module, self.optimizer
)
return optim_state_dict

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
optim_state_dict = FullyShardedDataParallel.optim_state_dict_to_load(
self.module, self.optimizer, state_dict
)
self.optimizer.load_state_dict(optim_state_dict)


def _construct_tracked_optimizers(
unit: AppStateMixin,
) -> Dict[str, Union[torch.optim.Optimizer, _FSDPOptimizerWrapper]]:
) -> Dict[str, Union[torch.optim.Optimizer, FSDPOptimizerWrapper]]:
"""
Constructs tracked optimizers. Handles optimizers working on FSDP modules, wrapping them in _FSDPOptimizerWrapper.
Constructs tracked optimizers. Handles optimizers working on FSDP modules, wrapping them in FSDPOptimizerWrapper.
"""
fsdp_tracked_optimizers: Dict[str, _FSDPOptimizerWrapper] = {}
fsdp_tracked_optimizers: Dict[str, FSDPOptimizerWrapper] = {}
for module in unit.tracked_modules().values():
if _is_fsdp_module(module):
# find optimizers for module, if exists
optimizer_list = _find_optimizers_for_module(
module, unit.tracked_optimizers()
)
for optim_name, optimizer in optimizer_list:
fsdp_tracked_optimizers[optim_name] = _FSDPOptimizerWrapper(
fsdp_tracked_optimizers[optim_name] = FSDPOptimizerWrapper(
module, optimizer
)

# construct custom tracked optimizers with FSDP optimizers
tracked_optimizers: Dict[
str, Union[torch.optim.Optimizer, _FSDPOptimizerWrapper]
str, Union[torch.optim.Optimizer, FSDPOptimizerWrapper]
] = {
key: value
for key, value in unit.tracked_optimizers().items()
Expand All @@ -201,9 +154,9 @@ def _construct_tracked_optimizers(

def _construct_tracked_optimizers_and_schedulers(
unit: AppStateMixin,
) -> Dict[str, Union[torch.optim.Optimizer, _FSDPOptimizerWrapper, TLRScheduler]]:
) -> Dict[str, Union[torch.optim.Optimizer, FSDPOptimizerWrapper, TLRScheduler]]:
"""
Combines tracked optimizers and schedulers. Handles optimizers working on FSDP modules, wrapping them in _FSDPOptimizerWrapper.
Combines tracked optimizers and schedulers. Handles optimizers working on FSDP modules, wrapping them in FSDPOptimizerWrapper.
"""
# construct custom tracked optimizers with FSDP optimizers
tracked_optimizers_and_schedulers = _construct_tracked_optimizers(unit)
Expand Down
43 changes: 41 additions & 2 deletions torchtnt/utils/prepare_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

from dataclasses import asdict, dataclass
from typing import Callable, Iterable, Optional, Union
from typing import Any, Callable, Dict, Iterable, Optional, Union

import torch
import torch.distributed as dist
Expand All @@ -20,7 +20,11 @@
)
from torch.nn.parallel import DistributedDataParallel as DDP
from torchtnt.utils.rank_zero_log import rank_zero_warn
from torchtnt.utils.version import is_torch_version_geq_1_12
from torchtnt.utils.version import is_torch_version_geq_1_12, is_torch_version_geq_2_0

if is_torch_version_geq_2_0():
from torch.distributed._composable_state import _get_module_state
from torch.distributed.fsdp._common_utils import _FSDPState


@dataclass
Expand Down Expand Up @@ -181,3 +185,38 @@ def prepare_fsdp(
module, state_dict_type, state_dict_config, optim_state_dict_config
)
return module


class FSDPOptimizerWrapper:
"""
Wrapper for FSDP optimizer to call specific FSDP optimizer state checkpointing APIs.
"""

def __init__(
self, module: torch.nn.Module, optimizer: torch.optim.Optimizer
) -> None:
self.module = module
self.optimizer = optimizer

def state_dict(self) -> Dict[str, Any]:
optim_state_dict = FSDP.optim_state_dict(self.module, self.optimizer)
return optim_state_dict

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
optim_state_dict = FSDP.optim_state_dict_to_load(
self.module, self.optimizer, state_dict
)
self.optimizer.load_state_dict(optim_state_dict)


def _is_fsdp_module(module: torch.nn.Module) -> bool:
if isinstance(module, FSDP):
return True

if is_torch_version_geq_2_0():
# Also check for composable FSDP API
maybe_composable_state = _get_module_state(module)
if maybe_composable_state is not None:
return isinstance(maybe_composable_state, _FSDPState)

return False

0 comments on commit a340a09

Please sign in to comment.