Skip to content

Commit

Permalink
[Opt] [ir] [refactor] Remove exceptions from demote_atomics pass (tai…
Browse files Browse the repository at this point in the history
…chi-dev#1272)

* remove exceptions from demote_atomics

* [skip ci] enforce code format

* trigger CI

* fix

* add assert in dtor

Co-authored-by: Taichi Gardener <[email protected]>
  • Loading branch information
TH3CHARLie and taichi-gardener authored Jun 19, 2020
1 parent 265e414 commit 4aee4e1
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 13 deletions.
15 changes: 14 additions & 1 deletion taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,9 @@ std::unique_ptr<Block> Block::clone() const {

DelayedIRModifier::~DelayedIRModifier() {
TI_ASSERT(to_insert_before.empty());
TI_ASSERT(to_insert_after.empty());
TI_ASSERT(to_erase.empty());
TI_ASSERT(to_replace_with.empty());
}

void DelayedIRModifier::erase(Stmt *stmt) {
Expand Down Expand Up @@ -847,8 +849,15 @@ void DelayedIRModifier::insert_after(Stmt *old_statement,
to_insert_after.emplace_back(old_statement, std::move(new_statements));
}

void DelayedIRModifier::replace_with(Stmt *stmt,
VecStatement &&new_statements,
bool replace_usages) {
to_replace_with.emplace_back(stmt, std::move(new_statements), replace_usages);
}

bool DelayedIRModifier::modify_ir() {
if (to_insert_before.empty() && to_insert_after.empty() && to_erase.empty())
if (to_insert_before.empty() && to_insert_after.empty() && to_erase.empty() &&
to_replace_with.empty())
return false;
for (auto &i : to_insert_before) {
i.first->parent->insert_before(i.first, std::move(i.second));
Expand All @@ -862,6 +871,10 @@ bool DelayedIRModifier::modify_ir() {
stmt->parent->erase(stmt);
}
to_erase.clear();
for (auto &i : to_replace_with) {
std::get<0>(i)->replace_with(std::move(std::get<1>(i)), std::get<2>(i));
}
to_replace_with.clear();
return true;
}

Expand Down
5 changes: 5 additions & 0 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <unordered_set>
#include <unordered_map>
#include <variant>
#include <tuple>
#include "taichi/common/core.h"
#include "taichi/util/bit.h"
#include "taichi/lang_util.h"
Expand Down Expand Up @@ -894,6 +895,7 @@ class DelayedIRModifier {
private:
std::vector<std::pair<Stmt *, VecStatement>> to_insert_before;
std::vector<std::pair<Stmt *, VecStatement>> to_insert_after;
std::vector<std::tuple<Stmt *, VecStatement, bool>> to_replace_with;
std::vector<Stmt *> to_erase;

public:
Expand All @@ -903,6 +905,9 @@ class DelayedIRModifier {
void insert_before(Stmt *old_statement, VecStatement &&new_statements);
void insert_after(Stmt *old_statement, std::unique_ptr<Stmt> new_statement);
void insert_after(Stmt *old_statement, VecStatement &&new_statements);
void replace_with(Stmt *stmt,
VecStatement &&new_statements,
bool replace_usages = true);
bool modify_ir();
};

Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ void replace_statements_with(IRNode *root,
std::function<bool(Stmt *)> filter,
std::function<std::unique_ptr<Stmt>()> generator);
void demote_dense_struct_fors(IRNode *root);
void demote_atomics(IRNode *root);
bool demote_atomics(IRNode *root);
void reverse_segments(IRNode *root); // for autograd
std::unique_ptr<ScratchPads> initialize_scratch_pad(StructForStmt *root);
void compile_to_offloads(IRNode *ir,
Expand Down
25 changes: 14 additions & 11 deletions taichi/transforms/demote_atomics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class DemoteAtomics : public BasicStmtVisitor {
using BasicStmtVisitor::visit;

OffloadedStmt *current_offloaded;
DelayedIRModifier modifier;

DemoteAtomics() : BasicStmtVisitor() {
current_offloaded = nullptr;
Expand Down Expand Up @@ -68,9 +69,8 @@ class DemoteAtomics : public BasicStmtVisitor {
// old value $d'.
// See also: https://github.com/taichi-dev/taichi/issues/332
stmt->replace_with(load);
stmt->parent->replace_with(stmt, std::move(new_stmts),
/*replace_usages=*/false);
throw IRModified();
modifier.replace_with(stmt, std::move(new_stmts),
/*replace_usages=*/false);
}
}
}
Expand All @@ -83,25 +83,28 @@ class DemoteAtomics : public BasicStmtVisitor {
current_offloaded = nullptr;
}

static void run(IRNode *node) {
static bool run(IRNode *node) {
DemoteAtomics demoter;
bool modified = false;
while (true) {
try {
node->accept(&demoter);
} catch (IRModified) {
continue;
node->accept(&demoter);
if (demoter.modifier.modify_ir()) {
modified = true;
} else {
break;
}
break;
}
return modified;
}
};

namespace irpass {

void demote_atomics(IRNode *root) {
bool demote_atomics(IRNode *root) {
TI_AUTO_PROF;
DemoteAtomics::run(root);
bool modified = DemoteAtomics::run(root);
typecheck(root);
return modified;
}

} // namespace irpass
Expand Down

0 comments on commit 4aee4e1

Please sign in to comment.