From 4c2599307e62c522723cf9fd08e113221ed35274 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Mon, 22 Jun 2020 17:21:36 -0400 Subject: [PATCH] [ir] [refactor] Simplify the "re_id" pass (#1304) --- taichi/transforms/re_id.cpp | 68 +++++-------------------------------- 1 file changed, 8 insertions(+), 60 deletions(-) diff --git a/taichi/transforms/re_id.cpp b/taichi/transforms/re_id.cpp index 0e2fdf63a318c..63b9870be7d42 100644 --- a/taichi/transforms/re_id.cpp +++ b/taichi/transforms/re_id.cpp @@ -1,94 +1,42 @@ #include "taichi/ir/ir.h" #include "taichi/ir/transforms.h" #include "taichi/ir/visitors.h" -#include "taichi/ir/frontend_ir.h" TLANG_NAMESPACE_BEGIN // This pass manipulates the id of statements so that they are successive values // starting from 0 -class ReId : public IRVisitor { +class ReId : public BasicStmtVisitor { public: int id_counter; - ReId(IRNode *node) { + ReId() : id_counter(0) { allow_undefined_visitor = true; invoke_default_visitor = true; - id_counter = 0; - node->accept(this); } void re_id(Stmt *stmt) { stmt->id = id_counter++; } - void visit(Stmt *stmt) { + void visit(Stmt *stmt) override { re_id(stmt); } - void visit(Block *stmt_list) { // block itself has no id - for (auto &stmt : stmt_list->statements) { - stmt->accept(this); - } - } - - void visit(IfStmt *if_stmt) { - re_id(if_stmt); - if (if_stmt->true_statements) - if_stmt->true_statements->accept(this); - if (if_stmt->false_statements) { - if_stmt->false_statements->accept(this); - } - } - - void visit(FrontendIfStmt *if_stmt) { - re_id(if_stmt); - if (if_stmt->true_statements) - if (if_stmt->true_statements) - if_stmt->true_statements->accept(this); - if (if_stmt->false_statements) { - if_stmt->false_statements->accept(this); - } - } - - void visit(WhileStmt *stmt) { - re_id(stmt); - stmt->body->accept(this); - } - - void visit(FrontendWhileStmt *stmt) { + void preprocess_container_stmt(Stmt *stmt) override { re_id(stmt); - stmt->body->accept(this); } - void visit(FrontendForStmt *for_stmt) { - re_id(for_stmt); - for_stmt->body->accept(this); - } - - void visit(RangeForStmt *for_stmt) { - re_id(for_stmt); - for_stmt->body->accept(this); - } - - void visit(StructForStmt *for_stmt) { - re_id(for_stmt); - for_stmt->body->accept(this); - } - - void visit(OffloadedStmt *stmt) { - re_id(stmt); - if (stmt->body) - stmt->body->accept(this); + static void run(IRNode *node) { + ReId instance; + node->accept(&instance); } }; namespace irpass { - void re_id(IRNode *root) { - ReId instance(root); + ReId::run(root); } - } // namespace irpass TLANG_NAMESPACE_END