forked from taichi-dev/taichi
-
Notifications
You must be signed in to change notification settings - Fork 0
/
aot_save.cpp
78 lines (67 loc) · 2.34 KB
/
aot_save.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include "taichi/ir/ir_builder.h"
#include "taichi/ir/statements.h"
#include "taichi/program/program.h"
void aot_save(taichi::Arch arch) {
using namespace taichi;
using namespace lang;
auto program = Program(arch);
program.this_thread_config().advanced_optimization = false;
int n = 10;
// program.materialize_runtime();
auto *root = new SNode(0, SNodeType::root);
auto *pointer = &root->dense(Axis(0), n, false);
auto *place = &pointer->insert_children(SNodeType::place);
place->dt = PrimitiveType::i32;
program.add_snode_tree(std::unique_ptr<SNode>(root), /*compile_only=*/true);
auto aot_builder = program.make_aot_module_builder(arch, {});
std::unique_ptr<Kernel> kernel_init, kernel_ret;
{
/*
@ti.kernel
def init():
for index in range(n):
place[index] = index
*/
IRBuilder builder;
auto *zero = builder.get_int32(0);
auto *n_stmt = builder.get_int32(n);
auto *loop = builder.create_range_for(zero, n_stmt, 0, 4);
{
auto _ = builder.get_loop_guard(loop);
auto *index = builder.get_loop_index(loop);
auto *ptr = builder.create_global_ptr(place, {index});
builder.create_global_store(ptr, index);
}
kernel_init =
std::make_unique<Kernel>(program, builder.extract_ir(), "init");
}
{
/*
@ti.kernel
def ret():
sum = 0
for index in place:
sum = sum + place[index];
return sum
*/
IRBuilder builder;
auto *sum = builder.create_local_var(PrimitiveType::i32);
auto *loop = builder.create_struct_for(pointer, 0, 4);
{
auto _ = builder.get_loop_guard(loop);
auto *index = builder.get_loop_index(loop);
auto *sum_old = builder.create_local_load(sum);
auto *place_index =
builder.create_global_load(builder.create_global_ptr(place, {index}));
builder.create_local_store(sum, builder.create_add(sum_old, place_index));
}
builder.create_return(builder.create_local_load(sum));
kernel_ret = std::make_unique<Kernel>(program, builder.extract_ir(), "ret");
kernel_ret->insert_ret(PrimitiveType::i32);
}
aot_builder->add_field("place", place, true, place->dt, {n}, 1, 1);
aot_builder->add("init", kernel_init.get());
aot_builder->add("ret", kernel_ret.get());
aot_builder->dump(".", taichi::arch_name(arch) + "_aot.tcb");
std::cout << "done" << std::endl;
}