Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN]fix symbol arg binding in bc optimize #70193

Merged
merged 3 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
// limitations under the License.

#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/collect_sym_expr.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h"
#include "paddle/pir/include/dialect/shape/utils/dim_expr_util.h"

namespace {
using cinn::dialect::ir::details::GetBlockOutsideInput;
using cinn::dialect::ir::details::OpLoweringGroup;
using cinn::dialect::ir::details::OpLoweringGroupPtr;
using cinn::hlir::framework::pir::GetBlockOutsideInput;

bool IsComplicatedDimExpr(const symbol::DimExpr& dim_expr) {
auto lambdas = common::Overloaded{
Expand Down Expand Up @@ -136,7 +136,8 @@ CollectSubstituteDimExprMap(
[&](const symbol::DimExpr& dim_expr) {
if (dim_expr.isa<std::string>()) return false;
for (const auto& symbol : symbol::CollectDimExprSymbols(dim_expr)) {
if (new_symbol_set.count(symbol) == 0) {
if (new_symbol_set.count(symbol) == 0 &&
base_dim_expr_set.count(symbol) == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

什么样的case会同时命中这两个条件?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段逻辑是想说如果被替换之后出现了new_symbol_set或者base_dim_expr_set里都没有的符号,那么它就是一个不能被子集替换的,很多case都有,比如BC(S0, S1),new_symbol_set里没有,输入符号里也没有S0和S1,那么它就can't BeRepresentedBySubset,需要被替换成新符号

return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "paddle/cinn/hlir/dialect/operator/transforms/refresh_combine_pattern.h"
#include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h"
#include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/pir/include/core/builtin_type.h"
#include "paddle/pir/include/pass/pass_registry.h"

Expand Down Expand Up @@ -62,7 +63,8 @@ class FusionOpPattern : public pir::OpRewritePattern<cinn::dialect::FusionOp> {
virtual pir::Operation* ProcessGroup(
const OpLoweringGroupPtr& group,
pir::PatternRewriter& rewriter) const { // NOLINT
auto group_inputs = GetBlockOutsideInput(group->ops());
auto group_inputs =
cinn::hlir::framework::pir::GetBlockOutsideInput(group->ops());
// compile group to jit_kernel_op
std::vector<pir::Type> output_types;
const auto& group_output_values = group->output_values();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,29 +36,6 @@ using cinn::hlir::framework::PirCompiler;
using cinn::hlir::framework::pir::CINNKernelInfo;
using cinn::hlir::framework::pir::CompatibleInfo;

std::vector<pir::Value> GetBlockOutsideInput(
const std::vector<pir::Operation*>& op_list) {
std::vector<pir::Value> vec_res;
std::unordered_set<::pir::Value> block_inner_output;
for (size_t k = 0; k < op_list.size(); ++k) {
for (size_t i = 0; i < op_list[k]->num_results(); ++i) {
block_inner_output.insert(op_list[k]->result(i));
}
}

std::unordered_set<::pir::Value> insert_value;
for (size_t k = 0; k < op_list.size(); ++k) {
for (size_t i = 0; i < op_list[k]->num_operands(); ++i) {
if (!block_inner_output.count(op_list[k]->operand_source(i)) &&
!insert_value.count(op_list[k]->operand_source(i))) {
vec_res.push_back(op_list[k]->operand_source(i));
insert_value.insert(op_list[k]->operand_source(i));
}
}
}
return vec_res;
}

std::unordered_map<std::string, ::pir::Attribute> GetJitKernelAttr(
const OpLoweringGroupPtr& group) {
const auto& CreateKernelInfo = [&]() -> CINNKernelInfo {
Expand Down
82 changes: 49 additions & 33 deletions paddle/cinn/hlir/framework/pir/compilation_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,52 +101,68 @@ void UnifyBroadcastGroupFuncArgs(
std::vector<GroupCompilationContext>* contexts,
pir::OpLoweringGroupPtr origin_group,
std::unordered_map<int, ir::Var>* symbolic_shape_var_index) {
std::unordered_map<ir::Var, pir::CINNKernelInfo::SymbolArgBindInfo>
new_args_map;
std::vector<ir::Argument> new_args_vec;
int total_args_num = 0;
int cur_arg_idx = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cur_arg_idx加点注释说明

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done,直接改造其构造方法并放入AddSymbolArgs函数中,语义已经明确


const auto& AddTensorArgs = [&](GroupCompilationContext& context) {
const auto& func_args = context.lowered_funcs_[0]->args;
const auto& origin_symbol_args = context.group_->symbol_args_map();
const auto& AddTensorArgs = [&]() {
const auto& func_args = (*contexts)[0].lowered_funcs_[0]->args;
for (size_t arg_idx = 0; arg_idx < func_args.size(); ++arg_idx) {
if (func_args[arg_idx].is_var()) {
new_args_map[func_args[arg_idx].var_arg()] =
origin_symbol_args.at(arg_idx);
} else {
new_args_vec.emplace_back(func_args[arg_idx]);
break;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么是跳过var?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为当前func_arg是按照(tensor_arg1, tensor_arg2, .. tensor_argn, ... var_arg1, var_arg2, .. var_argn)组织的,这一步只收集TensorArg,var_arg(也就是SymbolArg)后面根据原始shape_or_data生成

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那是不写成tensor_arg的判断代码更容易理解些?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}
}
for (ir::LoweredFunc& func : context.lowered_funcs_) {
func->args = new_args_vec;
new_args_vec.emplace_back(func_args[arg_idx]);
cur_arg_idx++;
}
};
for (size_t i = 0; i < contexts->size(); ++i) {
AddTensorArgs((*contexts)[i]);
if (i == 0) total_args_num += new_args_vec.size();
new_args_vec.clear();
}

origin_group->mut_symbol_args_map().clear();
const auto& new_symbol_args_vec = [&]() -> std::vector<ir::Argument> {
std::vector<ir::Argument> res;
for (const auto& [arg, idx_info] : new_args_map) {
symbolic_shape_var_index->insert({total_args_num, arg});
origin_group->mut_symbol_args_map()[total_args_num++] = idx_info;
res.emplace_back(ir::Argument{arg});
}
return res;
}();
std::unordered_set<std::string> symbol_args_set;
const auto& AddSymbolArgs = [&](::pir::Value input, const int& input_idx) {
enum ArgType { Dim, Value };
const auto& AddSymbolArgFromDimExprVec =
[&](ArgType arg_type, const std::vector<symbol::DimExpr>& expr_vec) {
int vec_size = expr_vec.size();
for (int idx = 0; idx < vec_size; idx++) {
if (expr_vec[idx].isa<std::string>()) {
const std::string& symbol_name =
expr_vec[idx].dyn_cast<std::string>();
if (symbol_args_set.count(symbol_name) != 0) {
continue;
}
symbol_args_set.insert(symbol_name);
const auto& arg = ir::Var(symbol_name, cinn::common::Int(64));
new_args_vec.emplace_back(ir::Argument{arg});
symbolic_shape_var_index->insert({cur_arg_idx, arg});
if (arg_type == Dim) {
origin_group->mut_symbol_args_map()[cur_arg_idx++] =
pir::CINNKernelInfo::ArgDimIdx{input_idx, idx};
} else {
origin_group->mut_symbol_args_map()[cur_arg_idx++] =
pir::CINNKernelInfo::ArgValueIdx{input_idx, idx};
}
}
}
};
const auto& shape_or_data = origin_group->GetShapeOrDataExprs(input);
// Add dim symbol args
AddSymbolArgFromDimExprVec(ArgType::Dim, shape_or_data.shape());
// Add value symbol args
if (shape_or_data.data())
AddSymbolArgFromDimExprVec(ArgType::Value, shape_or_data.data().value());
};

const auto& AddUnifiedSymbolArgs = [&](GroupCompilationContext& context) {
const auto& UpdateAllFuncArgs = [&](GroupCompilationContext& context) {
for (ir::LoweredFunc& func : context.lowered_funcs_) {
func->args.insert(func->args.end(),
new_symbol_args_vec.begin(),
new_symbol_args_vec.end());
func->args = new_args_vec;
}
};

AddTensorArgs();
origin_group->mut_symbol_args_map().clear();
const auto& group_inputs = pir::GetBlockOutsideInput(origin_group->ops());
for (size_t input_idx = 0; input_idx < group_inputs.size(); ++input_idx)
AddSymbolArgs(group_inputs[input_idx], input_idx);
for (int i = 0; i < contexts->size(); ++i) {
AddUnifiedSymbolArgs((*contexts)[i]);
UpdateAllFuncArgs((*contexts)[i]);
}
}

Expand Down
23 changes: 23 additions & 0 deletions paddle/cinn/hlir/framework/pir/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,29 @@ std::vector<int64_t> GetBroadcastAxis(const phi::DDim& in_shape,
return broadcast_axes;
}

std::vector<::pir::Value> GetBlockOutsideInput(
const std::vector<::pir::Operation*>& op_list) {
std::vector<::pir::Value> vec_res;
std::unordered_set<::pir::Value> block_inner_output;
for (size_t k = 0; k < op_list.size(); ++k) {
for (size_t i = 0; i < op_list[k]->num_results(); ++i) {
block_inner_output.insert(op_list[k]->result(i));
}
}

std::unordered_set<::pir::Value> insert_value;
for (size_t k = 0; k < op_list.size(); ++k) {
for (size_t i = 0; i < op_list[k]->num_operands(); ++i) {
if (!block_inner_output.count(op_list[k]->operand_source(i)) &&
!insert_value.count(op_list[k]->operand_source(i))) {
vec_res.push_back(op_list[k]->operand_source(i));
insert_value.insert(op_list[k]->operand_source(i));
}
}
}
return vec_res;
}

} // namespace pir
} // namespace framework
} // namespace hlir
Expand Down
3 changes: 3 additions & 0 deletions paddle/cinn/hlir/framework/pir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ struct CompatibleInfo {
std::vector<int64_t> GetBroadcastAxis(const ::common::DDim& in_shape,
const std::vector<int64_t>& out_shape);

std::vector<::pir::Value> GetBlockOutsideInput(
const std::vector<::pir::Operation*>& op_list);

class PrettyNamer {
public:
const std::string& GetOrNew(::pir::Value hash_key,
Expand Down