Skip to content

Commit

Permalink
General fixes for ShardedTensor op framework.
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#77191

1) Add more basic validation to all ops.
2) Ensure `register_on_local_shards` uses appropriate sharding_spec.

Differential Revision: [D36292159](https://our.internmc.facebook.com/intern/diff/D36292159/)

Approved by: https://github.com/wanchaol, https://github.com/fduwjj
  • Loading branch information
pritamdamania87 authored and pytorchmergebot committed May 11, 2022
1 parent 420b49c commit b91a149
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 15 deletions.
9 changes: 6 additions & 3 deletions torch/distributed/_shard/sharded_tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,11 +441,14 @@ def sharded_op_impl(func):
implementation (ex: torch.nn.functional.linear)
"""
def decorator_sharded_func(wrapped_func):
_register_sharded_op(func, wrapped_func)
from torch.distributed._shard.sharded_tensor._ops._common import _basic_validation

@functools.wraps(wrapped_func)
def wrapper(*args, **kwargs):
return wrapped_func(*args, **kwargs)
def wrapper(types, args, kwargs, process_group):
_basic_validation(func, args, kwargs)
return wrapped_func(types, args, kwargs, process_group)

_register_sharded_op(func, wrapper)
return wrapper
return decorator_sharded_func

Expand Down
64 changes: 56 additions & 8 deletions torch/distributed/_shard/sharded_tensor/_ops/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,56 @@
Shard,
ShardedTensor,
)
from torch.distributed._shard.partial_tensor import _PartialTensor
from torch.distributed._shard.replicated_tensor import ReplicatedTensor
from torch.utils._pytree import tree_map

def _basic_validation(op, args=(), kwargs=None):
"""
Common validation across all ops go in here.
"""
if len(args) == 0 and (kwargs is None or len(kwargs) == 0):
raise ValueError(f" No input for '{op.__name__}'!")

# Validate types
has_distributed_tensor = False

def is_distributed_tensor(e):
nonlocal has_distributed_tensor
if isinstance(e, ReplicatedTensor) or isinstance(e, _PartialTensor) or isinstance(e, ShardedTensor):
has_distributed_tensor = True

tree_map(is_distributed_tensor, args)
tree_map(is_distributed_tensor, kwargs)

if not has_distributed_tensor:
raise TypeError(
f"torch function '{op.__name__}', with args: {args} and "
f"kwargs: {kwargs} are called without any distributed tensor!"
)

# Validate all DistributedTensors use the same PG.
cur_pg = None

def validate_pg(e):
nonlocal cur_pg
if isinstance(e, ReplicatedTensor) or isinstance(e, _PartialTensor):
if cur_pg is not None and e.process_group is not cur_pg:
raise RuntimeError(
'All distributed tensors should use the '
'same ProcessGroup if used together in an op.'
)
cur_pg = e.process_group
elif isinstance(e, ShardedTensor):
if cur_pg is not None and e._process_group is not cur_pg:
raise RuntimeError(
'All distributed tensors should use the '
'same ProcessGroup if used together in an op.'
)
cur_pg = e._process_group

tree_map(validate_pg, args)
tree_map(validate_pg, kwargs)

def _sharded_op_common(op, early_stop_func, extra_check):
"""
Expand Down Expand Up @@ -35,15 +85,9 @@ def _sharded_op_common(op, early_stop_func, extra_check):
def decorator_sharded_func(wrapped_func):
@functools.wraps(wrapped_func)
def wrapper(types, args=(), kwargs=None, pg=None):
if len(args) == 0:
raise ValueError(f" No input for '{op.__name__}'!")
# Validate types
_basic_validation(op, args, kwargs)

st = args[0]
if not isinstance(st, ShardedTensor):
raise TypeError(
f"torch function '{op.__name__}', with args: {args} and "
f"kwargs: {kwargs} are called for non ShardedTensor!"
)
if kwargs is None:
kwargs = {}
if extra_check:
Expand All @@ -69,6 +113,9 @@ def _register_sharded_op_on_local_shards(
For more complicated ops, a customized func can be used to generate
the new shards and sharded tensor size.
This function expects that the original ShardingSpec for the ShardedTensor
is preserved irrespective of whether or not a customized function is used.
Args:
op: The op to be registered and applied to all shards of the st.
early_stop_func (Callable, optional): the func for early stop.
Expand Down Expand Up @@ -104,4 +151,5 @@ def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None):
st_metadata,
process_group=pg,
init_rrefs=st._init_rrefs,
sharding_spec=st.sharding_spec()
)
1 change: 0 additions & 1 deletion torch/distributed/_shard/sharded_tensor/_ops/math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
sharded_op_impl
)
from torch.distributed._shard.replicated_tensor import ReplicatedTensor

from torch.distributed._shard._utils import narrow_tensor

def binary_math_op_impl(op, types, args=(), kwargs=None, pg=None):
Expand Down
10 changes: 7 additions & 3 deletions torch/distributed/_shard/sharding_spec/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,16 @@ def custom_sharding_spec_op(sharding_spec_class, func):
func(Callable): The op to override (ex: torch.bmm)
"""
def decorator_sharded_func(wrapped_func):
_register_custom_op(sharding_spec_class, func, wrapped_func)
from torch.distributed._shard.sharded_tensor._ops._common import _basic_validation

@functools.wraps(wrapped_func)
def wrapper(*args, **kwargs):
return wrapped_func(*args, **kwargs)
def wrapper(types, args, kwargs):
_basic_validation(func, args, kwargs)
return wrapped_func(types, args, kwargs)

_register_custom_op(sharding_spec_class, func, wrapper)
return wrapper

return decorator_sharded_func


Expand Down

0 comments on commit b91a149

Please sign in to comment.