Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN][Backend Pass Update No.7] Update merge_block_utils #70406

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 33 additions & 25 deletions paddle/cinn/optim/merge_block_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,60 +17,68 @@
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/stmt.h"
#include "paddle/common/enforce.h"

namespace cinn {
namespace optim {

namespace {
using ir::stmt::BlockRef;
using ir::stmt::For;
using ir::stmt::StmtRef;

struct ForInfoAnalyzer : public ir::IRMutator<Expr*> {
struct ForHash {
std::size_t operator()(const For& stmt) const {
return std::hash<const Object*>()(stmt.get());
}
};

struct ForInfoAnalyzer {
public:
void operator()(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
void operator()(const For& for_stmt) { Visit(for_stmt); }

ForTreeNode BuildTreeNode(const ir::For* node) {
ForTreeNode BuildTreeNode(const For& node) {
ForTreeNode tree_node = {node, std::vector<ForTreeNode>()};
for (const auto for_node : for_to_children_[node]) {
tree_node.children.push_back(BuildTreeNode(for_node));
for (const For& stmt : for_to_children_[node]) {
tree_node.children.push_back(BuildTreeNode(stmt));
}
return tree_node;
}

ForTreeNode GetRootTreeNode() { return BuildTreeNode(root_node_); }

private:
void Visit(const ir::For* node, ir::Expr* expr) override {
auto old_last_node = last_node_;
if (last_node_ == nullptr) {
void Visit(const For& node) {
if (root_node_ == nullptr) {
root_node_ = node;
} else {
for_to_children_[last_node_].push_back(node);
}
last_node_ = const_cast<ir::For*>(node);
ir::IRMutator<>::Visit(node, expr);
last_node_ = old_last_node;
const BlockRef& body = node->body();
for (const StmtRef& stmt : body->stmts()) {
if (stmt.isa<For>()) {
for_to_children_[node].push_back(stmt.as<For>());
Visit(stmt.as<For>());
}
}
}

ir::For* last_node_ = nullptr;
const ir::For* root_node_ = nullptr;
std::unordered_map<const ir::For*, std::vector<const ir::For*>>
for_to_children_;
private:
For root_node_{nullptr};
std::unordered_map<For, std::vector<For>, ForHash> for_to_children_;
};

} // namespace

bool CanMergeBlocks(const ir::For* first,
const ir::For* second,
bool CanMergeBlocks(const For first,
const For second,
const ForEqualFunc& IsEqual) {
auto Get = [&](ir::Expr* expr) -> ForTreeNode {
auto Get = [&](const For for_stmt) -> ForTreeNode {
ForInfoAnalyzer for_info_analyzer;
for_info_analyzer(expr);
for_info_analyzer(for_stmt);
return for_info_analyzer.GetRootTreeNode();
};
ir::Expr first_expr = Expr(const_cast<ir::For*>(first));
ir::Expr second_expr = Expr(const_cast<ir::For*>(second));
const auto first_inner_for_list = Get(&first_expr);
const auto second_inner_for_list = Get(&second_expr);
const auto first_inner_for_list = Get(first);
const auto second_inner_for_list = Get(second);
return IsEqual(first_inner_for_list, second_inner_for_list);
}

Expand Down
89 changes: 43 additions & 46 deletions paddle/cinn/optim/merge_block_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,71 +14,68 @@

#pragma once

#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/stmt.h"

namespace cinn {
namespace optim {

struct ForTreeNode {
const ir::For* val;
const ir::stmt::For val;
std::vector<ForTreeNode> children;
};

using ForEqualFunc =
std::function<bool(const ForTreeNode&, const ForTreeNode&)>;

/**
/*
* Determines if two blocks of code with nested for-loops have identical loop
extents and can be merged.

* extents and can be merged.
*
* This pass is applicable in scenarios where there are multiple code blocks
with nested for-loops,
* and we need to determine if these blocks can be consolidated to simplify the
code structure.

* with nested for-loops, and we need to determine if these blocks can be
* consolidated to simplify the code structure.
*
* When applied, this pass will not directly modify the IR but serves as a
prerequisite check
* to ensure that loop extents match. If they do, a separate merging process can
be safely conducted
* to combine the blocks into a single block with shared loop structures.

* prerequisite check to ensure that loop extents match. If they do, a separate
* merging process can be safely conducted to combine the blocks into a single
* block with shared loop structures.
*
* Performance impact: This pass itself does not directly impact performance but
enables further
* optimizations by identifying mergeable loop structures, which can reduce code
size and potentially
* improve cache efficiency by consolidating similar data processing tasks.

* Examples:
* 1. Simple identical loops:
* Input IR:
* block(var_B)
* for(i, 0, 10)
* for(j, 0, 10)
* B[i,j] = A[i,j]
* enables further optimizations by identifying mergeable loop structures, which
* can reduce code size and potentially improve cache efficiency by
* consolidating similar data processing tasks.
*
* block(var_C)
* for(i, 0, 10)
* for(j, 0, 10)
* C[i,j] = A[i,j]
* Output IR:
* Can be merged since loop extents are identical.
* Examples:
*
* 2. Different loop extents:
* Input IR:
* block(var_B)
* for(i, 0, 10)
* for(j, 0, 10)
* B[i,j] = A[i,j]
* Simple identical loops:
* Input IR:
* block(var_B)
* for(i, 0, 10)
* for(j, 0, 10)
* B[i,j] = A[i,j]
* block(var_C)
* for(i, 0, 10)
* for(j, 0, 10)
* C[i,j] = A[i,j]
* Output IR:
* Can be merged since loop extents are identical.
*
* block(var_C)
* for(i, 0, 3)
* for(j, 0, 4)
* C[i,j] = A[i,j]
* Output IR:
* Cannot be merged due to differing loop extents.
* Different loop extents:
* Input IR:
* block(var_B)
* for(i, 0, 10)
* for(j, 0, 10)
* B[i,j] = A[i,j]
* block(var_C)
* for(i, 0, 3)
* for(j, 0, 4)
* C[i,j] = A[i,j]
* Output IR:
* Cannot be merged due to differing loop extents.
*/
bool CanMergeBlocks(const ir::For* first,
const ir::For* second,

bool CanMergeBlocks(const ir::stmt::For first,
const ir::stmt::For second,
const ForEqualFunc& IsEqual);

} // namespace optim
Expand Down
83 changes: 40 additions & 43 deletions test/cpp/pir/cinn/adt/merge_block_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ namespace {
bool IsBlockForAllEqual(const ForTreeNode& first, const ForTreeNode& second) {
auto ForVarExtentEqual = [&](const ForTreeNode& first,
const ForTreeNode& second) -> bool {
const ir::Expr lhs = first.val->extent;
const ir::Expr rhs = second.val->extent;
const ir::Expr lhs = first.val->extent();
const ir::Expr rhs = second.val->extent();
if (cinn::common::AutoSimplify(ir::Sub::Make(lhs, rhs)) != ir::Expr(0)) {
return false;
}
Expand All @@ -46,74 +46,71 @@ bool IsBlockForAllEqual(const ForTreeNode& first, const ForTreeNode& second) {
return true;
}

ir::Expr MakeForLoops(const std::vector<int> extents, int index) {
if (index >= extents.size()) {
ir::Expr sb = ir::ScheduleBlock::Make(std::vector<Var>(),
std::vector<Expr>(),
std::vector<Expr>(),
"block",
ir::Expr(0));
return sb;
ir::stmt::For MakeForLoops(const std::vector<int> extents, int index) {
ir::stmt::StmtRef body_stmt;
if (index == extents.size() - 1) {
body_stmt = ir::stmt::Schedule(std::vector<Var>(),
std::vector<Expr>(),
std::vector<Expr>(),
std::vector<Expr>(),
"block",
ir::stmt::BlockRef(0));
} else {
body_stmt = MakeForLoops(extents, index + 1);
}

ir::Expr extent = ir::Expr(extents.at(index));
ir::Expr for_expr = ir::For::Make(ir::Var("i"),
ir::Expr(0),
extent,
ir::ForType::Serial,
ir::DeviceAPI::CUDA,
MakeForLoops(extents, index + 1),
ir::VectorizeInfo(),
ir::BindInfo());

return for_expr;
std::vector<ir::stmt::StmtRef> body = {body_stmt};
return ir::stmt::For(ir::Var("i"),
ir::Expr(0),
ir::Expr(extents[index]),
ir::ForType::Serial,
ir::DeviceAPI::CUDA,
ir::stmt::BlockRef(body),
ir::VectorizeInfo(),
ir::BindInfo());
}

void TestHelper(const std::vector<int>& extents1,
const std::vector<int>& extents2,
bool is_same) {
auto for_loop1 = MakeForLoops(extents1, 0);
auto for_loop2 = MakeForLoops(extents2, 0);
auto f1 = for_loop1.As<ir::For>();
auto f2 = for_loop2.As<ir::For>();

if (is_same) {
EXPECT_TRUE(CanMergeBlocks(f1, f2, IsBlockForAllEqual));
EXPECT_TRUE(CanMergeBlocks(for_loop1, for_loop2, IsBlockForAllEqual));
} else {
EXPECT_FALSE(CanMergeBlocks(f1, f2, IsBlockForAllEqual));
EXPECT_FALSE(CanMergeBlocks(for_loop1, for_loop2, IsBlockForAllEqual));
}
}

void TestHelper2(const std::vector<std::vector<int>>& extents1,
const std::vector<std::vector<int>>& extents2,
bool is_same) {
auto MakeNestLoops =
[&](const std::vector<std::vector<int>>& extents) -> ir::Expr {
std::vector<ir::Expr> for_loops;
[&](const std::vector<std::vector<int>>& extents) -> ir::stmt::For {
std::vector<ir::stmt::StmtRef> for_loops;
for (size_t i = 0; i < extents.size(); ++i) {
for_loops.push_back(MakeForLoops(extents[i], 0));
}
ir::Expr block = ir::Block::Make(for_loops);
ir::Expr for_expr = ir::For::Make(ir::Var("i"),
ir::Expr(0),
ir::Expr(1),
ir::ForType::Serial,
ir::DeviceAPI::CUDA,
block,
ir::VectorizeInfo(),
ir::BindInfo());
return for_expr;
ir::stmt::BlockRef block(for_loops);
ir::stmt::For for_stmt = ir::stmt::For(ir::Var("i"),
ir::Expr(0),
ir::Expr(1),
ir::ForType::Serial,
ir::DeviceAPI::CUDA,
block,
ir::VectorizeInfo(),
ir::BindInfo());
return for_stmt;
};

auto for_expr1 = MakeNestLoops(extents1);
auto for_expr2 = MakeNestLoops(extents2);
auto f1 = for_expr1.As<ir::For>();
auto f2 = for_expr2.As<ir::For>();
auto for_stmt1 = MakeNestLoops(extents1);
auto for_stmt2 = MakeNestLoops(extents2);

if (is_same) {
EXPECT_TRUE(CanMergeBlocks(f1, f2, IsBlockForAllEqual));
EXPECT_TRUE(CanMergeBlocks(for_stmt1, for_stmt2, IsBlockForAllEqual));
} else {
EXPECT_FALSE(CanMergeBlocks(f1, f2, IsBlockForAllEqual));
EXPECT_FALSE(CanMergeBlocks(for_stmt1, for_stmt2, IsBlockForAllEqual));
}
}

Expand Down