Skip to content

Commit

Permalink
aggressive local store forwarding attempt
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanming-hu committed Oct 22, 2019
1 parent 8a1d8f1 commit 4504293
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 20 deletions.
6 changes: 4 additions & 2 deletions lang/src/backends/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -991,8 +991,10 @@ void GPUCodeGen::lower() {
if (prog->config.simplify_before_lower_access) {
irpass::simplify(ir);
irpass::re_id(ir);
if (prog->config.print_ir)
if (prog->config.print_ir) {
TC_TRACE("Simplified I:");
irpass::print(ir);
}
}
if (kernel->grad) {
// irpass::re_id(ir);
Expand All @@ -1019,7 +1021,7 @@ void GPUCodeGen::lower() {
}
irpass::simplify(ir);
if (prog->config.print_ir) {
TC_TRACE("DupEliminated2:");
TC_TRACE("Simplified II:");
irpass::re_id(ir);
irpass::print(ir);
}
Expand Down
4 changes: 2 additions & 2 deletions lang/src/backends/codegen_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ void CPUCodeGen::lower() {
if (prog->config.simplify_before_lower_access) {
irpass::simplify(ir);
if (prog->config.print_ir) {
TC_TRACE("DupEliminated:");
TC_TRACE("Simplified I:");
irpass::re_id(ir);
irpass::print(ir);
}
Expand Down Expand Up @@ -677,7 +677,7 @@ void CPUCodeGen::lower() {
}
irpass::simplify(ir);
if (prog->config.print_ir) {
TC_TRACE("DupEliminated2:");
TC_TRACE("Simplified II:");
irpass::re_id(ir);
irpass::print(ir);
}
Expand Down
36 changes: 30 additions & 6 deletions lang/src/transforms/simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,15 @@ class BasicBlockSimplify : public IRVisitor {
}
}
if (regular) {
// Check all previous statements in the current block before the local load
// Check all previous statements in the current block before the local
// load
auto block = stmt->parent;
Stmt *containing_statement = stmt;
bool modified = false;
while (true) {
auto stmt_id = stmt->parent->locate(stmt);
TC_ASSERT(current_stmt_id == stmt_id);
for (int i = current_stmt_id - 1; i >= 0; i--) {
auto stmt_id = block->locate(containing_statement);
TC_ASSERT(stmt_id != -1);
for (int i = stmt_id - 1; i >= 0; i--) {
auto &bstmt = block->statements[i];
// Find a previous store
if (bstmt->is<LocalStoreStmt>()) {
Expand All @@ -261,12 +264,33 @@ class BasicBlockSimplify : public IRVisitor {
}
} else if (bstmt->is_container_statement()) {
// assume this container may modify the local var
modified = true;
break;
}
}
break;
// block = block->parent; // TODO: how to find the stmt of this block?
//if ()
/*
// Note: simply checking all statements before stmt is not sufficient
// since statements after stmt may change the value of the alloca
if (modified) break;
// Go to parent level
auto parent_block = block->parent;
if (!parent_block)
break;
Stmt *parent_statement = nullptr;
for (int i = 0; i < parent_block->statements.size(); i++) {
auto s = parent_block->statements[i].get();
if (s->is<RangeForStmt>() &&
s->as<RangeForStmt>()->body.get() == block) {
parent_statement = s;
break;
}
}
if (!parent_statement)
break;
block = parent_block;
containing_statement = parent_statement;
*/
}
}
set_done(stmt);
Expand Down
13 changes: 4 additions & 9 deletions python/taichi/lang/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,12 @@ def visit_block(self, list_stmt):
list_stmt[i] = self.visit(l)

def visit_If(self, node):
old = False
if old:
with self.variable_scope():
node.test = self.visit(node.test)
with self.variable_scope():
self.generic_visit(node)
else:
self.visit_block(node.body)
with self.variable_scope():
node.test = self.visit(node.test)
with self.variable_scope():
self.visit_block(node.body)
with self.variable_scope():
self.visit_block(node.orelse)
self.visit_block(node.orelse)

is_static_if = isinstance(node.test,
ast.Call) and isinstance(node.test.func,
Expand Down
1 change: 0 additions & 1 deletion tests/python/test_loop_grad.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import taichi as ti

def test_loop_grad():
return
for arch in [ti.x86_64, ti.cuda]:
ti.reset()
ti.cfg.arch = arch
Expand Down

0 comments on commit 4504293

Please sign in to comment.