Skip to content

Commit

Permalink
reader kernel; reproduce test_io failure
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanming-hu committed Nov 1, 2019
1 parent e736819 commit 2f19ac3
Showing 11 changed files with 102 additions and 52 deletions.
1 change: 0 additions & 1 deletion examples/mpm_lagrangian_forces.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import taichi as ti
import os
import random

real = ti.f32
dim = 2
24 changes: 14 additions & 10 deletions python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
@@ -161,15 +161,19 @@ def initialize_accessor(self):
if self.getter:
return
snode = self.ptr.snode()
num_ind = snode.num_active_indices()
dt_name = taichi_lang_core.data_type_short_name(snode.data_type())
self.getter = getattr(self.ptr, 'val{}_{}'.format(num_ind, dt_name))

if self.snode().data_type() == f32 or self.snode().data_type() == f64:
def getter(*key):
return snode.read_float(key[0], key[1], key[2], key[3])
def setter(value, *key):
self.snode().ptr.write_float(key[0], key[1], key[2], key[3], value)
snode.write_float(key[0], key[1], key[2], key[3], value)
else:
def getter(*key):
return snode.read_int(key[0], key[1], key[2], key[3])
def setter(value, *key):
self.snode().ptr.write_int(key[0], key[1], key[2], key[3], value)
snode.write_int(key[0], key[1], key[2], key[3], value)

self.getter = getter
self.setter = setter

def __setitem__(self, key, value):
@@ -188,11 +192,11 @@ def __getitem__(self, key):
self.materialize_layout_callback()
self.initialize_accessor()
if key is None:
return self.getter()
else:
if not isinstance(key, tuple):
key = (key, )
return self.getter(*key)
key = ()
if not isinstance(key, tuple):
key = (key, )
key = key + ((0, ) * (4 - len(key)))
return self.getter(*key)

def loop_range(self):
return self
20 changes: 19 additions & 1 deletion src/backends/codegen_llvm.h
Original file line number Diff line number Diff line change
@@ -214,7 +214,7 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder {
task.compile();
}
auto offloaded_tasks_local = offloaded_tasks;
return [=](Context context) {
return [=](Context &context) {
for (auto task : offloaded_tasks_local) {
task(&context);
}
@@ -688,6 +688,24 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder {
}
}

void visit(ArgStoreStmt *stmt) {
if (stmt->is_ptr) {
TC_NOT_IMPLEMENTED
} else {
auto intermediate_bits =
tlctx->get_data_type(stmt->val->ret_type.data_type)
->getPrimitiveSizeInBits();
llvm::Type *intermediate_type =
llvm::Type::getIntNTy(*llvm_context, intermediate_bits);
llvm::Type *dest_ty = tlctx->get_data_type<int64>();
auto extended = builder->CreateZExt(
builder->CreateBitCast(stmt->val->value, intermediate_type), dest_ty);
builder->CreateCall(
get_runtime_function("Context_set_args"),
{get_context(), tlctx->get_constant(stmt->arg_id), extended});
}
}

void visit(LocalLoadStmt *stmt) {
TC_ASSERT(stmt->width() == 1);
stmt->value = builder->CreateLoad(stmt->ptr[0].var->value);
7 changes: 5 additions & 2 deletions src/ir.h
Original file line number Diff line number Diff line change
@@ -943,8 +943,9 @@ class FrontendArgStoreStmt : public Stmt {
FrontendArgStoreStmt(int arg_id, Expr expr) : arg_id(arg_id), expr(expr) {
}

// Arguments are considered global (nonlocal)
virtual bool has_global_side_effect() const override {
return false;
return true;
}

DEFINE_ACCEPT
@@ -957,10 +958,12 @@ class ArgStoreStmt : public Stmt {
Stmt *val;

ArgStoreStmt(int arg_id, Stmt *val) : arg_id(arg_id), val(val) {
add_operand(this->val);
}

// Arguments are considered global (nonlocal)
virtual bool has_global_side_effect() const override {
return false;
return true;
}

DEFINE_ACCEPT
5 changes: 3 additions & 2 deletions src/program.cpp
Original file line number Diff line number Diff line change
@@ -205,8 +205,9 @@ Kernel &Program::get_snode_reader(SNode *snode) {
for (int i = 0; i < snode->num_active_indices; i++) {
indices.push_back(Expr::make<ArgLoadExpression>(i));
}
Stmt::make<FrontendArgStoreStmt>(snode->num_active_indices,
(*snode->expr)[indices]);
auto ret = Stmt::make<FrontendArgStoreStmt>(
snode->num_active_indices, load_if_ptr((*snode->expr)[indices]));
current_ast_builder().insert(std::move(ret));
});
ker.set_arch(get_host_arch());
ker.name = kernel_name;
27 changes: 2 additions & 25 deletions src/python_bindings.cpp
Original file line number Diff line number Diff line change
@@ -24,25 +24,6 @@ void expr_assign(const Expr &lhs_, const Expr &rhs, std::string tb) {

std::vector<std::unique_ptr<IRBuilder::ScopeGuard>> scope_stack;

template <typename T, typename C>
void export_accessors(C &c) {
c.def(
fmt::format("val0_{}", data_type_short_name(get_data_type<T>())).c_str(),
&Expr::val<T>);
c.def(
fmt::format("val1_{}", data_type_short_name(get_data_type<T>())).c_str(),
&Expr::val<T, int>);
c.def(
fmt::format("val2_{}", data_type_short_name(get_data_type<T>())).c_str(),
&Expr::val<T, int, int>);
c.def(
fmt::format("val3_{}", data_type_short_name(get_data_type<T>())).c_str(),
&Expr::val<T, int, int, int>);
c.def(
fmt::format("val4_{}", data_type_short_name(get_data_type<T>())).c_str(),
&Expr::val<T, int, int, int, int>);
}

void compile_runtimes();
std::string libdevice_path();

@@ -117,6 +98,8 @@ void export_lang(py::module &m) {
py::return_value_policy::reference)
.def("data_type", [](SNode *snode) { return snode->dt; })
.def("lazy_grad", &SNode::lazy_grad)
.def("read_int", &SNode::read_int)
.def("read_float", &SNode::read_float)
.def("write_int", &SNode::write_int)
.def("write_float", &SNode::write_float)
.def("num_active_indices",
@@ -141,12 +124,6 @@ void export_lang(py::module &m) {
.def("set_grad", &Expr::set_grad)
.def("get_raw_address", [](Expr *expr) { return (uint64)expr; });

export_accessors<int32>(expr);
export_accessors<int64>(expr);

export_accessors<float32>(expr);
export_accessors<float64>(expr);

py::class_<ExprGroup>(m, "ExprGroup")
.def(py::init<>())
.def("push_back", &ExprGroup::push_back)
43 changes: 34 additions & 9 deletions src/snode.cpp
Original file line number Diff line number Diff line change
@@ -148,19 +148,25 @@ void SNode::write_float(int i, int j, int k, int l, float64 val) {
}

float64 SNode::read_float(int i, int j, int k, int l) {
if (writer_kernel == nullptr) {
writer_kernel = &get_current_program().get_snode_writer(this);
if (reader_kernel == nullptr) {
reader_kernel = &get_current_program().get_snode_reader(this);
}
if (num_active_indices >= 1)
writer_kernel->set_arg_int(0, i);
reader_kernel->set_arg_int(0, i);
if (num_active_indices >= 2)
writer_kernel->set_arg_int(1, j);
reader_kernel->set_arg_int(1, j);
if (num_active_indices >= 3)
writer_kernel->set_arg_int(2, k);
reader_kernel->set_arg_int(2, k);
if (num_active_indices >= 4)
writer_kernel->set_arg_int(3, l);
(*writer_kernel)();
return get_current_program().context.get_arg<float32>(num_active_indices);
reader_kernel->set_arg_int(3, l);
(*reader_kernel)();
if (dt == DataType::f32) {
return get_current_program().context.get_arg<float32>(num_active_indices);
} else if (dt == DataType::f64) {
return get_current_program().context.get_arg<float64>(num_active_indices);
} else {
TC_NOT_IMPLEMENTED
}
}

// for int32 and int64
@@ -179,8 +185,27 @@ void SNode::write_int(int i, int j, int k, int l, int64 val) {
writer_kernel->set_arg_float(num_active_indices, val);
(*writer_kernel)();
}

int64 SNode::read_int(int i, int j, int k, int l) {
return 0;
if (reader_kernel == nullptr) {
reader_kernel = &get_current_program().get_snode_reader(this);
}
if (num_active_indices >= 1)
reader_kernel->set_arg_int(0, i);
if (num_active_indices >= 2)
reader_kernel->set_arg_int(1, j);
if (num_active_indices >= 3)
reader_kernel->set_arg_int(2, k);
if (num_active_indices >= 4)
reader_kernel->set_arg_int(3, l);
(*reader_kernel)();
if (dt == DataType::i32) {
return get_current_program().context.get_arg<int32>(num_active_indices);
} else if (dt == DataType::i64) {
return get_current_program().context.get_arg<int64>(num_active_indices);
} else {
TC_NOT_IMPLEMENTED
}
}

TLANG_NAMESPACE_END
6 changes: 6 additions & 0 deletions src/taichi_llvm_context.cpp
Original file line number Diff line number Diff line change
@@ -63,6 +63,12 @@ TaichiLLVMContext::~TaichiLLVMContext() {
llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) {
if (dt == DataType::i32) {
return llvm::Type::getInt32Ty(*ctx);
} else if (dt == DataType::i8) {
return llvm::Type::getInt8Ty(*ctx);
} else if (dt == DataType::i16) {
return llvm::Type::getInt16Ty(*ctx);
} else if (dt == DataType::i64) {
return llvm::Type::getInt64Ty(*ctx);
} else if (dt == DataType::f32) {
return llvm::Type::getFloatTy(*ctx);
} else if (dt == DataType::f64) {
2 changes: 1 addition & 1 deletion src/util.h
Original file line number Diff line number Diff line change
@@ -36,7 +36,7 @@ real measure_cpe(std::function<void()> target,

struct Context;

using FunctionType = std::function<void(Context)>;
using FunctionType = std::function<void(Context &)>;

enum class DataType : int {
f16,
1 change: 1 addition & 0 deletions tests/python/test_abs.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
@ti.program_test
def test_abs():
ti.reset()
ti.cfg.print_ir = True
x = ti.var(ti.f32)
y = ti.var(ti.f32)

18 changes: 17 additions & 1 deletion tests/python/test_basics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import taichi as ti

@ti.program_test

def test_simple():
x = ti.var(ti.i32)

@@ -44,3 +43,20 @@ def func():
assert x[i] == i + 123


@ti.program_test
def test_io():
ti.cfg.print_ir = True
x = ti.var(ti.i32)

n = 128

@ti.layout
def place():
ti.root.dense(ti.i, n).place(x)

x[3] = 123
x[4] = 456
assert x[3] == 123
assert x[4] == 456

test_io()

0 comments on commit 2f19ac3

Please sign in to comment.