Skip to content

Commit

Permalink
[SPMD] Support SymInt with non-op call_function nodes (pytorch#99420)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#99420
Approved by: https://github.com/fegin
  • Loading branch information
mrshenli authored and pytorchmergebot committed Apr 19, 2023
1 parent 7c0c663 commit 2922961
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 53 deletions.
13 changes: 13 additions & 0 deletions test/distributed/_spmd/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,19 @@ def forward(self, x):

self._test_op_with_train_step(Model)

@skip_if_lt_x_gpu(2)
@with_comms
def test_arithmetic_ops_on_symint(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 10)

def forward(self, x):
return self.fc(x) + x.shape[0] * x.numel() - x.shape[0] // 2

self._test_op_with_train_step(Model)


if __name__ == "__main__":
run_tests()
140 changes: 87 additions & 53 deletions torch/distributed/_spmd/distribute.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import operator
from dataclasses import dataclass
from enum import auto, Enum
from functools import partial
Expand Down Expand Up @@ -70,43 +71,33 @@ class Schema:
@dataclass
class DSymInt:
"""
DSymInt represents a value retrieved by a SymInt op
DSymInt represents a value retrieved by a SymInt op from a DTensor. DSymInt
helps View and Factory ops to determine the placement and shape of the
output tensor, as those operators either do not have an input DTensor or
the input DTensor is insufficient to determine the output tensor's placement.
"""

value: int # value that the SymInt evaluates to
op: torch._ops.OpOverloadPacket # one of {sym_size, sym_numel}
tensor: DTensor # DTensor this SymInt was extracted from
dim: Optional[int] = None # dimension the SymInt was extracted from

@property
def local_value(self) -> int:
with torch.no_grad():
if self.op == aten.sym_size:
assert self.dim is not None
return self.tensor.to_local().size(self.dim)
elif self.op == aten.sym_numel:
return self.tensor.to_local().numel()
else:
raise NotImplementedError(f"Unsupported SymInt op {self.op}")
global_value: int # value that the SymInt evaluates to
local_value: int # vaue that this SymInt evaluates to on the local shard
mesh: DeviceMesh # device mesh of the DTensor where this SymInt is retrieved from

def is_shard(self) -> bool:
return any(p.is_shard(self.dim) for p in self.tensor.placements)
return self.local_value != self.global_value

@classmethod
def from_node(cls, node: fx.Node, dtensor: DTensor) -> "DSymInt":
if node.target == aten.sym_size:
dim: int = cast(int, node.args[1])
return cls(
value=dtensor.size(dim),
tensor=dtensor,
op=cast(torch._ops.OpOverloadPacket, node.target),
dim=dim,
global_value=dtensor.size(dim),
local_value=dtensor.to_local().size(dim),
mesh=dtensor.device_mesh,
)
elif node.target == aten.sym_numel:
return cls(
value=dtensor.numel(),
tensor=dtensor,
op=cast(torch._ops.OpOverloadPacket, node.target),
global_value=dtensor.numel(),
local_value=dtensor.to_local().numel(),
mesh=dtensor.device_mesh,
)
else:
raise NotImplementedError(f"DSymInt does not support {node.target}")
Expand Down Expand Up @@ -224,23 +215,24 @@ def _remap_arg(node_to_obj: Dict[fx.Node, Any], arg: Any) -> Any:
return arg


def unpack_size_and_sharded_dims(
def unpack_sizes_and_dims(
sizes: List[Union[DSymInt, int]], mesh: DeviceMesh
) -> Tuple[List[int], List[Placement]]:
local_sizes: List[int] = [
s.local_value if isinstance(s, DSymInt) else s for s in sizes
]
sharded_placements: List[Placement] = [
placements: List[Placement] = [
Shard(i)
for i, a in enumerate(sizes)
if (isinstance(a, DSymInt) and a.is_shard())
]
assert len(sharded_placements) == mesh.ndim, (
f"The number of sharded dimensions ({len(sharded_placements)}) must "
] or [Replicate()]

assert len(placements) == mesh.ndim, (
f"The number of sharded dimensions ({len(placements)}) must "
f"match number of dimensions in device mesh ({mesh.ndim})."
)

return local_sizes, sharded_placements
return local_sizes, placements


def binop_sym_int_consumer_rule(node: fx.Node, args: Tuple[Any, ...]) -> DTensor:
Expand All @@ -252,17 +244,15 @@ def binop_sym_int_consumer_rule(node: fx.Node, args: Tuple[Any, ...]) -> DTensor

# extract sharded dimensions in the size list, the output DTensor should
# follow these placements.
local_sizes, sharded_placements = unpack_size_and_sharded_dims(
args[1], args[0].device_mesh
)
local_sizes, placements = unpack_sizes_and_dims(args[1], args[0].device_mesh)

# set node args to real int sizes.
node.args = (node.args[0], local_sizes)
op = cast(torch._ops.OpOverload, node.target)
return DTensor.from_local(
local_tensor=op(args[0]._local_tensor, local_sizes),
device_mesh=args[0].device_mesh,
placements=sharded_placements,
placements=placements,
run_check=False,
)

Expand All @@ -280,15 +270,13 @@ def factory_with_sizes_rule(
)
assert isinstance(args[0], list), f"Expect 2nd argument as list but got {args[1]}"

local_sizes, sharded_placements = unpack_size_and_sharded_dims(
args[0], default_mesh
)
local_sizes, placements = unpack_sizes_and_dims(args[0], default_mesh)
node.args = (local_sizes, *args[1:])
op = cast(torch._ops.OpOverload, node.target)
return DTensor.from_local(
local_tensor=op(*node.args, **kwargs),
device_mesh=default_mesh,
placements=sharded_placements,
placements=placements,
run_check=False,
)

Expand Down Expand Up @@ -325,9 +313,11 @@ def default_factory_op_rule(
)


# Dispatch override for ops that consume SymInt arguments, where the output
# spec should follow dimension placement where the SymInt comes from.
SYM_INT_CONSUMERS: Dict[torch._ops.OpOverload, Callable] = {
# Dispatch override for view and factory ops that consume SymInt arguments,
# where the output spec should follow dimension placement where the SymInt comes
# from.
VIEW_SYM_INT_CONSUMERS: Dict[torch._ops.OpOverload, Callable] = {
aten._unsafe_view.default: binop_sym_int_consumer_rule,
aten.expand.default: binop_sym_int_consumer_rule,
aten.view.default: binop_sym_int_consumer_rule,
}
Expand All @@ -338,6 +328,9 @@ def default_factory_op_rule(
aten.arange.start: factory_arange_rule,
}


# Dispatch override for factory ops, as DTensor cannot propogate sharding spec
# without DTensor inputs.
FACTORY_OPS: Dict[torch._ops.OpOverload, Callable] = {
aten.scalar_tensor.default: default_factory_op_rule,
}
Expand All @@ -358,24 +351,24 @@ def _get_dtensor_dispatch_graph(
op_overload = cast(torch._ops.OpOverload, node.target)

if any(a.is_shard() for a in tree_flatten(args)[0] if isinstance(a, DSymInt)):
if op_overload in SYM_INT_CONSUMERS:
if op_overload in VIEW_SYM_INT_CONSUMERS:
assert len(kwargs) == 0, f"Expect empty kwargs, but got {kwargs}"
node_to_obj[node] = SYM_INT_CONSUMERS[op_overload](node, args)
# skip DTensor expansion
node_to_obj[node] = VIEW_SYM_INT_CONSUMERS[op_overload](node, args)
return None
elif op_overload in FACTORY_SYM_INT_CONSUMERS:
assert default_mesh is not None, "Requires default mesh for factory ops"
node_to_obj[node] = FACTORY_SYM_INT_CONSUMERS[op_overload](
node, args, kwargs, default_mesh
)
return None
else:
# If an operator consumes SymInt sizes on a sharded dimension, we
# override with callables in SYM_INT_CONSUMERS or FACTORY_OPS to
# create DTensor activations.
raise NotImplementedError(
f"{op_overload} consumes SymInt args from a sharded dimension, "
"but SPMD expansion does not support this use case."
assert isinstance(logger, logging.Logger)
logger.warning(
"Assuming using local_value from SymInt for %s"
"is mathematically correct. Full args are %s.",
op_overload,
args,
)
return None

if node.target == aten.view.default:
# HACK: this is a hack to get around with the fact that some
Expand All @@ -392,6 +385,8 @@ def _get_dtensor_dispatch_graph(
)

if op_overload in FACTORY_OPS:
# Don't pass factory ops to DTensor dispatch, as DTensor cannot
# propagate sharding spec without DTensor inputs.
node_to_obj[node] = FACTORY_OPS[op_overload](
node, args, kwargs, default_mesh
)
Expand Down Expand Up @@ -689,6 +684,9 @@ def _convert_to_distributed(
"""
global logger
logger = get_logger("spmd_exp")
operators = {
getattr(operator, name) for name in dir(operator) if not name.startswith("_")
}
node_to_obj: Dict[fx.Node, Any] = {}
# map local op node in traced_f to its corresponding subgraph of
# DTensor ops.
Expand Down Expand Up @@ -740,11 +738,47 @@ def _convert_to_distributed(
output_schemas[inp_arg.name] = Schema(
obj.device_mesh, obj.placements # type: ignore[arg-type]
)

elif node.op == OP.CALL_FUNCTION:
args = tree_map(partial(_remap_arg, node_to_obj), node.args)
kwargs = tree_map(partial(_remap_arg, node_to_obj), node.kwargs)
node_to_obj[node] = node.target(*args, **kwargs)

dsymints = list(
filter(lambda a: isinstance(a, DSymInt), args + tuple(kwargs.values()))
)

if node.target in operators and len(dsymints) > 0:
assert all(
dsymints[0].mesh == d.mesh for d in dsymints
), "all DSymInts must have the same mesh. "

local_args = tree_map(
lambda a: a.local_value if isinstance(a, DSymInt) else a, args
)
local_kwargs = tree_map(
lambda a: a.local_value if isinstance(a, DSymInt) else a, kwargs
)

global_args = tree_map(
lambda a: a.global_value if isinstance(a, DSymInt) else a, args
)
global_kwargs = tree_map(
lambda a: a.global_value if isinstance(a, DSymInt) else a, kwargs
)

node.args = local_args
node.kwargs = local_kwargs

node_to_obj[node] = DSymInt(
local_value=node.target(*local_args, **local_kwargs),
global_value=node.target(*global_args, **global_kwargs),
mesh=dsymints[0].mesh,
)
else:
assert len(dsymints) == 0, (
"SPMD expansion does not support SymInt in non-operator "
f"nodes, got {node.target}."
)
node_to_obj[node] = node.target(*args, **kwargs)
else:
raise ValueError(f"Unrecognized node.op type {node.op}")

Expand Down

0 comments on commit 2922961

Please sign in to comment.