Skip to content

Commit

Permalink
[Opt] Add binary ops optimization pass (taichi-dev#1226)
Browse files Browse the repository at this point in the history
* add optimization for associative binary ops

* fix bugs

* ti format

* add print_evaluator_ir to limit JIT IR output, fix bugs in binary_op_simplify

* fix rearrange associative op check

* export print_evaluator_ir

* update comment

* update according to code review

* update comment for bit ops

* retrigger CI

* fix swapping consts

* [skip ci] remove debug tags
  • Loading branch information
TH3CHARLie authored Jun 13, 2020
1 parent afc852b commit 1e3749f
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 1 deletion.
1 change: 1 addition & 0 deletions taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ void die(IRNode *root);
bool simplify(IRNode *root, Kernel *kernel = nullptr);
void cfg_optimization(IRNode *root);
bool alg_simp(IRNode *root);
bool binary_op_simplify(IRNode *root);
bool whole_kernel_cse(IRNode *root);
void variable_optimization(IRNode *root, bool after_lower_access);
void extract_constant(IRNode *root);
Expand Down
1 change: 1 addition & 0 deletions taichi/program/compile_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ CompileConfig::CompileConfig() {
external_optimization_level = 3;
print_ir = false;
print_accessor_ir = false;
print_evaluator_ir = false;
print_benchmark_stat = false;
use_llvm = true;
print_struct_llvm_ir = false;
Expand Down
1 change: 1 addition & 0 deletions taichi/program/compile_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ struct CompileConfig {
int max_vector_width;
bool print_ir;
bool print_accessor_ir;
bool print_evaluator_ir;
bool print_benchmark_stat;
bool serial_schedule;
bool simplify_before_lower_access;
Expand Down
3 changes: 2 additions & 1 deletion taichi/program/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ void Kernel::lower(bool lower_access) { // TODO: is a "Lowerer" class necessary
auto codegen = KernelCodeGen::create(arch, this);
auto config = program.config;
bool verbose = config.print_ir;
if (is_accessor && !config.print_accessor_ir)
if ((is_accessor && !config.print_accessor_ir) ||
(is_evaluator && !config.print_evaluator_ir))
verbose = false;
irpass::compile_to_offloads(
ir.get(), config, /*vectorize*/ arch_is_cpu(arch), grad,
Expand Down
1 change: 1 addition & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ void export_lang(py::module &m) {
.def_readwrite("debug", &CompileConfig::debug)
.def_readwrite("check_out_of_bound", &CompileConfig::check_out_of_bound)
.def_readwrite("print_accessor_ir", &CompileConfig::print_accessor_ir)
.def_readwrite("print_evaluator_ir", &CompileConfig::print_evaluator_ir)
.def_readwrite("use_llvm", &CompileConfig::use_llvm)
.def_readwrite("print_benchmark_stat",
&CompileConfig::print_benchmark_stat)
Expand Down
105 changes: 105 additions & 0 deletions taichi/transforms/binary_op_simplify.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#include "taichi/ir/ir.h"
#include "taichi/ir/transforms.h"
#include "taichi/ir/visitors.h"
#include "taichi/program/program.h"

TLANG_NAMESPACE_BEGIN

class BinaryOpSimp : public BasicStmtVisitor {
public:
using BasicStmtVisitor::visit;
bool fast_math;
DelayedIRModifier modifier;

explicit BinaryOpSimp(bool fast_math_)
: BasicStmtVisitor(), fast_math(fast_math_) {
}

void visit(BinaryOpStmt *stmt) override {
// swap lhs and rhs if lhs is a const and op is commutative
auto const_lhs = stmt->lhs->cast<ConstStmt>();
if (const_lhs && is_commutative(stmt->op_type) &&
!stmt->rhs->is<ConstStmt>()) {
auto rhs_stmt = stmt->rhs;
stmt->lhs = rhs_stmt;
stmt->rhs = const_lhs;
}
if (!fast_math) {
return;
}
auto binary_lhs = stmt->lhs->cast<BinaryOpStmt>();
auto const_rhs = stmt->rhs->cast<ConstStmt>();
if (!binary_lhs || !const_rhs) {
return;
}
auto const_lhs_rhs = binary_lhs->rhs->cast<ConstStmt>();
if (!const_lhs_rhs || binary_lhs->lhs->is<ConstStmt>()) {
return;
}
// original:
// stmt = (a op1 b) op2 c
// rearrange to:
// stmt = a op1 (b op2 c)
if (can_rearrange_associative(binary_lhs->op_type, stmt->op_type)) {
auto bin_op =
Stmt::make<BinaryOpStmt>(stmt->op_type, const_lhs_rhs, const_rhs);
bin_op->ret_type.data_type = stmt->ret_type.data_type;
auto new_stmt = Stmt::make<BinaryOpStmt>(binary_lhs->op_type,
binary_lhs->lhs, bin_op.get());
new_stmt->ret_type.data_type = stmt->ret_type.data_type;

modifier.insert_before(stmt, std::move(bin_op));
stmt->replace_with(new_stmt.get());
modifier.insert_before(stmt, std::move(new_stmt));
modifier.erase(stmt);
}
}

static bool can_rearrange_associative(BinaryOpType op1, BinaryOpType op2) {
if (op1 == BinaryOpType::add &&
(op2 == BinaryOpType::add || op2 == BinaryOpType::sub)) {
return true;
}
if (op1 == BinaryOpType::mul &&
(op2 == BinaryOpType::mul || op2 == BinaryOpType::div)) {
return true;
}
// for bit operations it only holds when two ops are the same
if ((op1 == BinaryOpType::bit_and || op1 == BinaryOpType::bit_or ||
op1 == BinaryOpType::bit_xor) &&
op1 == op2) {
return true;
}
return false;
}

static bool is_commutative(BinaryOpType op) {
return op == BinaryOpType::add || op == BinaryOpType::mul ||
op == BinaryOpType::bit_and || op == BinaryOpType::bit_or ||
op == BinaryOpType::bit_xor;
}

static bool run(IRNode *node, bool fast_math) {
BinaryOpSimp simplifier(fast_math);
bool modified = false;
while (true) {
node->accept(&simplifier);
if (simplifier.modifier.modify_ir()) {
modified = true;
} else
break;
}
return modified;
}
};

namespace irpass {

bool binary_op_simplify(IRNode *root) {
const auto &config = root->get_kernel()->program.config;
return BinaryOpSimp::run(root, config.fast_math);
}

} // namespace irpass

TLANG_NAMESPACE_END
2 changes: 2 additions & 0 deletions taichi/transforms/simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,8 @@ void full_simplify(IRNode *root, Kernel *kernel) {
while (true) {
bool modified = false;
extract_constant(root);
if (binary_op_simplify(root))
modified = true;
if (constant_fold(root))
modified = true;
if (alg_simp(root))
Expand Down

0 comments on commit 1e3749f

Please sign in to comment.