Skip to content

Commit

Permalink
[FSDP] ufmt flat_param.py, flatten_params_wrapper.py (pytorch#83664)
Browse files Browse the repository at this point in the history
I think we can move FSDP code to start using ufmt (https://ufmt.omnilib.dev/en/stable/) to unify formatting across developers. ufmt is the recommended formatter for PyTorch's Python code. If we have consensus, I can ufmt all of the FSDP code in follow-ups.
Pull Request resolved: pytorch#83664
Approved by: https://github.com/rohan-varma
  • Loading branch information
awgu authored and pytorchmergebot committed Aug 31, 2022
1 parent 040263d commit 84ceebe
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 41 deletions.
139 changes: 100 additions & 39 deletions torch/distributed/fsdp/flat_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,17 @@
from torch import Tensor

__all__ = [
"FlatParameter", "FlatParamHandle", "FlatParamShardMetadata",
"ParamInfo", "SharedParamInfo",
"FlatParameter",
"FlatParamHandle",
"FlatParamShardMetadata",
"ParamInfo",
"SharedParamInfo",
]


class ParamInfo(NamedTuple):
"""Information for an original module parameter."""

param_name: str # unprefixed
module: nn.Module
module_name: str
Expand All @@ -40,6 +44,7 @@ class SharedParamInfo(NamedTuple):
in the parameter walk. These are prefixed with "prim". The primary module
and parameter do not have their own :class:`SharedParamInfo` instance.
"""

param_name: str # unprefixed
module: nn.Module
module_name: str
Expand All @@ -64,6 +69,7 @@ class FlatParamShardMetadata(NamedTuple):
units of numels) giving this rank's part of each flattened
original module parameter.
"""

param_names: Tuple[str, ...]
param_shapes: Tuple[torch.Size, ...]
param_numels: Tuple[int, ...]
Expand Down Expand Up @@ -183,6 +189,7 @@ class FlatParamHandle:
be the top-level module, while for recursive wrapping, this may not
necessarily be the top-level module.
"""

def __init__(
self,
params: Sequence[nn.Parameter],
Expand Down Expand Up @@ -212,8 +219,9 @@ def _init_flat_param(
"""
params_set = set(params)
params_set.discard(None)
assert len(params_set) > 0, \
"Cannot initialize a `FlatParameter` from an empty parameter list"
assert (
len(params_set) > 0
), "Cannot initialize a `FlatParameter` from an empty parameter list"
param_infos: List[ParamInfo] = []
numels: List[int] = []
shapes: List[torch.Size] = []
Expand All @@ -228,11 +236,19 @@ def _init_flat_param(
if param not in params_set:
continue
if param in shared_param_memo:
prim_module, prim_module_name, prim_param_name = shared_param_memo[param]
shared_param_infos.append(SharedParamInfo(
param_name, submodule, submodule_name, prim_param_name,
prim_module, prim_module_name,
))
prim_module, prim_module_name, prim_param_name = shared_param_memo[
param
]
shared_param_infos.append(
SharedParamInfo(
param_name,
submodule,
submodule_name,
prim_param_name,
prim_module,
prim_module_name,
)
)
else:
if isinstance(param, FlatParameter):
raise ValueError("`FlatParameter` does not support nesting")
Expand All @@ -241,22 +257,36 @@ def _init_flat_param(
"`FlatParameter` requires uniform dtype but got "
f"{dtype} and {param.dtype}"
)
if requires_grad is not None and param.requires_grad != requires_grad:
raise ValueError("`FlatParameter` requires uniform `requires_grad`")
if (
requires_grad is not None
and param.requires_grad != requires_grad
):
raise ValueError(
"`FlatParameter` requires uniform `requires_grad`"
)
dtype = param.dtype
requires_grad = param.requires_grad
shared_param_memo[param] = (submodule, submodule_name, param_name)
params_to_flatten.append(param)
param_infos.append(ParamInfo(param_name, submodule, submodule_name))
numels.append(param.numel())
shapes.append(param.shape)
prefixed_param_name = submodule_name + "." + param_name \
if submodule_name else param_name
prefixed_param_name = (
submodule_name + "." + param_name
if submodule_name
else param_name
)
prefixed_param_names.append(prefixed_param_name)
assert requires_grad is not None
self.flat_param = FlatParamHandle.flatten_params(params_to_flatten, requires_grad)
self.flat_param = FlatParamHandle.flatten_params(
params_to_flatten, requires_grad
)
self.flat_param.init_metadata(
param_infos, numels, shapes, prefixed_param_names, shared_param_infos,
param_infos,
numels,
shapes,
prefixed_param_names,
shared_param_infos,
)

@staticmethod
Expand All @@ -275,8 +305,8 @@ def flatten_params(
"""
with torch.no_grad():
flat_params = [
p.detach().reshape(-1) if isinstance(p, nn.Parameter)
else p.reshape(-1) for p in params
p.detach().reshape(-1) if isinstance(p, nn.Parameter) else p.reshape(-1)
for p in params
]
flat_param_data = torch.cat(flat_params, dim=0)
flat_param = FlatParameter(flat_param_data, requires_grad=requires_grad)
Expand All @@ -298,12 +328,15 @@ def _get_unflat_views(
"""
if tensor is None:
tensor = flat_param
assert tensor.numel() == flat_param._unsharded_size.numel(), \
f"Expects {flat_param._unsharded_size.numel()} numel but got " \
assert tensor.numel() == flat_param._unsharded_size.numel(), (
f"Expects {flat_param._unsharded_size.numel()} numel but got "
f"{tensor.numel()} numel"
)
views = (
subtensor.view(shape) for (subtensor, shape) in
zip(torch.split(tensor, flat_param._numels, dim=0), flat_param._shapes) # type: ignore[arg-type]
subtensor.view(shape)
for (subtensor, shape) in zip(
torch.split(tensor, flat_param._numels, dim=0), flat_param._shapes # type: ignore[arg-type]
)
)
return views

Expand All @@ -327,7 +360,14 @@ def _unflatten(self, as_params: bool) -> None:
module.register_parameter(param_name, nn.Parameter(view))
else:
setattr(module, param_name, view)
for (param_name, module, _, prim_param_name, prim_module, _) in self.flat_param._shared_param_infos:
for (
param_name,
module,
_,
prim_param_name,
prim_module,
_,
) in self.flat_param._shared_param_infos:
if hasattr(module, param_name):
delattr(module, param_name)
assert hasattr(prim_module, prim_param_name)
Expand Down Expand Up @@ -378,8 +418,11 @@ def init_shard_metadata(
)
start = sharded_flat_param_numel * rank
end = sharded_flat_param_numel * (rank + 1) - 1 # inclusive
self.flat_param._shard_param_offsets, self.flat_param._shard_indices = ( # type: ignore[attr-defined]
self._get_shard_metadata(start, end)
(
self.flat_param._shard_param_offsets, # type: ignore[attr-defined]
self.flat_param._shard_indices, # type: ignore[attr-defined]
) = self._get_shard_metadata(
start, end
)
self.flat_param._shard_numel_padded = numel_padded # type: ignore[attr-defined]

Expand Down Expand Up @@ -421,16 +464,21 @@ def _get_shard_metadata(
intra_param_start = start - param_start
intra_param_end = min(param_end, end) - param_start
shard_param_indices_range.append(i)
shard_param_offsets.append((intra_param_start, intra_param_end)) # both inclusive
shard_param_offsets.append(
(intra_param_start, intra_param_end)
) # both inclusive
if len(shard_param_indices_range) == 0:
shard_param_indices = (0, 0)
assert len(shard_param_offsets) == 0
else:
shard_param_indices = (
shard_param_indices_range[0], shard_param_indices_range[-1],
shard_param_indices_range[0],
shard_param_indices_range[-1],
)
assert (
len(shard_param_offsets)
== shard_param_indices[-1] - shard_param_indices[0] + 1
)
assert len(shard_param_offsets) == \
shard_param_indices[-1] - shard_param_indices[0] + 1
return tuple(shard_param_offsets), shard_param_indices

@staticmethod
Expand All @@ -455,7 +503,9 @@ def _get_unpadded_shard(
else:
chunk = chunks[rank]
numel_to_pad = chunks[0].numel() - chunk.numel()
assert numel_to_pad >= 0, "Chunk's size should be at most the first chunk's size"
assert (
numel_to_pad >= 0
), "Chunk's size should be at most the first chunk's size"
return chunk, numel_to_pad

@staticmethod
Expand All @@ -471,7 +521,9 @@ def _get_shard(
This method allocates new memory (via :meth:`clone`) since the
unsharded ``tensor`` may be deallocated after this method returns.
"""
chunk, numel_to_pad = FlatParamHandle._get_unpadded_shard(tensor, rank, world_size)
chunk, numel_to_pad = FlatParamHandle._get_unpadded_shard(
tensor, rank, world_size
)
shard = chunk.clone()
if numel_to_pad > 0:
shard = F.pad(shard, [0, numel_to_pad])
Expand All @@ -485,8 +537,8 @@ def _get_sharded_size(tensor: Tensor, rank: int, world_size: int) -> torch.Size:
shape is 1D.
"""
assert len(tensor.shape) == 1, f"{tensor.shape}"
unpadded_sharded_tensor, numel_to_pad = (
FlatParamHandle._get_unpadded_shard(tensor, rank, world_size)
unpadded_sharded_tensor, numel_to_pad = FlatParamHandle._get_unpadded_shard(
tensor, rank, world_size
)
unpadded_sharded_size = unpadded_sharded_tensor.size()
assert len(unpadded_sharded_size) == 1, f"{unpadded_sharded_size}"
Expand All @@ -506,13 +558,16 @@ def shard_metadata(
) -> FlatParamShardMetadata:
"""Returns shard-related metadata specific to this rank's shard of the
flattened parameter."""
assert hasattr(self.flat_param, "_shard_indices") and \
hasattr(self.flat_param, "_shard_param_offsets"), \
"Shard metadata has not been initialized"
assert hasattr(self.flat_param, "_shard_indices") and hasattr(
self.flat_param, "_shard_param_offsets"
), "Shard metadata has not been initialized"
shard_param_start_index = self.flat_param._shard_indices[0] # type: ignore[attr-defined]
shard_param_end_index = self.flat_param._shard_indices[1] # type: ignore[attr-defined]
sl = slice(shard_param_start_index, shard_param_end_index + 1) \
if shard_param_start_index <= shard_param_end_index else slice(0, 0)
sl = (
slice(shard_param_start_index, shard_param_end_index + 1)
if shard_param_start_index <= shard_param_end_index
else slice(0, 0)
)
return FlatParamShardMetadata(
self.flat_param._prefixed_param_names[sl],
self.flat_param._shapes[sl],
Expand All @@ -534,8 +589,14 @@ def _get_modules(self) -> Set[nn.Module]:
def parameter_module_names(self) -> Iterator[Tuple[str, str]]:
shared_param_infos = [
ParamInfo(param_name, module, module_name)
for (param_name, module, module_name, _, _, _)
in self.flat_param._shared_param_infos
for (
param_name,
module,
module_name,
_,
_,
_,
) in self.flat_param._shared_param_infos
]
for param_name, _, module_name in chain(
self.flat_param._param_infos, shared_param_infos
Expand Down
6 changes: 4 additions & 2 deletions torch/distributed/fsdp/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class FlattenParamsWrapper(nn.Module):
_flat_param_handle (FlatParamHandle): A handle for the flattened
parameter; only present if this wrapper manages parameters.
"""

def __init__(
self,
module: nn.Module,
Expand Down Expand Up @@ -106,9 +107,10 @@ def has_params(self) -> bool:

@property
def handle(self) -> FlatParamHandle:
assert hasattr(self, "_flat_param_handle"), \
"Accessing the handle of a `FlattenParamsWrapper` that does not " \
assert hasattr(self, "_flat_param_handle"), (
"Accessing the handle of a `FlattenParamsWrapper` that does not "
"manage any parameters"
)
return self._flat_param_handle

@property
Expand Down

0 comments on commit 84ceebe

Please sign in to comment.