Skip to content

Commit

Permalink
[refactor] [ir] IR system refactorings (taichi-dev#1058)
Browse files Browse the repository at this point in the history
* [refactor] [ir] IR system refactorings part 1

* format

* fix segfault

* refactor typecheck signature and misc

* fix

* fix build

* use utility function as

* fix init

* fix segfault to pass CI

* format again

* retrigger CI

* modify test with fake kernel

* format

* remove CompileConfig from some transform passes

* Update taichi/transforms/alg_simp.cpp

Co-authored-by: xumingkuan <[email protected]>

Co-authored-by: xumingkuan <[email protected]>
  • Loading branch information
TH3CHARLie and xumingkuan authored May 26, 2020
1 parent 387c823 commit 2bf0bee
Show file tree
Hide file tree
Showing 19 changed files with 112 additions and 58 deletions.
4 changes: 2 additions & 2 deletions taichi/analysis/clone.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,15 @@ class IRCloner : public IRVisitor {

static std::unique_ptr<IRNode> run(IRNode *root, Kernel *kernel) {
if (kernel == nullptr) {
kernel = &get_current_program().get_current_kernel();
kernel = root->get_kernel();
}
std::unique_ptr<IRNode> new_root = root->clone();
IRCloner cloner(new_root.get());
cloner.phase = IRCloner::register_operand_map;
root->accept(&cloner);
cloner.phase = IRCloner::replace_operand;
root->accept(&cloner);
irpass::typecheck(new_root.get(), kernel);
irpass::typecheck(new_root.get());
irpass::fix_block_parents(new_root.get());
return new_root;
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1031,7 +1031,7 @@ CodeGen::CodeGen(Kernel *kernel,
FunctionType CodeGen::compile() {
auto &config = kernel_->program.config;
config.demote_dense_struct_fors = true;
irpass::compile_to_offloads(kernel_->ir, config,
irpass::compile_to_offloads(kernel_->ir.get(), config,
/*vectorize=*/false, kernel_->grad,
/*ad_use_stack=*/false, config.print_ir);

Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/opengl/codegen_opengl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ FunctionType OpenglCodeGen::gen(void) {
}

void OpenglCodeGen::lower() {
auto ir = kernel_->ir;
auto ir = kernel_->ir.get();
auto &config = kernel_->program.config;
config.demote_dense_struct_fors = true;
irpass::compile_to_offloads(ir, config,
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ TLANG_NAMESPACE_BEGIN
KernelCodeGen::KernelCodeGen(Kernel *kernel, IRNode *ir)
: prog(&kernel->program), kernel(kernel), ir(ir) {
if (ir == nullptr)
this->ir = kernel->ir;
this->ir = kernel->ir.get();

auto num_stmts = irpass::analysis::count_statements(this->ir);
if (kernel->is_evaluator)
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ CodeGenLLVM::CodeGenLLVM(Kernel *kernel, IRNode *ir)
ir(ir),
prog(&kernel->program) {
if (ir == nullptr)
this->ir = kernel->ir;
this->ir = kernel->ir.get();
initialize_context();

context_ty = get_runtime_type("Context");
Expand Down
19 changes: 19 additions & 0 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,14 @@ IRNode *Stmt::get_ir_root() {
return dynamic_cast<IRNode *>(block);
}

Kernel *Stmt::get_kernel() const {
if (parent) {
return parent->get_kernel();
} else {
return nullptr;
}
}

std::vector<Stmt *> Stmt::get_operands() const {
std::vector<Stmt *> ret;
for (int i = 0; i < num_operands(); i++) {
Expand Down Expand Up @@ -706,6 +714,17 @@ Stmt *Block::mask() {
}
}

Kernel *Block::get_kernel() const {
Block *parent = this->parent;
if (parent == nullptr) {
return kernel;
}
while (parent->parent) {
parent = parent->parent;
}
return parent->kernel;
}

void Block::set_statements(VecStatement &&stmts) {
statements.clear();
for (int i = 0; i < (int)stmts.size(); i++) {
Expand Down
8 changes: 8 additions & 0 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,9 @@ class IRNode {
virtual void accept(IRVisitor *visitor) {
TI_NOT_IMPLEMENTED
}
virtual Kernel *get_kernel() const {
return nullptr;
}
virtual ~IRNode() = default;

template <typename T>
Expand Down Expand Up @@ -553,6 +556,8 @@ class Stmt : public IRNode {

IRNode *get_ir_root();

Kernel *get_kernel() const override;

virtual void repeat(int factor) {
ret_type.width *= factor;
}
Expand Down Expand Up @@ -809,6 +814,7 @@ class Block : public IRNode {
std::vector<std::unique_ptr<Stmt>> statements, trash_bin;
Stmt *mask_var;
std::vector<SNode *> stop_gradients;
Kernel *kernel;

// Only used in frontend. Stores LoopIndexStmt or BinaryOpStmt for loop
// variables, and AllocaStmt for other variables.
Expand All @@ -817,6 +823,7 @@ class Block : public IRNode {
Block() {
mask_var = nullptr;
parent = nullptr;
kernel = nullptr;
}

bool has_container_statements();
Expand All @@ -838,6 +845,7 @@ class Block : public IRNode {
bool replace_usages = true);
Stmt *lookup_var(const Identifier &ident) const;
Stmt *mask();
Kernel *get_kernel() const override;

Stmt *back() const {
return statements.back().get();
Expand Down
9 changes: 4 additions & 5 deletions taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,15 @@ void re_id(IRNode *root);
void flag_access(IRNode *root);
void die(IRNode *root);
void simplify(IRNode *root, Kernel *kernel = nullptr);
bool alg_simp(IRNode *root, const CompileConfig &config);

bool alg_simp(IRNode *root);
void whole_kernel_cse(IRNode *root);
void variable_optimization(IRNode *root, bool after_lower_access);
void extract_constant(IRNode *root);
void full_simplify(IRNode *root,
const CompileConfig &config,
Kernel *kernel = nullptr);
void full_simplify(IRNode *root, Kernel *kernel = nullptr);
void print(IRNode *root, std::string *output = nullptr);
void lower(IRNode *root);
void typecheck(IRNode *root, Kernel *kernel = nullptr);
void typecheck(IRNode *root);
void loop_vectorize(IRNode *root);
void slp_vectorize(IRNode *root);
void vector_split(IRNode *root, int max_width, bool serial_schedule);
Expand Down
6 changes: 3 additions & 3 deletions taichi/program/async_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void ExecutionQueue::enqueue(KernelLaunchRecord &&ker) {
flag_access(stmt);
lower_access(stmt, true, kernel);
flag_access(stmt);
full_simplify(stmt, kernel->program.config, kernel);
full_simplify(stmt, kernel);
// analysis::verify(stmt);
}
auto func = CodeGenCPU(kernel, stmt).codegen();
Expand Down Expand Up @@ -108,7 +108,7 @@ ExecutionQueue::ExecutionQueue()
void AsyncEngine::launch(Kernel *kernel) {
if (!kernel->lowered)
kernel->lower(false);
auto block = dynamic_cast<Block *>(kernel->ir);
auto block = dynamic_cast<Block *>(kernel->ir.get());
TI_ASSERT(block);
auto &offloads = block->statements;
for (std::size_t i = 0; i < offloads.size(); i++) {
Expand Down Expand Up @@ -266,7 +266,7 @@ bool AsyncEngine::fuse() {
irpass::fix_block_parents(task_a);

auto kernel = task_queue[i].kernel;
irpass::full_simplify(task_a, kernel->program.config, kernel);
irpass::full_simplify(task_a, kernel);
task_queue[i].h = hash(task_a);

modified = true;
Expand Down
6 changes: 3 additions & 3 deletions taichi/program/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ Kernel::Kernel(Program &program,
is_evaluator = false;
compiled = nullptr;
taichi::lang::context = std::make_unique<FrontendContext>();
ir_holder = taichi::lang::context->get_root();
ir = ir_holder.get();
ir = taichi::lang::context->get_root();

{
CurrentKernelGuard _(program, this);
program.start_function_definition(this);
func();
program.end_function_definition();
ir->as<Block>()->kernel = this;
}

arch = program.config.arch;
Expand Down Expand Up @@ -74,7 +74,7 @@ void Kernel::lower(bool lower_access) { // TODO: is a "Lowerer" class necessary
if (is_accessor && !config.print_accessor_ir)
verbose = false;
irpass::compile_to_offloads(
ir, config, /*vectorize*/ arch_is_cpu(arch), grad,
ir.get(), config, /*vectorize*/ arch_is_cpu(arch), grad,
/*ad_use_stack*/ true, verbose, /*lower_global_access*/ lower_access);
} else {
TI_NOT_IMPLEMENTED
Expand Down
3 changes: 1 addition & 2 deletions taichi/program/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ class Program;

class Kernel {
public:
std::unique_ptr<IRNode> ir_holder;
IRNode *ir;
std::unique_ptr<IRNode> ir;
Program &program;
FunctionType compiled;
std::string name;
Expand Down
3 changes: 2 additions & 1 deletion taichi/transforms/alg_simp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ class AlgSimp : public BasicStmtVisitor {

namespace irpass {

bool alg_simp(IRNode *root, const CompileConfig &config) {
bool alg_simp(IRNode *root) {
const auto &config = root->get_kernel()->program.config;
return AlgSimp::run(root, config.fast_math);
}

Expand Down
8 changes: 4 additions & 4 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ void compile_to_offloads(IRNode *ir,

if (grad) {
irpass::demote_atomics(ir);
irpass::full_simplify(ir, config);
irpass::full_simplify(ir);
irpass::make_adjoint(ir, ad_use_stack);
irpass::full_simplify(ir, config);
irpass::full_simplify(ir);
print("Adjoint");
irpass::analysis::verify(ir);
}
Expand Down Expand Up @@ -91,7 +91,7 @@ void compile_to_offloads(IRNode *ir,
irpass::analysis::verify(ir);
}

irpass::full_simplify(ir, config);
irpass::full_simplify(ir);
print("Simplified II");
irpass::analysis::verify(ir);

Expand Down Expand Up @@ -122,7 +122,7 @@ void compile_to_offloads(IRNode *ir,
irpass::variable_optimization(ir, true);
print("Store forwarded II");

irpass::full_simplify(ir, config);
irpass::full_simplify(ir);
print("Simplified III");

// Final field registration correctness & type checking
Expand Down
13 changes: 8 additions & 5 deletions taichi/transforms/constant_fold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,14 @@ class ConstantFold : public BasicStmtVisitor {
rhs.dt,
true};
auto *ker = get_jit_evaluator_kernel(id);
auto &ctx = get_current_program().get_context();
auto &current_program = stmt->get_kernel()->program;
auto &ctx = current_program.get_context();
ContextArgSaveGuard _(
ctx); // save input args, prevent override current kernel
ctx.set_arg<int64_t>(0, lhs.val_i64);
ctx.set_arg<int64_t>(1, rhs.val_i64);
(*ker)();
ret.val_i64 = get_current_program().fetch_result<int64_t>(0);
ret.val_i64 = current_program.fetch_result<int64_t>(0);
return true;
}

Expand All @@ -135,12 +136,13 @@ class ConstantFold : public BasicStmtVisitor {
stmt->cast_type,
false};
auto *ker = get_jit_evaluator_kernel(id);
auto &ctx = get_current_program().get_context();
auto &current_program = stmt->get_kernel()->program;
auto &ctx = current_program.get_context();
ContextArgSaveGuard _(
ctx); // save input args, prevent override current kernel
ctx.set_arg<int64_t>(0, operand.val_i64);
(*ker)();
ret.val_i64 = get_current_program().fetch_result<int64_t>(0);
ret.val_i64 = current_program.fetch_result<int64_t>(0);
return true;
}

Expand Down Expand Up @@ -204,7 +206,8 @@ void constant_fold(IRNode *root) {
// disable constant_fold when config.debug is turned on.
// Discussion:
// https://github.com/taichi-dev/taichi/pull/839#issuecomment-626107010
if (get_current_program().config.debug) {
auto kernel = root->get_kernel();
if (kernel && kernel->program.config.debug) {
TI_TRACE("config.debug enabled, ignoring constant fold");
return;
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/transforms/lower_access.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ namespace irpass {

void lower_access(IRNode *root, bool lower_atomic, Kernel *kernel) {
LowerAccess::run(root, lower_atomic);
typecheck(root, kernel);
typecheck(root);
}

} // namespace irpass
Expand Down
6 changes: 3 additions & 3 deletions taichi/transforms/simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ class BasicBlockSimplify : public IRVisitor {
stmt->insert_before_me(std::move(sum));
stmt->parent->erase(stmt);
// get types of adds and muls
irpass::typecheck(stmt->parent, kernel);
irpass::typecheck(stmt->parent);
throw IRModified();
}

Expand Down Expand Up @@ -1160,10 +1160,10 @@ void simplify(IRNode *root, Kernel *kernel) {
}
}

void full_simplify(IRNode *root, const CompileConfig &config, Kernel *kernel) {
void full_simplify(IRNode *root, Kernel *kernel) {
constant_fold(root);
if (advanced_optimization) {
alg_simp(root, config);
alg_simp(root);
die(root);
whole_kernel_cse(root);
}
Expand Down
17 changes: 9 additions & 8 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ class TypeCheck : public IRVisitor {
CompileConfig config;

public:
TypeCheck(Kernel *kernel) : kernel(kernel) {
// TODO: remove dependency on get_current_program here
if (current_program != nullptr)
config = get_current_program().config;
TypeCheck(IRNode *root) {
kernel = root->get_kernel();
if (kernel != nullptr) {
config = kernel->program.config;
}
allow_undefined_visitor = true;
}

Expand Down Expand Up @@ -316,7 +317,7 @@ class TypeCheck : public IRVisitor {
void visit(ArgLoadStmt *stmt) {
Kernel *current_kernel = kernel;
if (current_kernel == nullptr) {
current_kernel = &get_current_program().get_current_kernel();
current_kernel = stmt->get_kernel();
}
auto &args = current_kernel->args;
TI_ASSERT(0 <= stmt->arg_id && stmt->arg_id < args.size());
Expand All @@ -326,7 +327,7 @@ class TypeCheck : public IRVisitor {
void visit(KernelReturnStmt *stmt) {
Kernel *current_kernel = kernel;
if (current_kernel == nullptr) {
current_kernel = &get_current_program().get_current_kernel();
current_kernel = stmt->get_kernel();
}
auto &rets = current_kernel->rets;
TI_ASSERT(rets.size() >= 1);
Expand Down Expand Up @@ -416,9 +417,9 @@ class TypeCheck : public IRVisitor {

namespace irpass {

void typecheck(IRNode *root, Kernel *kernel) {
void typecheck(IRNode *root) {
analysis::check_fields_registered(root);
TypeCheck inst(kernel);
TypeCheck inst(root);
root->accept(&inst);
}

Expand Down
Loading

0 comments on commit 2bf0bee

Please sign in to comment.