Skip to content

Commit

Permalink
[example] SPIR-V AOT example in C++ (taichi-dev#3707)
Browse files Browse the repository at this point in the history
* Aot example in CPP

* fix

* compile_only=true

* rename to cpp_examples

* materialize runtime
  • Loading branch information
AmesingFlank authored Dec 5, 2021
1 parent cc74e13 commit 71d67bc
Showing 1 changed file with 76 additions and 0 deletions.
76 changes: 76 additions & 0 deletions cpp_examples/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,84 @@ void autograd() {
std::cout << std::endl;
}

void aot_save() {
using namespace taichi;
using namespace lang;
auto program = Program(Arch::vulkan);

program.config.advanced_optimization = false;

int n = 10;

program.materialize_runtime();
auto *root = new SNode(0, SNodeType::root);
auto *pointer = &root->dense(Axis(0), n, false);
auto *place = &pointer->insert_children(SNodeType::place);
place->dt = PrimitiveType::i32;
program.add_snode_tree(std::unique_ptr<SNode>(root), /*compile_only=*/true);

auto aot_builder = program.make_aot_module_builder(Arch::vulkan);

std::unique_ptr<Kernel> kernel_init, kernel_ret;

{
/*
@ti.kernel
def init():
for index in range(n):
place[index] = index
*/
IRBuilder builder;
auto *zero = builder.get_int32(0);
auto *n_stmt = builder.get_int32(n);
auto *loop = builder.create_range_for(zero, n_stmt, 1, 0, 4);
{
auto _ = builder.get_loop_guard(loop);
auto *index = builder.get_loop_index(loop);
auto *ptr = builder.create_global_ptr(place, {index});
builder.create_global_store(ptr, index);
}

kernel_init =
std::make_unique<Kernel>(program, builder.extract_ir(), "init");
}

{
/*
@ti.kernel
def ret():
sum = 0
for index in place:
sum = sum + place[index];
return sum
*/
IRBuilder builder;
auto *sum = builder.create_local_var(PrimitiveType::i32);
auto *loop = builder.create_struct_for(pointer, 1, 0, 4);
{
auto _ = builder.get_loop_guard(loop);
auto *index = builder.get_loop_index(loop);
auto *sum_old = builder.create_local_load(sum);
auto *place_index =
builder.create_global_load(builder.create_global_ptr(place, {index}));
builder.create_local_store(sum, builder.create_add(sum_old, place_index));
}
builder.create_return(builder.create_local_load(sum));

kernel_ret = std::make_unique<Kernel>(program, builder.extract_ir(), "ret");
kernel_ret->insert_ret(PrimitiveType::i32);
}

aot_builder->add_field("place", place, true, place->dt, {n}, 1, 1);
aot_builder->add("init", kernel_init.get());
aot_builder->add("ret", kernel_ret.get());
aot_builder->dump(".", "aot.tcb");
std::cout << "done" << std::endl;
}

int main() {
run_snode();
autograd();
aot_save();
return 0;
}

0 comments on commit 71d67bc

Please sign in to comment.