-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CINN]fix symbol arg binding in bc optimize #70193
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -101,52 +101,68 @@ void UnifyBroadcastGroupFuncArgs( | |
std::vector<GroupCompilationContext>* contexts, | ||
pir::OpLoweringGroupPtr origin_group, | ||
std::unordered_map<int, ir::Var>* symbolic_shape_var_index) { | ||
std::unordered_map<ir::Var, pir::CINNKernelInfo::SymbolArgBindInfo> | ||
new_args_map; | ||
std::vector<ir::Argument> new_args_vec; | ||
int total_args_num = 0; | ||
int cur_arg_idx = 0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cur_arg_idx加点注释说明 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done,直接改造其构造方法并放入AddSymbolArgs函数中,语义已经明确 |
||
|
||
const auto& AddTensorArgs = [&](GroupCompilationContext& context) { | ||
const auto& func_args = context.lowered_funcs_[0]->args; | ||
const auto& origin_symbol_args = context.group_->symbol_args_map(); | ||
const auto& AddTensorArgs = [&]() { | ||
const auto& func_args = (*contexts)[0].lowered_funcs_[0]->args; | ||
for (size_t arg_idx = 0; arg_idx < func_args.size(); ++arg_idx) { | ||
if (func_args[arg_idx].is_var()) { | ||
new_args_map[func_args[arg_idx].var_arg()] = | ||
origin_symbol_args.at(arg_idx); | ||
} else { | ||
new_args_vec.emplace_back(func_args[arg_idx]); | ||
break; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为什么是跳过var? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 因为当前func_arg是按照(tensor_arg1, tensor_arg2, .. tensor_argn, ... var_arg1, var_arg2, .. var_argn)组织的,这一步只收集TensorArg,var_arg(也就是SymbolArg)后面根据原始shape_or_data生成 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 那是不写成tensor_arg的判断代码更容易理解些? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
} | ||
} | ||
for (ir::LoweredFunc& func : context.lowered_funcs_) { | ||
func->args = new_args_vec; | ||
new_args_vec.emplace_back(func_args[arg_idx]); | ||
cur_arg_idx++; | ||
} | ||
}; | ||
for (size_t i = 0; i < contexts->size(); ++i) { | ||
AddTensorArgs((*contexts)[i]); | ||
if (i == 0) total_args_num += new_args_vec.size(); | ||
new_args_vec.clear(); | ||
} | ||
|
||
origin_group->mut_symbol_args_map().clear(); | ||
const auto& new_symbol_args_vec = [&]() -> std::vector<ir::Argument> { | ||
std::vector<ir::Argument> res; | ||
for (const auto& [arg, idx_info] : new_args_map) { | ||
symbolic_shape_var_index->insert({total_args_num, arg}); | ||
origin_group->mut_symbol_args_map()[total_args_num++] = idx_info; | ||
res.emplace_back(ir::Argument{arg}); | ||
} | ||
return res; | ||
}(); | ||
std::unordered_set<std::string> symbol_args_set; | ||
const auto& AddSymbolArgs = [&](::pir::Value input, const int& input_idx) { | ||
enum ArgType { Dim, Value }; | ||
const auto& AddSymbolArgFromDimExprVec = | ||
[&](ArgType arg_type, const std::vector<symbol::DimExpr>& expr_vec) { | ||
int vec_size = expr_vec.size(); | ||
for (int idx = 0; idx < vec_size; idx++) { | ||
if (expr_vec[idx].isa<std::string>()) { | ||
const std::string& symbol_name = | ||
expr_vec[idx].dyn_cast<std::string>(); | ||
if (symbol_args_set.count(symbol_name) != 0) { | ||
continue; | ||
} | ||
symbol_args_set.insert(symbol_name); | ||
const auto& arg = ir::Var(symbol_name, cinn::common::Int(64)); | ||
new_args_vec.emplace_back(ir::Argument{arg}); | ||
symbolic_shape_var_index->insert({cur_arg_idx, arg}); | ||
if (arg_type == Dim) { | ||
origin_group->mut_symbol_args_map()[cur_arg_idx++] = | ||
pir::CINNKernelInfo::ArgDimIdx{input_idx, idx}; | ||
} else { | ||
origin_group->mut_symbol_args_map()[cur_arg_idx++] = | ||
pir::CINNKernelInfo::ArgValueIdx{input_idx, idx}; | ||
} | ||
} | ||
} | ||
}; | ||
const auto& shape_or_data = origin_group->GetShapeOrDataExprs(input); | ||
// Add dim symbol args | ||
AddSymbolArgFromDimExprVec(ArgType::Dim, shape_or_data.shape()); | ||
// Add value symbol args | ||
if (shape_or_data.data()) | ||
AddSymbolArgFromDimExprVec(ArgType::Value, shape_or_data.data().value()); | ||
}; | ||
|
||
const auto& AddUnifiedSymbolArgs = [&](GroupCompilationContext& context) { | ||
const auto& UpdateAllFuncArgs = [&](GroupCompilationContext& context) { | ||
for (ir::LoweredFunc& func : context.lowered_funcs_) { | ||
func->args.insert(func->args.end(), | ||
new_symbol_args_vec.begin(), | ||
new_symbol_args_vec.end()); | ||
func->args = new_args_vec; | ||
} | ||
}; | ||
|
||
AddTensorArgs(); | ||
origin_group->mut_symbol_args_map().clear(); | ||
const auto& group_inputs = pir::GetBlockOutsideInput(origin_group->ops()); | ||
for (size_t input_idx = 0; input_idx < group_inputs.size(); ++input_idx) | ||
AddSymbolArgs(group_inputs[input_idx], input_idx); | ||
for (int i = 0; i < contexts->size(); ++i) { | ||
AddUnifiedSymbolArgs((*contexts)[i]); | ||
UpdateAllFuncArgs((*contexts)[i]); | ||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
什么样的case会同时命中这两个条件?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这段逻辑是想说如果被替换之后出现了new_symbol_set或者base_dim_expr_set里都没有的符号,那么它就是一个不能被子集替换的,很多case都有,比如BC(S0, S1),new_symbol_set里没有,输入符号里也没有S0和S1,那么它就can't BeRepresentedBySubset,需要被替换成新符号