Skip to content

Commit

Permalink
Format DTensor dispatch.py and _meta_registrations.py (pytorch#98114)
Browse files Browse the repository at this point in the history
Format-only changes with black and lintrunner to prepare for the commit on top.

Differential Revision: [D44603809](https://our.internmc.facebook.com/intern/diff/D44603809)
Pull Request resolved: pytorch#98114
Approved by: https://github.com/yifuwang, https://github.com/fegin
  • Loading branch information
mrshenli authored and pytorchmergebot committed Apr 1, 2023
1 parent 64077ce commit bccf2ef
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 33 deletions.
116 changes: 94 additions & 22 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,14 @@ def meta_randint(

@register_meta(aten.randint.low)
def meta_randint_low(
low, high, size, *, dtype=torch.long, layout=None, device=None, pin_memory=None
low,
high,
size,
*,
dtype=torch.long,
layout=None,
device=None,
pin_memory=None,
):
return torch.empty(
size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
Expand Down Expand Up @@ -187,7 +194,8 @@ def meta_angle(self):
result_dtype = corresponding_real_dtype(self.dtype)
else:
_, result_dtype = elementwise_dtypes(
self, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
self,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
return torch.empty_like(self, dtype=result_dtype)

Expand Down Expand Up @@ -287,7 +295,10 @@ def linalg_inv_ex_meta(A: Tensor, check_errors: bool = False):
# NOTE: matching defaults in aten/src/ATen/native/native_functions.yaml
@register_meta(aten._linalg_svd.default)
def _linalg_svd_meta(
A: Tensor, full_matrices: bool = False, compute_uv: bool = True, driver: str = None
A: Tensor,
full_matrices: bool = False,
compute_uv: bool = True,
driver: str = None,
):
checkIsMatrix(A, "linalg.svd")
checkFloatingOrComplex(A, "linalg.svd")
Expand Down Expand Up @@ -333,7 +344,10 @@ def _linalg_det_meta(A):

# From aten/src/ATen/native/ReflectionPad.cpp
@register_meta(
[aten.reflection_pad2d_backward.default, aten.replication_pad2d_backward.default]
[
aten.reflection_pad2d_backward.default,
aten.replication_pad2d_backward.default,
]
)
def meta_pad2d_backward(grad_output, self, padding):
dim_w = 2
Expand Down Expand Up @@ -831,7 +845,10 @@ def unpack(name, val):
else:
size = [nbatch, nInputPlane, outputHeight, outputWidth]
return torch.empty(
size, dtype=input.dtype, device=input.device, memory_format=memory_format
size,
dtype=input.dtype,
device=input.device,
memory_format=memory_format,
)


Expand Down Expand Up @@ -946,7 +963,10 @@ def meta_avg_pool2d_backward(
)

return torch.empty(
input_size, dtype=input.dtype, device=input.device, memory_format=mem_format
input_size,
dtype=input.dtype,
device=input.device,
memory_format=mem_format,
)


Expand All @@ -961,7 +981,10 @@ def meta_adaptive_avg_pool2d(self, output_size):
# need to set memory_format to preserve the memory format of the input
# channel last input should have channel last output
return torch.empty(
output_shape, dtype=self.dtype, device=self.device, memory_format=memory_format
output_shape,
dtype=self.dtype,
device=self.device,
memory_format=memory_format,
)


Expand Down Expand Up @@ -1196,7 +1219,10 @@ def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1):
]
)
def meta__foreach_unaop_(self):
check(isinstance(self, List), lambda: f"Expect List[Tensor] but got {type(self)}")
check(
isinstance(self, List),
lambda: f"Expect List[Tensor] but got {type(self)}",
)


@register_meta(
Expand All @@ -1207,7 +1233,10 @@ def meta__foreach_unaop_(self):
]
)
def meta__foreach_unaop(self):
check(isinstance(self, List), lambda: f"Expect List[Tensor] but got {type(self)}")
check(
isinstance(self, List),
lambda: f"Expect List[Tensor] but got {type(self)}",
)
return [torch.empty_like(s) for s in self]


Expand Down Expand Up @@ -1425,7 +1454,8 @@ def meta_embedding_bag(
num_bags = offsets.size(0)
if include_last_offset:
check(
num_bags >= 1, lambda: "include_last_offset: numBags should be at least 1"
num_bags >= 1,
lambda: "include_last_offset: numBags should be at least 1",
)
num_bags -= 1

Expand Down Expand Up @@ -1867,9 +1897,20 @@ def unpack(name, val):

@register_meta(aten.max_pool2d_with_indices_backward.default)
def meta_max_pool2d_with_indices_backward(
grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices
grad_output,
self,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
indices,
):
nInputPlane, outputHeight, outputWidth = max_pool2d_checks_and_compute_shape(
(
nInputPlane,
outputHeight,
outputWidth,
) = max_pool2d_checks_and_compute_shape(
self, kernel_size, stride, padding, dilation, ceil_mode
)

Expand All @@ -1891,15 +1932,22 @@ def _check_dim_size(t):

memory_format = utils.suggest_memory_format(self)
return torch.empty(
self.shape, dtype=self.dtype, device=self.device, memory_format=memory_format
self.shape,
dtype=self.dtype,
device=self.device,
memory_format=memory_format,
)


@register_meta(aten.max_pool2d_with_indices.default)
def meta_max_pool2d_with_indices(
input, kernel_size, stride=(), padding=(0,), dilation=(1,), ceil_mode=False
):
nInputPlane, outputHeight, outputWidth = max_pool2d_checks_and_compute_shape(
(
nInputPlane,
outputHeight,
outputWidth,
) = max_pool2d_checks_and_compute_shape(
input, kernel_size, stride, padding, dilation, ceil_mode
)

Expand All @@ -1911,10 +1959,16 @@ def meta_max_pool2d_with_indices(
size = [nbatch, nInputPlane, outputHeight, outputWidth]
return (
torch.empty(
size, dtype=input.dtype, device=input.device, memory_format=memory_format
size,
dtype=input.dtype,
device=input.device,
memory_format=memory_format,
),
torch.empty(
size, dtype=torch.int64, device=input.device, memory_format=memory_format
size,
dtype=torch.int64,
device=input.device,
memory_format=memory_format,
),
)

Expand Down Expand Up @@ -1960,7 +2014,12 @@ def meta_like(self, *args, **kwargs):
# zeros_like is special cased to work for sparse
@register_meta(aten.zeros_like.default)
def zeros_like(
self, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None
self,
dtype=None,
layout=None,
device=None,
pin_memory=None,
memory_format=None,
):
if layout == torch.sparse_coo:
check(
Expand Down Expand Up @@ -1999,7 +2058,9 @@ def zeros_like(
def meta_select(self, dim, index):
ndim = self.dim()
check(
ndim != 0, lambda: "select() cannot be applied to a 0-dim tensor.", IndexError
ndim != 0,
lambda: "select() cannot be applied to a 0-dim tensor.",
IndexError,
)

dim = dim if dim >= 0 else dim + ndim
Expand Down Expand Up @@ -2438,7 +2499,11 @@ def meta__scaled_dot_product_efficient_backward(
if grad_kv_needs_init
else torch.empty(value.shape, dtype=value.dtype, device=value.device)
)
return grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2)
return (
grad_q.transpose(1, 2),
grad_k.transpose(1, 2),
grad_v.transpose(1, 2),
)


@register_meta([aten.scatter_reduce.two, aten.scatter_reduce.two_out])
Expand Down Expand Up @@ -2565,7 +2630,12 @@ def upsample_nearest3d(input, output_size, scales_d=None, scales_h=None, scales_


@register_meta(
[aten.sort.default, aten.sort.stable, aten.sort.values, aten.sort.values_stable]
[
aten.sort.default,
aten.sort.stable,
aten.sort.values,
aten.sort.values_stable,
]
)
def meta_sort(self, stable=None, dim=-1, descending=False, values=None, indices=None):
v, i = torch.empty_like(self), torch.empty_like(self, dtype=torch.int64)
Expand Down Expand Up @@ -2913,10 +2983,12 @@ def _amp_foreach_non_finite_check_and_unscale_(self, found_inf, inv_scale):
check(found_inf.numel() == 1, lambda: "found_inf must be a 1-element tensor.")
check(inv_scale.numel() == 1, lambda: "inv_scale must be a 1-element tensor.")
check(
found_inf.dtype.is_floating_point, lambda: "found_inf must be a float tensor."
found_inf.dtype.is_floating_point,
lambda: "found_inf must be a float tensor.",
)
check(
inv_scale.dtype.is_floating_point, lambda: "inv_scale must be a float tensor."
inv_scale.dtype.is_floating_point,
lambda: "inv_scale must be a float tensor.",
)


Expand Down
32 changes: 21 additions & 11 deletions torch/distributed/_tensor/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def wrap(res: object, spec: OutputSpecType) -> object:
shape=s.tensor_meta.shape,
dtype=s.tensor_meta.dtype,
requires_grad=s.tensor_meta.requires_grad,
stride=s.tensor_meta.stride
stride=s.tensor_meta.stride,
)
else:
res_dt = None
Expand Down Expand Up @@ -97,7 +97,9 @@ def _reshape_alias(
return torch.ops.aten.view(x, shape)


_CURRENT_DECOMPOSITION_TABLE: Dict[Callable[..., object], Callable[..., object]] = {
_CURRENT_DECOMPOSITION_TABLE: Dict[
Callable[..., object], Callable[..., object]
] = {
torch.ops.aten._reshape_alias.default: _reshape_alias,
}

Expand All @@ -112,7 +114,9 @@ def operator_dispatch(
arg_list, _ = tree_flatten(args)
mesh = None
for arg in arg_list:
if isinstance(arg, torch.Tensor) and not isinstance(arg, dtensor.DTensor):
if isinstance(arg, torch.Tensor) and not isinstance(
arg, dtensor.DTensor
):
raise RuntimeError(
f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
" torch.Tensor to DTensor before calling distributed operators!"
Expand All @@ -134,7 +138,9 @@ def operator_dispatch(
# unwrap the args/kwargs schema
op_schema = sharding_propagator.prepare_op_schema(op_call, args, kwargs)

output_sharding = sharding_propagator.propagate_op_sharding(op_call, op_schema)
output_sharding = sharding_propagator.propagate_op_sharding(
op_call, op_schema
)

# if the schema suggestion from sharding prop is not the same instance as the
# input op_schema, it indicates a reshard, we need to redistribute the input
Expand Down Expand Up @@ -171,12 +177,16 @@ def operator_dispatch(
ret_type = str(ret_list[0].type)
if ret_type == "bool":
import operator
local_results: object = functools.reduce(operator.and_, obj_list, True)

local_results: object = functools.reduce(
operator.and_, obj_list, True
)
else:
raise NotImplementedError(
f"return type {ret_type} in DTensor op is not supported"
)
else:

def default_tensor(spec: DTensorSpec) -> torch.Tensor:
if spec.tensor_meta is not None:
shape = spec.tensor_meta.shape
Expand All @@ -188,16 +198,16 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor:
# non-scalar tensor
return torch.tensor([], dtype=dtype)
else:
raise RuntimeError(
f"{spec} has no tensor metadata."
)
raise RuntimeError(f"{spec} has no tensor metadata.")

if (isinstance(spec, DTensorSpec)):
if isinstance(spec, DTensorSpec):
# return a Tensor value
local_results = default_tensor(spec)
elif (isinstance(spec, Sequence)):
elif isinstance(spec, Sequence):
# return a List[Tensor] value
local_results = [default_tensor(s) if s is not None else None for s in spec]
local_results = [
default_tensor(s) if s is not None else None for s in spec
]
assert isinstance(local_results, List)
if None in local_results:
ret_type = str(ret_list[0].type)
Expand Down

0 comments on commit bccf2ef

Please sign in to comment.