Skip to content

Commit

Permalink
fix different fn_name (PaddlePaddle#64771)
Browse files Browse the repository at this point in the history
  • Loading branch information
chen2016013 authored May 31, 2024
1 parent d2117f9 commit 810fd5b
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 40 deletions.
37 changes: 20 additions & 17 deletions paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,23 +110,26 @@ OpLoweringGroupPtr BuildOpLoweringGroup(pir::Operation* fusion_op_ptr) {
: group_op_kind;
}
}

auto group = std::make_shared<OpLoweringGroup>(ops);

if (fusion_op.attributes().count("group_info")) {
auto attr = fusion_op.attribute("group_info")
.dyn_cast<cinn::dialect::GroupInfoAttribute>()
.data();

group_op_kind =
static_cast<int>(attr.op_pattern_kind) > static_cast<int>(group_op_kind)
? attr.op_pattern_kind
: group_op_kind;
group->set_loop_ranges(attr.loop_ranges);
group->set_loop_ranges_expr(attr.loop_ranges_expr);
group->set_reduce_axis(attr.reduce_axis);
group->set_alignment_schedule_info(attr.alignment_schedule_info);
}
PADDLE_ENFORCE_GT(fusion_op.attributes().count("group_info"),
0UL,
phi::errors::InvalidArgument(
"fusion_op should have group_info attribute."));

const auto attr = fusion_op.attribute("group_info")
.dyn_cast<cinn::dialect::GroupInfoAttribute>()
.data();

const auto& fn_name = attr.fn_name;
auto group = std::make_shared<OpLoweringGroup>(ops, fn_name);

group_op_kind =
static_cast<int>(attr.op_pattern_kind) > static_cast<int>(group_op_kind)
? attr.op_pattern_kind
: group_op_kind;
group->set_loop_ranges(attr.loop_ranges);
group->set_loop_ranges_expr(attr.loop_ranges_expr);
group->set_reduce_axis(attr.reduce_axis);
group->set_alignment_schedule_info(attr.alignment_schedule_info);
group->set_op_pattern_kind(group_op_kind);

// Rebuild output_ops and input_ops of the group
Expand Down
3 changes: 2 additions & 1 deletion paddle/cinn/hlir/framework/pir/op_lowering_group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,9 @@ std::shared_ptr<OpLoweringGroup> OpLoweringGroup::Clone(
ops_mapper[op] = new_op;
}

const auto new_fn_name = this->fn_name_ + "_cloned";
// Construct Base information for new Group
auto new_group = std::make_shared<OpLoweringGroup>(new_ops);
auto new_group = std::make_shared<OpLoweringGroup>(new_ops, new_fn_name);
for (auto* op : this->output_ops_) {
new_group->output_ops_.insert(ops_mapper.at(op));
}
Expand Down
14 changes: 6 additions & 8 deletions paddle/cinn/hlir/framework/pir/op_lowering_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,13 @@ class OpLoweringGroup {
OpLoweringGroup(const OpLoweringGroup&) = delete;
OpLoweringGroup(OpLoweringGroup&&) = delete;

explicit OpLoweringGroup(const std::vector<::pir::Operation*>& group_ops)
: ops_(group_ops) {
fn_name_ = CompatibleInfo::GroupOpsName(ops_);
}
explicit OpLoweringGroup(const std::vector<::pir::Operation*>& group_ops,
const std::string& fn_name)
: ops_(group_ops), fn_name_(fn_name) {}

explicit OpLoweringGroup(std::initializer_list<::pir::Operation*> group_ops)
: ops_(group_ops) {
fn_name_ = CompatibleInfo::GroupOpsName(ops_);
}
explicit OpLoweringGroup(std::initializer_list<::pir::Operation*> group_ops,
const std::string& fn_name)
: ops_(group_ops), fn_name_(fn_name) {}

const std::string& FuncName() const { return this->fn_name_; }
::pir::Block* GetParentBlock() const;
Expand Down
7 changes: 6 additions & 1 deletion test/cpp/pir/cinn/compilation_task_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h"
#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h"
#include "paddle/cinn/hlir/framework/pir/compilation_task.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/cinn/hlir/framework/pir_compiler.h"
#include "paddle/cinn/utils/data_util.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
Expand All @@ -34,6 +35,7 @@

PD_DECLARE_bool(cinn_bucket_compile);

using cinn::hlir::framework::pir::CompatibleInfo;
using cinn::hlir::framework::pir::OpLoweringGroup;
using cinn::hlir::framework::pir::OpLoweringGroupPtr;

Expand All @@ -50,8 +52,11 @@ ProgramInfo BuildProgram(std::vector<int64_t> input_shape) {
input_shape, value_one, phi::DataType::FLOAT32, phi::GPUPlace());

std::vector<OpLoweringGroupPtr> groups;
const std::string fn_name = CompatibleInfo::GroupOpsName(
std::initializer_list<::pir::Operation*>({full_op_x.operation()}));
groups.emplace_back(std::make_shared<OpLoweringGroup>(
std::initializer_list<::pir::Operation*>({full_op_x.operation()})));
std::initializer_list<::pir::Operation*>({full_op_x.operation()}),
fn_name));
groups.back()->mut_output_ops().insert(full_op_x.operation());

return {program, groups};
Expand Down
26 changes: 19 additions & 7 deletions test/cpp/pir/cinn/pir_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h"
#include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h"
#include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/cinn/hlir/framework/pir_compiler.h"
#include "paddle/cinn/utils/data_util.h"
#include "paddle/fluid/framework/new_executor/interpretercore.h"
Expand All @@ -38,6 +39,7 @@
#include "paddle/pir/include/core/program.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h"

using cinn::hlir::framework::pir::CompatibleInfo;
using cinn::hlir::framework::pir::OpLoweringGroup;
using cinn::hlir::framework::pir::OpLoweringGroupPtr;

Expand Down Expand Up @@ -74,18 +76,26 @@ ProgramInfo BuildProgram() {
builder.Build<pir::YieldOp>(std::vector<pir::Value>{relu_op_y.result(0)});

std::vector<OpLoweringGroupPtr> groups;
const auto full_op_x_ops =
std::initializer_list<::pir::Operation*>({full_op_x.operation()});
groups.emplace_back(std::make_shared<OpLoweringGroup>(
std::initializer_list<::pir::Operation*>(
{full_op_x.operation()}))); // For coverage
full_op_x_ops,
CompatibleInfo::GroupOpsName(full_op_x_ops))); // For coverage
groups[0]->mut_output_values().push_back(groups[0]->ops().back()->result(0));

const auto full_op_y_ops =
std::initializer_list<::pir::Operation*>({full_op_x.operation()});
groups.emplace_back(std::make_shared<OpLoweringGroup>(
std::initializer_list<::pir::Operation*>({full_op_y.operation()})));
full_op_y_ops, CompatibleInfo::GroupOpsName(full_op_y_ops)));

groups[1]->mut_output_values().push_back(groups[1]->ops().back()->result(0));
groups.emplace_back(std::make_shared<OpLoweringGroup>(
const auto vector_ops =
std::vector<::pir::Operation*>({tan_op_x.operation(),
relu_op_x.operation(),
tan_op_y.operation(),
relu_op_y.operation()})));
relu_op_y.operation()});
groups.emplace_back(std::make_shared<OpLoweringGroup>(
vector_ops, CompatibleInfo::GroupOpsName(vector_ops)));
groups[2]->mut_output_values().push_back(groups[2]->ops().back()->result(0));

return {program, groups};
Expand Down Expand Up @@ -127,14 +137,16 @@ ProgramInfo BuildSoftmax() {
auto yield_op = builder.Build<pir::YieldOp>(std::vector<pir::Value>{divide});

std::vector<OpLoweringGroupPtr> groups;
groups.emplace_back(std::make_shared<OpLoweringGroup>(
const auto vector_ops =
std::initializer_list<::pir::Operation*>({max.defining_op(),
broadcast_1.defining_op(),
sub.defining_op(),
exp.defining_op(),
sum.defining_op(),
broadcast_2.defining_op(),
divide.defining_op()})));
divide.defining_op()});
groups.emplace_back(std::make_shared<OpLoweringGroup>(
vector_ops, CompatibleInfo::GroupOpsName(vector_ops)));
groups[0]->mut_output_values().push_back(groups[0]->ops().back()->result(0));
groups[0]->set_op_pattern_kind(cinn::hlir::framework::kReduction);

Expand Down
18 changes: 12 additions & 6 deletions test/cpp/pir/cinn/symbolic_lower_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "paddle/cinn/hlir/framework/pir/group.h"
#include "paddle/cinn/hlir/framework/pir/op_lowering_group.h"
#include "paddle/cinn/hlir/framework/pir/op_lowering_impl.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/cinn/hlir/framework/pir_compiler.h"
#include "paddle/common/ddim.h"
#include "paddle/fluid/framework/new_executor/interpretercore.h"
Expand All @@ -39,6 +40,7 @@

PD_DECLARE_bool(cinn_bucket_compile);

using cinn::hlir::framework::pir::CompatibleInfo;
using cinn::hlir::framework::pir::OpLoweringGroup;
using cinn::hlir::framework::pir::OpLoweringGroupPtr;

Expand Down Expand Up @@ -88,9 +90,11 @@ BuildGroupProgramForLowering() {
builder.Build<paddle::dialect::FetchOp>(group_op->result(0), "out", 0);

std::vector<OpLoweringGroupPtr> groups;
groups.emplace_back(
std::make_shared<OpLoweringGroup>(std::vector<::pir::Operation*>(
{exp.operation(), reshape.operation(), sub.operation()})));
groups.emplace_back(std::make_shared<OpLoweringGroup>(
std::vector<::pir::Operation*>(
{exp.operation(), reshape.operation(), sub.operation()}),
CompatibleInfo::GroupOpsName(std::vector<::pir::Operation*>(
{exp.operation(), reshape.operation(), sub.operation()}))));
groups[0]->mut_output_ops().insert(groups[0]->ops().back());
std::unordered_map<::pir::Value, symbol::ShapeOrDataDimExprs>
value_to_shape_data;
Expand Down Expand Up @@ -176,9 +180,11 @@ BuildBroadcastGroupProgramForLowering() {
builder.Build<paddle::dialect::FetchOp>(group_op->result(0), "out", 0);

std::vector<OpLoweringGroupPtr> groups;
groups.emplace_back(
std::make_shared<OpLoweringGroup>(std::vector<::pir::Operation*>(
{x_broadcast.operation(), sub.operation()})));
groups.emplace_back(std::make_shared<OpLoweringGroup>(
std::vector<::pir::Operation*>(
{x_broadcast.operation(), sub.operation()}),
CompatibleInfo::GroupOpsName(std::vector<::pir::Operation*>(
{x_broadcast.operation(), sub.operation()}))));
groups[0]->mut_output_ops().insert(groups[0]->ops().back());

std::unordered_map<::pir::Value, symbol::ShapeOrDataDimExprs>
Expand Down

0 comments on commit 810fd5b

Please sign in to comment.