Skip to content

Commit

Permalink
[Dist Dialect] Fix the bug of PIR infer_spmd when op inputs are incon…
Browse files Browse the repository at this point in the history
…sistent (PaddlePaddle#67052)
  • Loading branch information
pkuzyc authored Aug 7, 2024
1 parent e6ebdc4 commit 9f7c60d
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 12 deletions.
13 changes: 13 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_tools.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,19 @@ pir::Attribute CvtToPirAttr(const phi::distributed::ArgDistAttr& dist_attr) {
}
}

pir::Attribute GetTensorDistAttrArray(pir::VectorType x_vec_type) {
std::vector<pir::Attribute> x_arr_attr;
for (size_t i = 0; i < x_vec_type.size(); i++) {
auto dist_type = x_vec_type[i].dyn_cast<DistTypeInterface>();
if (!dist_type) {
x_arr_attr.push_back(nullptr);
} else {
x_arr_attr.push_back(dist_type.tensor_dist_attr());
}
}
return pir::ArrayAttribute::get(pir::IrContext::Instance(), x_arr_attr);
}

pir::Attribute CreateReplicatedDistAttr(pir::Type prim_type,
ProcessMeshAttribute mesh) {
auto ctx = pir::IrContext::Instance();
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ std::vector<phi::distributed::DistMetaTensor> CvtToDistMetaTensor(
pir::VectorType type);
pir::Attribute CvtToPirAttr(const phi::distributed::ArgDistAttr& dist_attr);

// When the input is a vector of Value, get all its dist
// attributes and converts them to a ''pir::Attribute'' type.
pir::Attribute GetTensorDistAttrArray(pir::VectorType x_vec_type);

pir::Attribute CreateReplicatedDistAttr(pir::Type prim_type,
ProcessMeshAttribute mesh);

Expand Down
55 changes: 43 additions & 12 deletions paddle/fluid/pir/dialect/op_generator/op_infermeta_func_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,8 @@ def GenDistBranch(args, op_info):
{}
CvtAllInputsToDist(input_values, op_mesh);
auto ctx = pir::IrContext::Instance();
std::vector<pir::Attribute> dist_operand_attrs, dist_result_attrs;"""
std::vector<pir::Attribute> dist_operand_attrs, dist_result_attrs;
"""

extra_call = ""
for name in op_info.spmd_params:
Expand All @@ -810,8 +811,10 @@ def GenDistBranch(args, op_info):
extra_call = "CopyLeafOpToMesh(scale_, op_mesh);"
dist_branch_str = TEMPLATE.format(merge_input_meshes, extra_call)
infer_spmd_args_list = []
spmd_input_value_num = 0
map_input_idx = {}
# Prepare inputs_meta_tensor & attributes for infer spmd
for name in op_info.spmd_params:
for spmd_arg_idx, name in enumerate(op_info.spmd_params):
# is input
if name in op_info.input_name_list:
input_index = op_info.input_name_list.index(name)
Expand All @@ -838,6 +841,8 @@ def GenDistBranch(args, op_info):
auto dist_meta_{name} = CvtToDistMetaTensor({name}_.type().dyn_cast<DistDenseTensorType>());"""
dist_branch_str += TEMPLATE.format(name=name)
infer_spmd_args_list.append("dist_meta_" + name)
spmd_input_value_num += 1
map_input_idx[input_index] = spmd_arg_idx
else:
attr_index = op_info.attribute_name_list.index(name)
param_type = op_info.attribute_gen_arg_type_list[attr_index]
Expand All @@ -856,26 +861,52 @@ def GenDistBranch(args, op_info):
DebugInfoForInferSpmd("{op_name}", spmd_info);
PADDLE_ENFORCE_EQ(spmd_info.first.size(), {input_size}u, common::errors::Unavailable(
"Size of spmd_info.first for op[{op_name}]is unexpected."));
for(auto& arg_dist : spmd_info.first) {{
dist_operand_attrs.push_back(CvtToPirAttr(arg_dist));
}}
"""
dist_branch_str += TEMPLATE.format(
spmd_func=spmd_rule_func,
args=', '.join(infer_spmd_args_list),
input_size=len(op_info.input_name_list),
input_size=spmd_input_value_num,
op_name=op_info.class_name,
)

if spmd_input_value_num == len(op_info.input_name_list):
TEMPLATE = """
for(auto& arg_dist : spmd_info.first) {
dist_operand_attrs.push_back(CvtToPirAttr(arg_dist));
}
"""
dist_branch_str += TEMPLATE
else:
for i in range(len(op_info.input_name_list)):
if i in map_input_idx:
spmd_idx = map_input_idx[i]
TEMPLATE = """
dist_operand_attrs.push_back(CvtToPirAttr(spmd_info.first[{idx}]));"""
dist_branch_str += TEMPLATE.format(idx=spmd_idx)
# vector<Tensor> input
elif "pir::VectorType" in op_info.input_type_list[i]:
TEMPLATE = """
dist_operand_attrs.push_back(GetTensorDistAttrArray({name}));"""
dist_branch_str += TEMPLATE.format(
name=op_info.input_name_list[i]
)
# Tensor input
else:
TEMPLATE = """
dist_operand_attrs.push_back({name}.type().dyn_cast<DistTypeInterface>().tensor_dist_attr());"""
dist_branch_str += TEMPLATE.format(
name=op_info.input_name_list[i]
)

if len(op_info.mutable_attribute_name_list) > 0:
TEMPLATE = """
for(int i = {input_size}; i < {all_input_size}; ++i) {{
if(auto dist_type = input_values[i].type().dyn_cast<DistTypeInterface>()) {{
dist_operand_attrs.push_back(dist_type.tensor_dist_attr());
}}
else {{
dist_operand_attrs.push_back(nullptr);
}}
if(auto dist_type = input_values[i].type().dyn_cast<DistTypeInterface>()) {{
dist_operand_attrs.push_back(dist_type.tensor_dist_attr());
}}
else {{
dist_operand_attrs.push_back(nullptr);
}}
}}
"""
dist_branch_str += TEMPLATE.format(
Expand Down
2 changes: 2 additions & 0 deletions test/auto_parallel/pir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,6 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
test_eliminate_transpose_pass MODULES test_eliminate_transpose_pass ENVS
FLAGS_enable_pir_in_executor=1)
py_test_modules(test_moe_api MODULES test_moe_api ENVS FLAGS_enable_pir_api=1)
py_test_modules(test_pir_stack_grad_spmd_rule MODULES
test_stack_grad_spmd_rule ENVS FLAGS_enable_pir_api=1)
endif()
65 changes: 65 additions & 0 deletions test/auto_parallel/pir/test_stack_grad_spmd_rule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import paddle
import paddle.distributed as dist
from paddle.distributed.auto_parallel.static.mix_to_dist_pass import (
apply_mix2dist_pass,
)

paddle.enable_static()


class TestStackGradSpmdRule(unittest.TestCase):
def test_build_replicated_program(self):
main_program = paddle.base.Program()
with paddle.base.program_guard(main_program):
mesh = dist.ProcessMesh([0, 1])
x0 = paddle.static.data(name='x0', shape=[64, 36])
x1 = paddle.static.data(name='x1', shape=[64, 36])
x0.stop_gradient = False
x1.stop_gradient = False
y = paddle.static.data(name='y', shape=[2, 64, 36])
dist_x0 = dist.shard_tensor(x0, mesh, [dist.Shard(0)])
dist_x1 = dist.shard_tensor(x1, mesh, [dist.Shard(0)])
dist_out = paddle.stack([dist_x0, dist_x1], axis=0)
loss = paddle.mean(dist_out - y)

dist_program = main_program.clone()
apply_mix2dist_pass(dist_program)
dist_loss_value = dist_program.global_block().ops[-1].result(0)

with paddle.static.program_guard(dist_program):
params_grads = paddle.autograd.ir_backward.append_backward(
dist_loss_value
)

stack_grad_op = [
op
for op in dist_program.global_block().ops
if op.name() == "pd_op.stack_grad"
]
stack_grad_op = stack_grad_op[0]
out_grad = stack_grad_op.operand_source(1)
x0_grad = dist_program.global_block().ops[-1].result(0)
x1_grad = dist_program.global_block().ops[-1].result(1)
self.assertEqual(out_grad.dist_attr().dims_mapping, [-1, 0, -1])
self.assertEqual(x0_grad.dist_attr().dims_mapping, [0, -1])
self.assertEqual(x1_grad.dist_attr().dims_mapping, [0, -1])


if __name__ == "__main__":
unittest.main()

0 comments on commit 9f7c60d

Please sign in to comment.