diff --git a/paddle/phi/api/yaml/generator/dist_api_gen.py b/paddle/phi/api/yaml/generator/dist_api_gen.py index 4e46db4b84e84a..9b8292cc866b80 100644 --- a/paddle/phi/api/yaml/generator/dist_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_api_gen.py @@ -1645,10 +1645,12 @@ def generate_output_dist_attr_setting(self) -> str: # TODO(GhostScreaming): for inplace view operators like reshape, # input and output may have different shape. If they have no specified # InferSPMD rules, just set replicated dist_attr for them. - if ( - self.need_to_generate_code_for_inplace_impl(i) - and self.outputs['names'][i] not in self.view_map - ): + if self.need_to_generate_code_for_inplace_impl(i): + if ( + self.generate_general_infer_spmd + and self.outputs['names'][i] in self.view_map + ): + continue need_reshard = ( "true" if self.generate_general_infer_spmd else "false" ) diff --git a/paddle/phi/api/yaml/generator/dist_bw_api_gen.py b/paddle/phi/api/yaml/generator/dist_bw_api_gen.py index 16d00d156f15be..3769155eb27e11 100644 --- a/paddle/phi/api/yaml/generator/dist_bw_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_bw_api_gen.py @@ -177,6 +177,9 @@ RESHARD_VECTOR_OUTPUT_TEMPLATE = """ ReshardKernelOutputToApiOutput(dev_ctx, shared_dist_out, {});""" +NONEED_TO_RESHARD_OUTPUT_TEMPLATE = """ + // API `{}` does not need to reshard output.""" + class DistBackwardAPI(DistForwardAPI, BackwardAPI): def __init__(self, backward_item_yaml): @@ -344,6 +347,9 @@ def generate_reshard_output_code(self): f"{self.api} : Output error: the output should not be empty." ) else: + reshard_output_code += NONEED_TO_RESHARD_OUTPUT_TEMPLATE.format( + self.kernel['func'][0] + ) # do nothing pass diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.cc index 5627788a266e89..39780d19056b83 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.cc @@ -132,6 +132,7 @@ void PToSReshardFunctionCrossMesh::Eval(DeviceContext* dev_ctx, p_to_s_func.Eval(dev_ctx, in, in_dist_attr_shard, &tmp_result); } else { SetDistProps(&tmp_result, in.dims(), in_dist_attr_shard); + SetValue(&tmp_result, in.value()); } SameStatusReshardFunction same_status_func; diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.cc index 865873550ea063..3ca5ebb4e49b42 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/r_to_s_reshard_function.cc @@ -129,8 +129,8 @@ void RToSReshardFunctionCrossMesh::Eval(phi::DeviceContext* dev_ctx, r_to_s_func.Eval(dev_ctx, in, in_dist_attr_shard, &tmp_result); } else { SetDistProps(&tmp_result, in.dims(), in_dist_attr_shard); + SetValue(&tmp_result, in.value()); } - SameStatusReshardFunction same_status_func; PADDLE_ENFORCE( same_status_func.IsSuitable(tmp_result, out_dist_attr), diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc index 8d2b1bfa823140..b99862b001eb11 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc @@ -128,11 +128,9 @@ bool SToRReshardFunctionCrossMesh::IsSuitable( const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& out_process_mesh = out_dist_attr.process_mesh(); - int64_t cur_global_rank = GetCurGlobalRank(); - if (in_process_mesh.contains(cur_global_rank)) { - int split_axis = - GetSplitAxisWithDimsMapping(in_dims_mapping).begin()->first; - int64_t num_of_process = in_process_mesh.size(); + int split_axis = GetSplitAxisWithDimsMapping(in_dims_mapping).begin()->first; + int64_t num_of_process = in_process_mesh.size(); + if (in.initialized()) { RESHARD_SHORTCUT_IF_FALSE(in.local_dims()[static_cast(split_axis)] * num_of_process == in.dims()[static_cast(split_axis)]); diff --git a/paddle/phi/infermeta/spmd_rules/layer_norm.cc b/paddle/phi/infermeta/spmd_rules/layer_norm.cc index 7bd9482f4aa615..f47f381e093f08 100644 --- a/paddle/phi/infermeta/spmd_rules/layer_norm.cc +++ b/paddle/phi/infermeta/spmd_rules/layer_norm.cc @@ -384,7 +384,7 @@ SpmdInfo LayerNormGradInferSpmd(const DistMetaTensor& x, for (int i = 0; i < begin_norm_axis; ++i) { auto mapping = dim_mapping[i]; if (mapping != -1) { - partial_on_dims.push_back(i); + partial_on_dims.push_back(mapping); } } scale_grad_dist_attr.set_partial_status(partial_on_dims); diff --git a/paddle/phi/infermeta/spmd_rules/transpose.cc b/paddle/phi/infermeta/spmd_rules/transpose.cc index e4942f2e4718ef..95840f8edfe356 100644 --- a/paddle/phi/infermeta/spmd_rules/transpose.cc +++ b/paddle/phi/infermeta/spmd_rules/transpose.cc @@ -90,8 +90,7 @@ SpmdInfo TransposeInferSpmd(const DistMetaTensor& x, // input dist_attr. TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src); out_dist_attr.set_dims_mapping(out_dims_mapping); - - // Step3 Handle Partial (TODO) + out_dist_attr.set_partial_status(x_dist_attr_src.partial_status()); VLOG(4) << "TransposeInferSpmd:"; VLOG(4) << "Input: shape: [" << str_join(x_shape) << "] "