Skip to content

Commit

Permalink
constant folding basics
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanming-hu committed Oct 22, 2019
1 parent a965d07 commit a529b83
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 8 deletions.
1 change: 1 addition & 0 deletions lang/src/backends/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,7 @@ void GPUCodeGen::lower() {
if (prog->config.print_ir) {
irpass::print(ir);
}
irpass::constant_fold(ir);
if (prog->config.simplify_before_lower_access) {
irpass::simplify(ir);
irpass::re_id(ir);
Expand Down
1 change: 1 addition & 0 deletions lang/src/backends/llvm_ptx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class CodeGenLLVMGPU : public CodeGenLLVM {
create_naive_range_for(for_stmt);
} else {
offloaded = true;
TC_P(for_stmt->begin->type());
auto loop_begin = for_stmt->begin->as<ConstStmt>()->val[0].val_int32();
auto loop_end = for_stmt->end->as<ConstStmt>()->val[0].val_int32();
auto loop_block_dim = for_stmt->block_size;
Expand Down
8 changes: 6 additions & 2 deletions lang/src/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ void vector_split(IRNode *root, int max_width, bool serial_schedule);
void replace_all_usages_with(IRNode *root, Stmt *old_stmt, Stmt *new_stmt);
void lower_access(IRNode *root, bool lower_atomic);
void make_adjoint(IRNode *root);
void constant_fold(IRNode *root);
std::unique_ptr<ScratchPads> initialize_scratch_pad(StructForStmt *root);

} // namespace irpass
Expand Down Expand Up @@ -248,7 +249,10 @@ class VecStatement {
public:
std::vector<pStmt> stmts;

VecStatement() {
VecStatement() {}

VecStatement(pStmt &&stmt) {
push_back(std::move(stmt));
}

VecStatement(VecStatement &&o) {
Expand Down Expand Up @@ -1377,7 +1381,7 @@ class Block : public IRNode {

void replace_with(Stmt *old_statement, std::unique_ptr<Stmt> &&new_statement);

void insert_before(Stmt *old_statement, VecStatement &new_statements) {
void insert_before(Stmt *old_statement, VecStatement &&new_statements) {
int location = -1;
for (int i = 0; i < (int)statements.size(); i++) {
if (old_statement == statements[i].get()) {
Expand Down
103 changes: 103 additions & 0 deletions lang/src/transforms/constant_fold.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#include "../ir.h"
#include <deque>
#include <set>

TLANG_NAMESPACE_BEGIN

// Visits all non-containing statements
class BasicStmtVisitor : public IRVisitor {
public:
StructForStmt *current_struct_for;

BasicStmtVisitor() {
current_struct_for = nullptr;
allow_undefined_visitor = true;
}

void visit(Block *stmt_list) override {
auto backup_block = current_block;
current_block = stmt_list;
for (auto &stmt : stmt_list->statements) {
stmt->accept(this);
}
current_block = backup_block;
}

void visit(IfStmt *if_stmt) override {
if (if_stmt->true_statements)
if_stmt->true_statements->accept(this);
if (if_stmt->false_statements) {
if_stmt->false_statements->accept(this);
}
}

void visit(WhileStmt *stmt) override {
stmt->body->accept(this);
}

void visit(RangeForStmt *for_stmt) override {
for_stmt->body->accept(this);
}

void visit(StructForStmt *for_stmt) override {
current_struct_for = for_stmt;
for_stmt->body->accept(this);
current_struct_for = nullptr;
}
};

class ConstantFold : public BasicStmtVisitor {
public:
ConstantFold() : BasicStmtVisitor() {
}

void visit(UnaryOpStmt *stmt) {
if (stmt->width() == 1 && stmt->op_type == UnaryOpType::cast &&
stmt->cast_by_value && stmt->operand->is<ConstStmt>()) {
auto input = stmt->operand->as<ConstStmt>()->val[0];
auto src_type = stmt->operand->ret_type.data_type;
auto dst_type = stmt->ret_type.data_type;
TypedConstant new_constant(dst_type);
bool success = false;
if (src_type == DataType::f32) {
auto v = input.val_float32();
if (dst_type == DataType::i32) {
new_constant.val_i32 = int32(v);
success = true;
}
}

if (success) {
auto evaluated = Stmt::make<ConstStmt>(LaneAttribute<TypedConstant>(new_constant));
stmt->replace_with(evaluated.get());
stmt->parent->insert_before(stmt, VecStatement(std::move(evaluated)));
stmt->parent->erase(stmt);
throw IRModified();
}
}
}

static void run(IRNode *node) {
ConstantFold folder;
while (true) {
bool modified = false;
try {
node->accept(&folder);
} catch (IRModified) {
modified = true;
}
if (!modified)
break;
}
}
};

namespace irpass {

void constant_fold(IRNode *root) {
return ConstantFold::run(root);
}

} // namespace irpass

TLANG_NAMESPACE_END
6 changes: 3 additions & 3 deletions lang/src/transforms/lower_access.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class LowerAccess : public IRVisitor {
if (stmt->ptr->is<GlobalPtrStmt>()) {
auto lowered = lower_vector_ptr(stmt->ptr->as<GlobalPtrStmt>(), false);
stmt->ptr = lowered.back().get();
stmt->parent->insert_before(stmt, lowered);
stmt->parent->insert_before(stmt, std::move(lowered));
throw IRModified();
}
}
Expand All @@ -160,7 +160,7 @@ class LowerAccess : public IRVisitor {
if (stmt->ptr->is<GlobalPtrStmt>()) {
auto lowered = lower_vector_ptr(stmt->ptr->as<GlobalPtrStmt>(), true);
stmt->ptr = lowered.back().get();
stmt->parent->insert_before(stmt, lowered);
stmt->parent->insert_before(stmt, std::move(lowered));
throw IRModified();
}
}
Expand All @@ -171,7 +171,7 @@ class LowerAccess : public IRVisitor {
if (stmt->dest->is<GlobalPtrStmt>()) {
auto lowered = lower_vector_ptr(stmt->dest->as<GlobalPtrStmt>(), true);
stmt->dest = lowered.back().get();
stmt->parent->insert_before(stmt, lowered);
stmt->parent->insert_before(stmt, std::move(lowered));
throw IRModified();
}
}
Expand Down
7 changes: 4 additions & 3 deletions tests/python/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def place():

@ti.kernel
def func():
for i in range(N // 2 + 3, N):
for i in range(ti.static(N // 2 + 3), N):
x[i] = ti.abs(y[i])

func()
Expand All @@ -34,6 +34,7 @@ def test_numpy_loops():
for arch in [ti.x86_64, ti.cuda]:
ti.reset()
ti.cfg.arch = arch
ti.cfg.print_ir = True
x = ti.var(ti.f32)
y = ti.var(ti.f32)

Expand All @@ -48,8 +49,8 @@ def place():
y[i] = i - 300

import numpy as np
begin = np.ones(1) * (N // 2 + 3)
end = np.ones(1) * N
begin = (np.ones(1) * (N // 2 + 3)).astype(np.int32)
end = (np.ones(1) * N).astype(np.int32)

@ti.kernel
def func():
Expand Down

0 comments on commit a529b83

Please sign in to comment.