Skip to content

Commit

Permalink
fixed ti.func; added ArgStore stmts
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanming-hu committed Oct 31, 2019
1 parent 702e380 commit e736819
Show file tree
Hide file tree
Showing 12 changed files with 132 additions and 22 deletions.
4 changes: 0 additions & 4 deletions examples/recompile.py

This file was deleted.

22 changes: 12 additions & 10 deletions python/taichi/lang/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,33 +22,35 @@ def remove_indent(lines):

return '\n'.join(cleaned)

# The ti.func decorator

def func(foo):
# What is this func used for?
assert False
from .impl import get_runtime
src = remove_indent(inspect.getsource(foo))
tree = ast.parse(src)

func_body = tree.body[0]
func_body.decorator_list = []

visitor = ASTTransformer(transform_args=False)
visitor.visit(tree)
ast.fix_missing_locations(tree)

if self.runtime.print_preprocessed:
if get_runtime().print_preprocessed:
import astor
print(astor.to_source(tree.body[0], indent_with=' '))

ast.increment_lineno(tree, inspect.getsourcelines(foo)[1] - 1)

self.runtime.inside_kernel = True
get_runtime().inside_kernel = True
frame = inspect.currentframe().f_back
exec(compile(tree, filename=inspect.getsourcefile(foo), mode='exec'),
dict(frame.f_globals, **frame.f_locals), locals())
self.runtime.inside_kernel = False
get_runtime().inside_kernel = False
compiled = locals()[foo.__name__]
return compiled


class KernelTemplateMapper:
def __init__(self, num_args, template_slot_locations):
self.num_args = num_args
Expand Down
2 changes: 2 additions & 0 deletions src/inc/statements.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ PER_STATEMENT(FrontendAtomicStmt)
PER_STATEMENT(FrontendEvalStmt)
PER_STATEMENT(FrontendSNodeOpStmt) // activate, deactivate, append, clear
PER_STATEMENT(FrontendAssertStmt)
PER_STATEMENT(FrontendArgStoreStmt)

// Midend statement

Expand All @@ -35,6 +36,7 @@ PER_STATEMENT(LocalStoreStmt)
PER_STATEMENT(SNodeOpStmt)
PER_STATEMENT(RangeAssumptionStmt)
PER_STATEMENT(AssertStmt)
PER_STATEMENT(ArgStoreStmt)

// SNodeOps
PER_STATEMENT(IntegerOffsetStmt)
Expand Down
34 changes: 32 additions & 2 deletions src/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,6 @@ class IRBuilder {
Stmt *get_last_stmt();
};

IRBuilder &current_ast_uilder();

inline Expr load_if_ptr(const Expr &ptr);
inline Expr smart_load(const Expr &var);

Expand Down Expand Up @@ -936,6 +934,38 @@ class ArgLoadExpression : public Expression {
}
};

// For return values
class FrontendArgStoreStmt : public Stmt {
public:
int arg_id;
Expr expr;

FrontendArgStoreStmt(int arg_id, Expr expr) : arg_id(arg_id), expr(expr) {
}

virtual bool has_global_side_effect() const override {
return false;
}

DEFINE_ACCEPT
};

// For return values
class ArgStoreStmt : public Stmt {
public:
int arg_id;
Stmt *val;

ArgStoreStmt(int arg_id, Stmt *val) : arg_id(arg_id), val(val) {
}

virtual bool has_global_side_effect() const override {
return false;
}

DEFINE_ACCEPT
};

class RandStmt : public Stmt {
public:
RandStmt(DataType dt) {
Expand Down
9 changes: 7 additions & 2 deletions src/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ void Kernel::operator()() {
set_arg_nparray(i, (uint64)device_buffers[i], args[i].size);
cudaMemcpy(device_buffers[i], host_buffers[i], args[i].size,
cudaMemcpyHostToDevice);
} }
}
}
if (has_buffer)
cudaDeviceSynchronize();
auto c = program.get_context();
Expand Down Expand Up @@ -131,6 +132,10 @@ void Kernel::set_arg_int(int i, int64 d) {
}
}

void Kernel::mark_arg_return_value(int i, bool is_return) {
args[i].is_return_value = is_return;
}

void Kernel::set_arg_nparray(int i, uint64 d, uint64 size) {
TC_ASSERT_INFO(args[i].is_nparray,
"Setting numpy array to scalar argument is not allowed");
Expand All @@ -144,7 +149,7 @@ void Kernel::set_arch(Arch arch) {
}

int Kernel::insert_arg(DataType dt, bool is_nparray) {
args.push_back({dt, is_nparray});
args.push_back(Arg{dt, is_nparray, 0, false});
return args.size() - 1;
}

Expand Down
13 changes: 13 additions & 0 deletions src/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@ class Kernel {
DataType dt;
bool is_nparray;
std::size_t size;
bool is_return_value;

Arg(DataType dt = DataType::unknown,
bool is_nparray = false,
std::size_t size = 0,
bool is_return_value = 0)
: dt(dt),
is_nparray(is_nparray),
size(size),
is_return_value(is_return_value) {
}
};
std::vector<Arg> args;
bool benchmarking;
Expand All @@ -45,6 +56,8 @@ class Kernel {

void set_arg_int(int i, int64 d);

void mark_arg_return_value(int i, bool is_return = true);

void set_arg_nparray(int i, uint64 ptr, uint64 size);

void set_arch(Arch arch);
Expand Down
23 changes: 21 additions & 2 deletions src/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ void Program::visualize_layout(const std::string &fn) {
\Tree)";
emit(header);

std::function<void(SNode *snode)> visit = [&](SNode *snode) {
std::function<void(SNode * snode)> visit = [&](SNode *snode) {
emit("[.{");
if (snode->type == SNodeType::place) {
emit(snode->name);
Expand Down Expand Up @@ -196,7 +196,26 @@ void Program::clear_all_gradients() {
}
}

void Program::get_snode_reader(SNode *snode){TC_NOT_IMPLEMENTED}
Kernel &Program::get_snode_reader(SNode *snode) {
TC_ASSERT(snode->type == SNodeType::place);
auto kernel_name = fmt::format("snode_writer_{}", snode->id);
TC_ASSERT(snode->num_active_indices <= 4);
auto &ker = kernel([&] {
ExprGroup indices;
for (int i = 0; i < snode->num_active_indices; i++) {
indices.push_back(Expr::make<ArgLoadExpression>(i));
}
Stmt::make<FrontendArgStoreStmt>(snode->num_active_indices,
(*snode->expr)[indices]);
});
ker.set_arch(get_host_arch());
ker.name = kernel_name;
for (int i = 0; i < snode->num_active_indices; i++)
ker.insert_arg(DataType::i32, false);
auto ret_val = ker.insert_arg(snode->dt, false);
ker.mark_arg_return_value(ret_val);
return ker;
}

Kernel &Program::get_snode_writer(SNode *snode) {
TC_ASSERT(snode->type == SNodeType::place);
Expand Down
2 changes: 1 addition & 1 deletion src/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class Program {

void clear_all_gradients();

void get_snode_reader(SNode *snode);
Kernel &get_snode_reader(SNode *snode);

Kernel &get_snode_writer(SNode *snode);

Expand Down
14 changes: 13 additions & 1 deletion src/snode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,19 @@ void SNode::write_float(int i, int j, int k, int l, float64 val) {
}

float64 SNode::read_float(int i, int j, int k, int l) {
return 0;
if (writer_kernel == nullptr) {
writer_kernel = &get_current_program().get_snode_writer(this);
}
if (num_active_indices >= 1)
writer_kernel->set_arg_int(0, i);
if (num_active_indices >= 2)
writer_kernel->set_arg_int(1, j);
if (num_active_indices >= 3)
writer_kernel->set_arg_int(2, k);
if (num_active_indices >= 4)
writer_kernel->set_arg_int(3, l);
(*writer_kernel)();
return get_current_program().context.get_arg<float32>(num_active_indices);
}

// for int32 and int64
Expand Down
10 changes: 10 additions & 0 deletions src/transforms/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,16 @@ class IRPrinter : public IRVisitor {
print("{}{} = arg[{}]", stmt->type_hint(), stmt->name(), stmt->arg_id);
}

void visit(FrontendArgStoreStmt *stmt) override {
print("{}{} : store arg {} <- {}", stmt->type_hint(), stmt->name(),
stmt->arg_id, stmt->expr->serialize());
}

void visit(ArgStoreStmt *stmt) override {
print("{}{} : store arg {} <- {}", stmt->type_hint(), stmt->name(),
stmt->arg_id, stmt->val->name());
}

void visit(LocalLoadStmt *stmt) override {
print("{}{} = local load [{}]", stmt->type_hint(), stmt->name(),
to_string(stmt->ptr));
Expand Down
11 changes: 11 additions & 0 deletions src/transforms/lower_ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,17 @@ class LowerAST : public IRVisitor {
throw IRModified();
}

void visit(FrontendArgStoreStmt *stmt) override {
// expand value
Stmt *val_stmt = nullptr;
VecStatement flattened;
stmt->expr->flatten(flattened);
flattened.push_back(
Stmt::make<ArgStoreStmt>(stmt->arg_id, flattened.back().get()));
stmt->parent->replace_with(stmt, flattened);
throw IRModified();
}

static void run(IRNode *node) {
LowerAST inst;
while (true) {
Expand Down
10 changes: 10 additions & 0 deletions src/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,16 @@ class TypeCheck : public IRVisitor {
stmt->ret_type = VectorType(1, args[stmt->arg_id].dt);
}

void visit(ArgStoreStmt *stmt) {
auto &args = get_current_program().get_current_kernel().args;
TC_ASSERT(0 <= stmt->arg_id && stmt->arg_id < args.size());
auto arg = args[stmt->arg_id];
auto arg_type = arg.dt;
TC_ASSERT(arg.is_return_value);
TC_ASSERT(stmt->val->ret_type.data_type == arg_type);
stmt->ret_type = VectorType(1, arg_type);
}

void visit(ExternalPtrStmt *stmt) {
stmt->ret_type = VectorType(stmt->base_ptrs.size(),
stmt->base_ptrs[0]->ret_type.data_type);
Expand Down

0 comments on commit e736819

Please sign in to comment.