Skip to content

Commit

Permalink
Implement shape ops (pytorch#324)
Browse files Browse the repository at this point in the history
Implements reshape / view / transpose / permute / expand and other ops that only change the shape and layout of the tensor without changing the element's values.

This adds support for multiple mesh dimension ; however it won't succeed if communication would be triggered by the operation. This would be the case when flattening a tensor where the tensor is shared on a flatenend dimension other than the first. We will need to tackle that subsequently.
  • Loading branch information
aazzolini authored Aug 22, 2022
1 parent 415168a commit 18a676b
Show file tree
Hide file tree
Showing 6 changed files with 1,114 additions and 91 deletions.
17 changes: 17 additions & 0 deletions spmd/tensor/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,30 @@ class OutputSharding:
failed_reason: Optional[str] = None


def _reshape_alias(
x: torch.Tensor, shape: Tuple[int, ...], strides: Tuple[int, ...]
) -> torch.Tensor:
return torch.ops.aten.view(x, shape)


_CURRENT_DECOMPOSITION_TABLE: Dict[
Callable[..., object], Callable[..., object]
] = {
torch.ops.aten._reshape_alias.default: _reshape_alias,
}


def operator_dispatch(
op_call: torch._ops.OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
op_to_rules: Dict[str, Callable[[OpSchema], OutputSharding]],
custom_dispatch_ops: Dict[str, Callable[..., object]],
) -> object:
# first we need to lift some private aten aliases to public calls
if op_call in _CURRENT_DECOMPOSITION_TABLE:
with torch.overrides.enable_reentrant_dispatch():
return _CURRENT_DECOMPOSITION_TABLE[op_call](*args, **kwargs)

func_schema = FunctionSchema.parse(str(op_call._schema))
schema_kind = func_schema.kind()
Expand Down
1 change: 1 addition & 0 deletions spmd/tensor/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .dropout import * # noqa: F403
from .tp_sharding_ops import * # noqa: F403
from .pointwise_ops import * # noqa: F403
from .view_ops import * # noqa: F403
90 changes: 0 additions & 90 deletions spmd/tensor/ops/tp_sharding_ops.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# implement matrix related ops for distributed tensor
import math

import torch
import torch.utils._pytree as pytree
from typing import List
from spmd.tensor.api import DTensor
from spmd.tensor.placement_types import Shard
from spmd.tensor.utils import unwrap_local_tensor
from spmd.tensor.ops.utils import unwrap_single_placement, register_impl

Expand All @@ -18,75 +15,6 @@
"""


@register_impl("aten.view.SymInt")
@register_impl("aten.view.default")
# pyre-fixme[2]: Parameter must be annotated.
def dist_view(self: DTensor, *shape) -> DTensor:
mat_placement = pytree.tree_map(unwrap_single_placement, self)
local_mat = pytree.tree_map(unwrap_local_tensor, self)
if mat_placement.is_replicate():
return DTensor.from_local(
local_mat.view(*shape), self.device_mesh, [mat_placement]
)

elif mat_placement.is_shard():
shape = shape[0]
try:
infer_idx = shape.index(-1)
except ValueError:
infer_idx = None # type: ignore

# Infer the dim which is specified with -1.
if infer_idx is not None:
st_size = math.prod(self.size()) # type: ignore[attr-defined]
shape_size = -1 * math.prod(shape) # type: ignore[attr-defined]
# pyre-fixme[60]: Concatenation not yet support for multiple variadic
shape = (
*shape[:infer_idx],
st_size // shape_size,
*shape[infer_idx + 1 :],
)
if self.size() == shape:
return self

sharding_dim = mat_placement.dim
# When the sharding dim is negative, we need to ensure the new
# sharded tensor is still sharded by the original dimension.
if sharding_dim < 0:
sharding_dim = self.dim() + sharding_dim

world_size = self.device_mesh.size(dim=0)
if shape[sharding_dim] % world_size:
raise NotImplementedError(
f"Case when dim '({shape[sharding_dim]})' is not divisible "
"by world_size is not supported."
)
# pyre-fixme[60]: Concatenation not yet support for multiple variadic
new_local_tensor_size = (
*shape[:sharding_dim],
shape[sharding_dim] // world_size,
*shape[sharding_dim + 1 :],
)
new_local_tensor = local_mat.view(*new_local_tensor_size)
return DTensor(new_local_tensor, self.device_mesh, self.placements)
else:
raise RuntimeError("not supported!")


@register_impl("aten.transpose.int")
def dist_transpose(self: DTensor, dim0: int, dim1: int) -> DTensor:
local_mat = pytree.tree_map(unwrap_local_tensor, self)
mat_placement = pytree.tree_map(unwrap_single_placement, self)
device_mesh = self.device_mesh
new_shard_dim = (
dim1 if mat_placement.is_shard(dim=dim0) else mat_placement.dim
)
new_shard_dim = dim0 if mat_placement.is_shard(dim=dim1) else new_shard_dim
new_sharding_placement = [Shard(new_shard_dim)]
local_tensor = local_mat.transpose(dim0, dim1)
return DTensor(local_tensor, device_mesh, new_sharding_placement)


@register_impl("aten.baddbmm.default")
def dist_baddbmm(
self: DTensor,
Expand Down Expand Up @@ -118,24 +46,6 @@ def dist_softmax(self: DTensor, dim: int, half_to_float: bool) -> DTensor:
return DTensor(local_tensor, self.device_mesh, self.placements)


@register_impl("aten.permute.default")
def dist_permute(self: DTensor, dims: List[int]) -> DTensor:
local_mat = pytree.tree_map(unwrap_local_tensor, self)
mat_placement = pytree.tree_map(unwrap_single_placement, self)

if mat_placement.is_replicate():
local_tensor = torch.ops.aten.permute(local_mat, dims=dims)
return DTensor(local_tensor, self.device_mesh, [mat_placement])
elif mat_placement.is_shard():
sharding_dim = mat_placement.dim
new_sharding_dim = dims.index(sharding_dim)
new_sharding_placement = [Shard(new_sharding_dim)]
local_tensor = torch.ops.aten.permute(local_mat, dims=dims)
return DTensor(local_tensor, self.device_mesh, new_sharding_placement)
else:
raise RuntimeError("Not supported!")


@register_impl("aten.cat.default")
def dist_cat(tensor_list: List[DTensor], dim: int = 0) -> DTensor:
local_inputs = pytree.tree_map(unwrap_local_tensor, tensor_list)
Expand Down
Loading

0 comments on commit 18a676b

Please sign in to comment.