Skip to content

Commit

Permalink
Fuse input dist splits all2all (pytorch#1111)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1111

Fuses the splits all to all part of the input dist. We do so by overriding the KJTAllToAll forward after initialization, so fusion is only done after the first batch executes.

Reviewed By: bigning

Differential Revision: D44850123

fbshipit-source-id: d11709bdb7e42c9bb0a53e1852cd953ccf14485d
  • Loading branch information
joshuadeng authored and facebook-github-bot committed Apr 13, 2023
1 parent 9969cd1 commit a31aa97
Show file tree
Hide file tree
Showing 3 changed files with 229 additions and 37 deletions.
15 changes: 4 additions & 11 deletions torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,15 @@ class SplitsAllToAllAwaitable(Awaitable[List[List[int]]]):
def __init__(
self,
input_tensors: List[torch.Tensor],
num_workers: int,
device: torch.device,
pg: dist.ProcessGroup,
) -> None:
super().__init__()
self.num_workers: int = num_workers
self.num_workers: int = pg.size()

with record_function("## all2all_data:kjt splits ##"):
self._output_tensor: torch.Tensor = torch.empty(
[num_workers * len(input_tensors)],
device=device,
[self.num_workers * len(input_tensors)],
device=input_tensors[0].device,
dtype=input_tensors[0].dtype,
)
input_tensor = torch.stack(input_tensors, dim=1).flatten()
Expand Down Expand Up @@ -331,9 +329,7 @@ def __init__(
)
input_tensors.append(batch_size_tensor)

self._splits_awaitable = SplitsAllToAllAwaitable(
input_tensors, self._workers, self._device, self._pg
)
self._splits_awaitable = SplitsAllToAllAwaitable(input_tensors, self._pg)

def _wait_impl(self) -> KJTAllToAllTensorsAwaitable:
"""
Expand Down Expand Up @@ -430,9 +426,7 @@ def __init__(
super().__init__()
assert len(splits) == pg.size()
self._pg: dist.ProcessGroup = pg
self._workers: int = pg.size()
self._splits = splits
self._no_dist: bool = all(s == 0 for s in splits)
self._splits_cumsum: List[int] = [0] + list(itertools.accumulate(splits))
self._stagger = stagger

Expand All @@ -453,7 +447,6 @@ def forward(
"""

device = input.values().device

with torch.no_grad():
assert len(input.keys()) == sum(self._splits)
rank = dist.get_rank(self._pg)
Expand Down
140 changes: 122 additions & 18 deletions torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@

import abc
from dataclasses import dataclass, field
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union

import torch
from torch import nn
from torchrec.distributed.dist_data import KJTAllToAllTensorsAwaitable
from torch import distributed as dist, nn
from torchrec.distributed.dist_data import (
KJTAllToAllTensorsAwaitable,
SplitsAllToAllAwaitable,
)
from torchrec.distributed.embedding_types import (
BaseEmbeddingLookup,
BaseGroupedFeatureProcessor,
Expand Down Expand Up @@ -233,6 +236,32 @@ def _wait_impl(self) -> KJTList:


C = TypeVar("C", bound=Multistreamable)
T = TypeVar("T")


def _set_sharding_context(
tensors_awaitables: List[Awaitable[KeyedJaggedTensor]],
ctx: C,
) -> None:
for awaitable, sharding_context in zip(
tensors_awaitables,
getattr(ctx, "sharding_contexts", []),
):
if isinstance(awaitable, KJTAllToAllTensorsAwaitable):
if hasattr(sharding_context, "batch_size_per_rank"):
sharding_context.batch_size_per_rank = awaitable._batch_size_per_rank
if hasattr(sharding_context, "input_splits"):
sharding_context.input_splits = awaitable._input_splits["values"]
if hasattr(sharding_context, "output_splits"):
sharding_context.output_splits = awaitable._output_splits["values"]
if hasattr(sharding_context, "sparse_features_recat"):
sharding_context.sparse_features_recat = awaitable._recat


def _split(flat_list: List[T], splits: List[int]) -> List[List[T]]:
return [
flat_list[sum(splits[:i]) : sum(splits[:i]) + n] for i, n in enumerate(splits)
]


class KJTListSplitsAwaitable(Awaitable[Awaitable[KJTList]], Generic[C]):
Expand Down Expand Up @@ -267,24 +296,99 @@ def _wait_impl(self) -> KJTListAwaitable:
KJTListAwaitable: awaitables for tensors of the sparse features.
"""
tensors_awaitables = [w.wait() for w in self.awaitables]
for awaitable, sharding_context in zip(
tensors_awaitables,
getattr(self.ctx, "sharding_contexts", []),
):
if isinstance(awaitable, KJTAllToAllTensorsAwaitable):
if hasattr(sharding_context, "batch_size_per_rank"):
sharding_context.batch_size_per_rank = (
awaitable._batch_size_per_rank
)
if hasattr(sharding_context, "input_splits"):
sharding_context.input_splits = awaitable._input_splits["values"]
if hasattr(sharding_context, "output_splits"):
sharding_context.output_splits = awaitable._output_splits["values"]
if hasattr(sharding_context, "sparse_features_recat"):
sharding_context.sparse_features_recat = awaitable._recat
_set_sharding_context(tensors_awaitables, self.ctx)
return KJTListAwaitable(tensors_awaitables)


@dataclass
class KJTSplitsAllToAllMeta:
pg: dist.ProcessGroup
input: KeyedJaggedTensor
splits: List[int]
splits_tensors: List[torch.Tensor]
input_splits: List[List[int]]
input_tensors: List[torch.Tensor]
labels: List[str]
keys: List[str]
device: torch.device
stagger: int


class FusedKJTListSplitsAwaitable(Awaitable[List[KJTListAwaitable]]):
def __init__(
self,
requests: List[KJTListSplitsAwaitable[C]],
contexts: List[C],
pg: Optional[dist.ProcessGroup],
) -> None:
super().__init__()
self._contexts = contexts
self._awaitables: List[
Union[KJTSplitsAllToAllMeta, Awaitable[Awaitable[KeyedJaggedTensor]]]
] = [awaitable for request in requests for awaitable in request.awaitables]
self._output_lengths: List[int] = [
len(request.awaitables) for request in requests
]
self._lengths: List[int] = [
len(awaitable.splits_tensors)
if isinstance(awaitable, KJTSplitsAllToAllMeta)
else 0
for awaitable in self._awaitables
]
splits_tensors = [
splits_tensor
for awaitable in self._awaitables
for splits_tensor in (
awaitable.splits_tensors
if isinstance(awaitable, KJTSplitsAllToAllMeta)
else []
)
]
self._splits_awaitable: Optional[SplitsAllToAllAwaitable] = (
SplitsAllToAllAwaitable(
input_tensors=splits_tensors,
pg=pg,
)
if splits_tensors and pg
else None
)

def _wait_impl(self) -> List[KJTListAwaitable]:
if self._splits_awaitable:
splits_list = self._splits_awaitable.wait()
splits_per_awaitable = _split(splits_list, self._lengths)
else:
splits_per_awaitable = [[] for _ in range(len(self._lengths))]
tensors_awaitables = []
for splits, awaitable in zip(splits_per_awaitable, self._awaitables):
if not splits: # NoWait
tensors_awaitables.append(awaitable.wait())
continue
output_splits = splits[:-1]
batch_size_per_rank = splits[-1]
tensors_awaitables.append(
KJTAllToAllTensorsAwaitable(
pg=awaitable.pg,
input=awaitable.input,
splits=awaitable.splits,
input_splits=awaitable.input_splits,
output_splits=output_splits,
input_tensors=awaitable.input_tensors,
labels=awaitable.labels,
batch_size_per_rank=batch_size_per_rank,
keys=awaitable.keys,
device=awaitable.device,
stagger=awaitable.stagger,
)
)
output = []
awaitables_per_output = _split(tensors_awaitables, self._output_lengths)
for awaitables, ctx in zip(awaitables_per_output, self._contexts):
_set_sharding_context(awaitables, ctx)
output.append(KJTListAwaitable(awaitables))
return output


class ListOfKJTListAwaitable(Awaitable[ListOfKJTList]):
"""
This module handles the tables-wise sharding input features distribution for
Expand Down
111 changes: 103 additions & 8 deletions torchrec/distributed/train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
# LICENSE file in the root directory of this source tree.

import abc
import itertools
import logging
from collections import defaultdict
from dataclasses import dataclass, field
from typing import (
Any,
Expand All @@ -22,12 +24,20 @@
)

import torch
from torch import distributed as dist
from torch.autograd.profiler import record_function
from torch.cuda import Event
from torch.fx.node import Node
from torchrec.distributed.dist_data import KJTAllToAll
from torchrec.distributed.embedding_sharding import (
FusedKJTListSplitsAwaitable,
KJTListSplitsAwaitable,
KJTSplitsAllToAllMeta,
)
from torchrec.distributed.model_parallel import DistributedModelParallel, ShardedModule
from torchrec.distributed.types import Awaitable
from torchrec.modules.feature_processor import BaseGroupedFeatureProcessor
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.streamable import Multistreamable, Pipelineable

logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -249,6 +259,45 @@ def args(self) -> List[ArgInfo]:
return self._args


class KJTAllToAllForward:
def __init__(
self, pg: dist.ProcessGroup, splits: List[int], stagger: int = 1
) -> None:
self._pg = pg
self._splits = splits
self._stagger = stagger
self._splits_cumsum: List[int] = [0] + list(itertools.accumulate(splits))

def __call__(self, input: KeyedJaggedTensor) -> KJTSplitsAllToAllMeta:
with torch.no_grad():
assert len(input.keys()) == sum(self._splits)
rank = dist.get_rank(self._pg)
local_keys = input.keys()[
self._splits_cumsum[rank] : self._splits_cumsum[rank + 1]
]
input_splits = input.dist_splits(self._splits)
device = input.values().device
splits_tensors = [
torch.tensor(split, device=device) for split in input_splits
]
batch_size_tensor = torch.tensor(
[input.stride()] * self._pg.size(), device=device
)
splits_tensors.append(batch_size_tensor)
return KJTSplitsAllToAllMeta(
pg=self._pg,
input=input,
splits=self._splits,
splits_tensors=splits_tensors,
input_splits=input_splits,
input_tensors=input.dist_tensors(),
labels=input.dist_labels(),
keys=local_keys,
device=device,
stagger=self._stagger,
)


def _start_data_dist(
pipelined_modules: List[ShardedModule],
batch: In,
Expand Down Expand Up @@ -287,10 +336,29 @@ def _start_data_dist(
context.input_dist_requests[forward.name] = module.input_dist(
module_ctx, *args, **kwargs
)

# Call wait on the first awaitable in the input dist for the tensor splits
for key, awaitable in context.input_dist_requests.items():
context.input_dist_requests[key] = awaitable.wait()
_fuse_input_dist_splits(context)


def _fuse_input_dist_splits(context: TrainPipelineContext) -> None:
names_per_pg = defaultdict(list)
for name, request in context.input_dist_requests.items():
pg = None
if isinstance(request, KJTListSplitsAwaitable):
for awaitable in request.awaitables:
if isinstance(awaitable, KJTSplitsAllToAllMeta):
pg = awaitable.pg
break
names_per_pg[pg].append(name)

for pg, names in names_per_pg.items():
requests = FusedKJTListSplitsAwaitable(
# pyre-ignore[6]
requests=[context.input_dist_requests[name] for name in names],
contexts=[context.module_contexts[name] for name in names],
pg=pg,
).wait()
for name, request in zip(names, requests):
context.input_dist_requests[name] = request


def _get_node_args_helper(
Expand Down Expand Up @@ -464,6 +532,24 @@ def _rewrite_model( # noqa C901
return ret


def _override_input_dist_forwards(pipelined_modules: List[ShardedModule]) -> None:
"""
Overrides each input dist forward to support fusing the splits collective.
NOTE: this can only be called after the input dists are initialized.
"""
for module in pipelined_modules:
assert not module._has_uninitialized_input_dist
# pyre-ignore[29]
for input_dist in module._input_dists:
if hasattr(input_dist, "_dist"):
assert isinstance(input_dist._dist, KJTAllToAll)
input_dist._dist.forward = KJTAllToAllForward(
pg=input_dist._dist._pg,
splits=input_dist._dist._splits,
stagger=input_dist._dist._stagger,
)


class TrainPipelineSparseDist(TrainPipeline[In, Out]):
"""
This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with
Expand Down Expand Up @@ -525,12 +611,12 @@ def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
# executes last batch in pipeline
if self._batch_i and self._execute_all_batches:
return
self._init_pipelined_modules()

# batch 1
self._batch_i = self._copy_batch_to_gpu(dataloader_iter)
if self._batch_i is None:
raise StopIteration
self._init_pipelined_modules(self._batch_i)
self._sparse_data_dist(self._batch_i)

# batch 2
Expand Down Expand Up @@ -569,19 +655,28 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:

return output

def _init_pipelined_modules(self) -> None:
def _init_pipelined_modules(self, batch: In) -> None:
"""
Retrieves the pipelined modules after overriding their forwards, initializes the
modules' input dists, and overrides the input dist forwards to support fusing
the splits collective in the input dist.
"""
if self._pipelined_modules:
return
self._pipelined_modules = _rewrite_model(
self._model, self._context, self._data_dist_stream
)
# initializes input dist, so we can override input dist forwards
self._sparse_data_dist(self._batch_i)
_override_input_dist_forwards(self._pipelined_modules)

def _copy_batch_to_gpu(self, dataloader_iter: Iterator[In]) -> Optional[In]:
"""
Retrieves batch from dataloader and moves it to the provided device.
Raises StopIteration when dataloader iterator is exhausted; unless
execute_all_batches=True, then returns None.
Raises:
StopIteration: if the dataloader iterator is exhausted; unless
`self._execute_all_batches=True`, then returns None.
"""
with record_function("## copy_batch_to_gpu ##"):
with torch.cuda.stream(self._memcpy_stream):
Expand Down

0 comments on commit a31aa97

Please sign in to comment.