diff --git a/paddle/fluid/operators/generator/parse_utils.py b/paddle/fluid/operators/generator/parse_utils.py index 0370d6cfba4b32..38a87efec0415c 100644 --- a/paddle/fluid/operators/generator/parse_utils.py +++ b/paddle/fluid/operators/generator/parse_utils.py @@ -369,7 +369,7 @@ def check_op_config(op_entry, op_name): 'traits', 'interfaces', ) - infer_meta_key_set = ('func', 'param', 'spmd_rule') + infer_meta_key_set = ('func', 'param', 'spmd_rule', 'local_shape') kernel_key_set = ( 'func', 'param', diff --git a/paddle/phi/api/yaml/generator/dist_api_gen.py b/paddle/phi/api/yaml/generator/dist_api_gen.py index d0b82f3be9f70a..ad153639c4d56c 100644 --- a/paddle/phi/api/yaml/generator/dist_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_api_gen.py @@ -483,53 +483,56 @@ // API `{}` does not need to set DistAttr for output.""" # TODO(GhostScreaming): Support aliquant condition. -# Specialized Code, for example, reshape needs to calculate local_shape -RESHAPE_CALCULATE_LOCAL_SHAPE_TEMPLATE = """ +# Operators like `reshape`, `expand_as` need to calculate local_shape +# for their local `DenseTensor`, as the given shape in their attribute +# is global_shape for `DistTensor`. +CALCULATE_LOCAL_SHAPE_TEMPLATE = """ // The dist_input_x is a dist tensor, the dims() func return the global dims. auto x_shape = dist_input_x->dims(); auto x_numel = dist_input_x->numel(); bool visit_negative = false; - std::vector local_shape; - for (size_t i = 0; i < shape.GetData().size(); i++) { + auto global_shape = {shape}; + std::vector<{dtype}> local_shape; + for (size_t i = 0; i < global_shape.size(); i++) {{ auto& out_dist_attr = PADDLE_GET_CONST(phi::distributed::TensorDistAttr, spmd_info.second[0]); - if (out_dist_attr.dims_mapping()[i] >= 0) { - int64_t shape_i = shape.GetData()[i]; - if (shape_i == 0) { + if (out_dist_attr.dims_mapping()[i] >= 0) {{ + {dtype} shape_i = global_shape[i]; + if (shape_i == 0) {{ shape_i = x_shape[i]; - } else if (shape_i == -1) { + }} else if (shape_i == -1) {{ PADDLE_ENFORCE(not visit_negative, phi::errors::InvalidArgument( - "Reshape can only have one -1 in the shape.")); + "{op_name} can only have one -1 in the {shape_name}.")); visit_negative = true; int64_t non_negative_product = 1; - for (size_t j = 0; j < shape.GetData().size(); j++) { - if (i == j) { + for (size_t j = 0; j < global_shape.size(); j++) {{ + if (i == j) {{ continue; - } - int64_t tmp_j = shape.GetData()[j]; - if (tmp_j == 0) { + }} + int64_t tmp_j = global_shape[j]; + if (tmp_j == 0) {{ tmp_j = x_shape[j]; - } + }} non_negative_product *= tmp_j; - } + }} PADDLE_ENFORCE(x_numel % non_negative_product == 0, phi::errors::InvalidArgument("Cannot infer real shape for -1.")); shape_i = x_numel / non_negative_product; - } + }} int64_t dim = out_dist_attr.dims_mapping()[i]; int64_t mesh_dim = out_dist_attr.process_mesh().shape()[dim]; // TODO: Support aliquant condition. PADDLE_ENFORCE(shape_i % mesh_dim == 0, phi::errors::InvalidArgument( - "Reshape only support local shape dim is divisible " + "{op_name} only support local shape dim is divisible " "by the mesh dim, however local_shape[%lld] is %lld " "and shard mesh dims is %lld.", i, shape_i, mesh_dim)); local_shape.push_back(shape_i / mesh_dim); - } else { - local_shape.push_back(shape.GetData()[i]); - } - } + }} else {{ + local_shape.push_back({shape}[i]); + }} + }} """ # BaseAPI members: @@ -590,7 +593,11 @@ def parse_infer_meta(self, infer_meta_config): infer_meta['param'] = None if 'spmd_rule' not in infer_meta_config: infer_meta['spmd_rule'] = None - + # Operators like `reshape`, `expand_as` need to calculate local_shape + # for their local `DenseTensor`, as the given shape in their attribute + # is global_shape for `DistTensor`. + if 'local_shape' not in infer_meta_config: + infer_meta['local_shape'] = None return infer_meta def need_to_generate_code_for_inplace_impl(self, i): @@ -613,17 +620,6 @@ def need_to_generate_code_for_inplace_or_view_impl(self, i): i ) or self.need_to_generate_code_for_view_impl(i) - # # view output is also inlace, such case still needs - # # to create an empty DenseTensor for inplace output in pp - # def need_to_set_inplace_output_for_pp_impl(self, i): - # return (not self.need_to_generate_code_for_view_impl(i)) and self.is_inplace_output(i) - - def is_reshape_kernel(self): - return ( - "reshape" in self.kernel['func'][0] - and 'grad' not in self.kernel['func'][0] - ) - def is_inplace_output(self, i): return self.outputs['names'][i] in self.inplace_map @@ -1548,8 +1544,8 @@ def generate_infer_meta_code(self) -> str: f"{self.api} : Param of infer_meta error : {self.inputs['input_info'][param]} type is not supported." ) elif param in attr_names: - # TODO(GhostScreaming): reshape kernel need specialized process - if self.is_reshape_kernel() and param == "shape": + # TODO(GhostScreaming): kernel like reshape need calculate local_shape + if self.infer_meta['local_shape'] is not None: input_args_code = input_args_code + "local_shape" + ", " else: input_args_code = input_args_code + param + ", " @@ -1582,9 +1578,24 @@ def generate_infer_meta_code(self) -> str: output_args_code = output_args_code[:-2] infer_meta_code = "" - # TODO(GhostScreaming): reshape kernel need specialized process - if self.is_reshape_kernel(): - infer_meta_code = RESHAPE_CALCULATE_LOCAL_SHAPE_TEMPLATE + # TODO(GhostScreaming): kernel like reshape need calculate local_shape + if self.infer_meta['local_shape'] is not None: + shape_name = self.infer_meta['local_shape'] + assert ( + shape_name in self.attrs['names'] + ), f"Auto Parallel will calculate local_shape {shape_name} for" + "operator {self.kernel['func'][0]}, but {shape_name} is not" + "found in its attributes." + shape_type = self.attrs['attr_info'][shape_name][0] + + infer_meta_code = CALCULATE_LOCAL_SHAPE_TEMPLATE.format( + shape=f"{shape_name}.GetData()" + if shape_type == "IntArray" + else f"{shape_name}", + dtype="int64_t" if shape_type == "IntArray" else "int", + op_name=self.kernel['func'][0], + shape_name=shape_name, + ) infer_meta_code = infer_meta_code + INFER_META_TEMPLATE.format( infer_meta_func_code, input_args_code, output_args_code ) @@ -1637,8 +1648,8 @@ def generate_kernel_call_code(self) -> str: elif arg in attr_names: if 'IntArray' in self.attrs['attr_info'][arg][0]: kernel_args_type_list.append('const phi::IntArray&') - # TODO(GhostScreaming): reshape kernel need specialized process - if self.is_reshape_kernel() and arg == "shape": + # TODO(GhostScreaming): kernel like reshape need calculate local_shape + if self.infer_meta['local_shape'] is not None: arg = 'phi::IntArray(local_shape)' else: arg = 'phi::IntArray(' + arg + ')' diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index a629ab70cd1091..e27e5de111bc8f 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -1005,6 +1005,7 @@ infer_meta : func : ReshapeWithXShapeInferMeta spmd_rule : ReshapeInferSpmdDynamic + local_shape: shape kernel : func : reshape inplace : (x -> out) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 35ccab6221eb6f..ce7d9e935247d0 100755 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -946,6 +946,7 @@ output : Tensor(out) infer_meta : func : ExpandAsInferMeta + local_shape: target_shape kernel : func : expand_as data_type : x