Skip to content

Commit

Permalink
[AutoParallel] Fix problem of expand_as. (PaddlePaddle#62460)
Browse files Browse the repository at this point in the history
* [AutoParallel] Fix problem of expand_as. It needs to calculate local shape in auto parallel dynamic graph mode.

* Remove useless print.

* Polish code according to comments.
  • Loading branch information
GhostScreaming authored Mar 8, 2024
1 parent 7b1540a commit 3646da6
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 42 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/operators/generator/parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def check_op_config(op_entry, op_name):
'traits',
'interfaces',
)
infer_meta_key_set = ('func', 'param', 'spmd_rule')
infer_meta_key_set = ('func', 'param', 'spmd_rule', 'local_shape')
kernel_key_set = (
'func',
'param',
Expand Down
93 changes: 52 additions & 41 deletions paddle/phi/api/yaml/generator/dist_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,53 +483,56 @@
// API `{}` does not need to set DistAttr for output."""

# TODO(GhostScreaming): Support aliquant condition.
# Specialized Code, for example, reshape needs to calculate local_shape
RESHAPE_CALCULATE_LOCAL_SHAPE_TEMPLATE = """
# Operators like `reshape`, `expand_as` need to calculate local_shape
# for their local `DenseTensor`, as the given shape in their attribute
# is global_shape for `DistTensor`.
CALCULATE_LOCAL_SHAPE_TEMPLATE = """
// The dist_input_x is a dist tensor, the dims() func return the global dims.
auto x_shape = dist_input_x->dims();
auto x_numel = dist_input_x->numel();
bool visit_negative = false;
std::vector<int64_t> local_shape;
for (size_t i = 0; i < shape.GetData().size(); i++) {
auto global_shape = {shape};
std::vector<{dtype}> local_shape;
for (size_t i = 0; i < global_shape.size(); i++) {{
auto& out_dist_attr = PADDLE_GET_CONST(phi::distributed::TensorDistAttr, spmd_info.second[0]);
if (out_dist_attr.dims_mapping()[i] >= 0) {
int64_t shape_i = shape.GetData()[i];
if (shape_i == 0) {
if (out_dist_attr.dims_mapping()[i] >= 0) {{
{dtype} shape_i = global_shape[i];
if (shape_i == 0) {{
shape_i = x_shape[i];
} else if (shape_i == -1) {
}} else if (shape_i == -1) {{
PADDLE_ENFORCE(not visit_negative,
phi::errors::InvalidArgument(
"Reshape can only have one -1 in the shape."));
"{op_name} can only have one -1 in the {shape_name}."));
visit_negative = true;
int64_t non_negative_product = 1;
for (size_t j = 0; j < shape.GetData().size(); j++) {
if (i == j) {
for (size_t j = 0; j < global_shape.size(); j++) {{
if (i == j) {{
continue;
}
int64_t tmp_j = shape.GetData()[j];
if (tmp_j == 0) {
}}
int64_t tmp_j = global_shape[j];
if (tmp_j == 0) {{
tmp_j = x_shape[j];
}
}}
non_negative_product *= tmp_j;
}
}}
PADDLE_ENFORCE(x_numel % non_negative_product == 0,
phi::errors::InvalidArgument("Cannot infer real shape for -1."));
shape_i = x_numel / non_negative_product;
}
}}
int64_t dim = out_dist_attr.dims_mapping()[i];
int64_t mesh_dim = out_dist_attr.process_mesh().shape()[dim];
// TODO: Support aliquant condition.
PADDLE_ENFORCE(shape_i % mesh_dim == 0,
phi::errors::InvalidArgument(
"Reshape only support local shape dim is divisible "
"{op_name} only support local shape dim is divisible "
"by the mesh dim, however local_shape[%lld] is %lld "
"and shard mesh dims is %lld.", i, shape_i, mesh_dim));
local_shape.push_back(shape_i / mesh_dim);
} else {
local_shape.push_back(shape.GetData()[i]);
}
}
}} else {{
local_shape.push_back({shape}[i]);
}}
}}
"""

# BaseAPI members:
Expand Down Expand Up @@ -590,7 +593,11 @@ def parse_infer_meta(self, infer_meta_config):
infer_meta['param'] = None
if 'spmd_rule' not in infer_meta_config:
infer_meta['spmd_rule'] = None

# Operators like `reshape`, `expand_as` need to calculate local_shape
# for their local `DenseTensor`, as the given shape in their attribute
# is global_shape for `DistTensor`.
if 'local_shape' not in infer_meta_config:
infer_meta['local_shape'] = None
return infer_meta

def need_to_generate_code_for_inplace_impl(self, i):
Expand All @@ -613,17 +620,6 @@ def need_to_generate_code_for_inplace_or_view_impl(self, i):
i
) or self.need_to_generate_code_for_view_impl(i)

# # view output is also inlace, such case still needs
# # to create an empty DenseTensor for inplace output in pp
# def need_to_set_inplace_output_for_pp_impl(self, i):
# return (not self.need_to_generate_code_for_view_impl(i)) and self.is_inplace_output(i)

def is_reshape_kernel(self):
return (
"reshape" in self.kernel['func'][0]
and 'grad' not in self.kernel['func'][0]
)

def is_inplace_output(self, i):
return self.outputs['names'][i] in self.inplace_map

Expand Down Expand Up @@ -1548,8 +1544,8 @@ def generate_infer_meta_code(self) -> str:
f"{self.api} : Param of infer_meta error : {self.inputs['input_info'][param]} type is not supported."
)
elif param in attr_names:
# TODO(GhostScreaming): reshape kernel need specialized process
if self.is_reshape_kernel() and param == "shape":
# TODO(GhostScreaming): kernel like reshape need calculate local_shape
if self.infer_meta['local_shape'] is not None:
input_args_code = input_args_code + "local_shape" + ", "
else:
input_args_code = input_args_code + param + ", "
Expand Down Expand Up @@ -1582,9 +1578,24 @@ def generate_infer_meta_code(self) -> str:
output_args_code = output_args_code[:-2]

infer_meta_code = ""
# TODO(GhostScreaming): reshape kernel need specialized process
if self.is_reshape_kernel():
infer_meta_code = RESHAPE_CALCULATE_LOCAL_SHAPE_TEMPLATE
# TODO(GhostScreaming): kernel like reshape need calculate local_shape
if self.infer_meta['local_shape'] is not None:
shape_name = self.infer_meta['local_shape']
assert (
shape_name in self.attrs['names']
), f"Auto Parallel will calculate local_shape {shape_name} for"
"operator {self.kernel['func'][0]}, but {shape_name} is not"
"found in its attributes."
shape_type = self.attrs['attr_info'][shape_name][0]

infer_meta_code = CALCULATE_LOCAL_SHAPE_TEMPLATE.format(
shape=f"{shape_name}.GetData()"
if shape_type == "IntArray"
else f"{shape_name}",
dtype="int64_t" if shape_type == "IntArray" else "int",
op_name=self.kernel['func'][0],
shape_name=shape_name,
)
infer_meta_code = infer_meta_code + INFER_META_TEMPLATE.format(
infer_meta_func_code, input_args_code, output_args_code
)
Expand Down Expand Up @@ -1637,8 +1648,8 @@ def generate_kernel_call_code(self) -> str:
elif arg in attr_names:
if 'IntArray' in self.attrs['attr_info'][arg][0]:
kernel_args_type_list.append('const phi::IntArray&')
# TODO(GhostScreaming): reshape kernel need specialized process
if self.is_reshape_kernel() and arg == "shape":
# TODO(GhostScreaming): kernel like reshape need calculate local_shape
if self.infer_meta['local_shape'] is not None:
arg = 'phi::IntArray(local_shape)'
else:
arg = 'phi::IntArray(' + arg + ')'
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,7 @@
infer_meta :
func : ReshapeWithXShapeInferMeta
spmd_rule : ReshapeInferSpmdDynamic
local_shape: shape
kernel :
func : reshape
inplace : (x -> out)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,7 @@
output : Tensor(out)
infer_meta :
func : ExpandAsInferMeta
local_shape: target_shape
kernel :
func : expand_as
data_type : x
Expand Down

0 comments on commit 3646da6

Please sign in to comment.