Skip to content

Commit

Permalink
[AutoParallel] Fix problems of sequence parallel in dynamic mode. (Pa…
Browse files Browse the repository at this point in the history
…ddlePaddle#59766)

* [AutoParallel] Fix problems of sequence parallel in dynamic mode.

* Polish code.

* Remove TODO in transpose.cc

* Polish code.

* Remove useless modification.

* Polish code.

* Polish code.

* Remove useless modification.

* Allow partial status flow
  • Loading branch information
GhostScreaming authored Dec 8, 2023
1 parent 7c44a2b commit a168173
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 13 deletions.
10 changes: 6 additions & 4 deletions paddle/phi/api/yaml/generator/dist_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1645,10 +1645,12 @@ def generate_output_dist_attr_setting(self) -> str:
# TODO(GhostScreaming): for inplace view operators like reshape,
# input and output may have different shape. If they have no specified
# InferSPMD rules, just set replicated dist_attr for them.
if (
self.need_to_generate_code_for_inplace_impl(i)
and self.outputs['names'][i] not in self.view_map
):
if self.need_to_generate_code_for_inplace_impl(i):
if (
self.generate_general_infer_spmd
and self.outputs['names'][i] in self.view_map
):
continue
need_reshard = (
"true" if self.generate_general_infer_spmd else "false"
)
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/api/yaml/generator/dist_bw_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@
RESHARD_VECTOR_OUTPUT_TEMPLATE = """
ReshardKernelOutputToApiOutput(dev_ctx, shared_dist_out, {});"""

NONEED_TO_RESHARD_OUTPUT_TEMPLATE = """
// API `{}` does not need to reshard output."""


class DistBackwardAPI(DistForwardAPI, BackwardAPI):
def __init__(self, backward_item_yaml):
Expand Down Expand Up @@ -344,6 +347,9 @@ def generate_reshard_output_code(self):
f"{self.api} : Output error: the output should not be empty."
)
else:
reshard_output_code += NONEED_TO_RESHARD_OUTPUT_TEMPLATE.format(
self.kernel['func'][0]
)
# do nothing
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ void PToSReshardFunctionCrossMesh::Eval(DeviceContext* dev_ctx,
p_to_s_func.Eval(dev_ctx, in, in_dist_attr_shard, &tmp_result);
} else {
SetDistProps(&tmp_result, in.dims(), in_dist_attr_shard);
SetValue(&tmp_result, in.value());
}

SameStatusReshardFunction same_status_func;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ void RToSReshardFunctionCrossMesh::Eval(phi::DeviceContext* dev_ctx,
r_to_s_func.Eval(dev_ctx, in, in_dist_attr_shard, &tmp_result);
} else {
SetDistProps(&tmp_result, in.dims(), in_dist_attr_shard);
SetValue(&tmp_result, in.value());
}

SameStatusReshardFunction same_status_func;
PADDLE_ENFORCE(
same_status_func.IsSuitable(tmp_result, out_dist_attr),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,9 @@ bool SToRReshardFunctionCrossMesh::IsSuitable(
const auto& in_process_mesh = in_dist_attr.process_mesh();
const auto& out_process_mesh = out_dist_attr.process_mesh();

int64_t cur_global_rank = GetCurGlobalRank();
if (in_process_mesh.contains(cur_global_rank)) {
int split_axis =
GetSplitAxisWithDimsMapping(in_dims_mapping).begin()->first;
int64_t num_of_process = in_process_mesh.size();
int split_axis = GetSplitAxisWithDimsMapping(in_dims_mapping).begin()->first;
int64_t num_of_process = in_process_mesh.size();
if (in.initialized()) {
RESHARD_SHORTCUT_IF_FALSE(in.local_dims()[static_cast<int>(split_axis)] *
num_of_process ==
in.dims()[static_cast<int>(split_axis)]);
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/spmd_rules/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ SpmdInfo LayerNormGradInferSpmd(const DistMetaTensor& x,
for (int i = 0; i < begin_norm_axis; ++i) {
auto mapping = dim_mapping[i];
if (mapping != -1) {
partial_on_dims.push_back(i);
partial_on_dims.push_back(mapping);
}
}
scale_grad_dist_attr.set_partial_status(partial_on_dims);
Expand Down
3 changes: 1 addition & 2 deletions paddle/phi/infermeta/spmd_rules/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ SpmdInfo TransposeInferSpmd(const DistMetaTensor& x,
// input dist_attr.
TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src);
out_dist_attr.set_dims_mapping(out_dims_mapping);

// Step3 Handle Partial (TODO)
out_dist_attr.set_partial_status(x_dist_attr_src.partial_status());

VLOG(4) << "TransposeInferSpmd:";
VLOG(4) << "Input: shape: [" << str_join(x_shape) << "] "
Expand Down

0 comments on commit a168173

Please sign in to comment.