Skip to content

Commit

Permalink
writer kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanming-hu committed Oct 31, 2019
1 parent 661d4d6 commit 689de60
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 51 deletions.
11 changes: 10 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,17 @@ matrix:
include:
- os: osx
osx_image: xcode10.3
python:
- "3.7"
env:
- MATRIX_EVAL="CC=clang && CXX=clang++ && PYTHON=python3.6"
- MATRIX_EVAL="CC=clang && CXX=clang++ && PYTHON=python3.7"
include:
- os: osx
osx_image: xcode10.3
python:
- "3.6"
env:
- MATRIX_EVAL="CC=clang && CXX=clang++ && PYTHON=python3.6"
before_install:
- eval "${MATRIX_EVAL}"
- echo $CXX
Expand Down
20 changes: 13 additions & 7 deletions python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .core import taichi_lang_core
from .util import is_taichi_class
from .util import *
import traceback


Expand Down Expand Up @@ -164,18 +164,24 @@ def initialize_accessor(self):
num_ind = snode.num_active_indices()
dt_name = taichi_lang_core.data_type_short_name(snode.data_type())
self.getter = getattr(self.ptr, 'val{}_{}'.format(num_ind, dt_name))
self.setter = getattr(self.ptr, 'set_val{}_{}'.format(num_ind, dt_name))
if self.snode().data_type() == f32 or self.snode().data_type() == f64:
def setter(value, *key):
self.snode().ptr.write_float(key[0], key[1], key[2], key[3], value)
else:
def setter(value, *key):
self.snode().ptr.write_int(key[0], key[1], key[2], key[3], value)
self.setter = setter

def __setitem__(self, key, value):
if not Expr.layout_materialized:
self.materialize_layout_callback()
self.initialize_accessor()
if key is None:
self.setter(value)
else:
if not isinstance(key, tuple):
key = (key, )
self.setter(value, *key)
key = ()
if not isinstance(key, tuple):
key = (key, )
key = key + ((0, ) * (4 - len(key)))
self.setter(value, *key)

def __getitem__(self, key):
if not Expr.layout_materialized:
Expand Down
3 changes: 3 additions & 0 deletions python/taichi/lang/snode.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,6 @@ def lazy_grad(self):

def parent(self):
return SNode(self.ptr.snode().parent)

def data_type(self):
return self.ptr.data_type()
33 changes: 14 additions & 19 deletions src/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ std::string capitalize_first(std::string s) {
std::string latex_short_digit(int v) {
std::string units = "KMGT";
int unit_id = -1;
while (v >= 1024 && unit_id + 1 < (int) units.size()) {
while (v >= 1024 && unit_id + 1 < (int)units.size()) {
TC_ASSERT(v % 1024 == 0);
v /= 1024;
unit_id++;
Expand Down Expand Up @@ -136,7 +136,7 @@ void Program::visualize_layout(const std::string &fn) {
}
emit("} ");

for (int i = 0; i < (int) snode->ch.size(); i++) {
for (int i = 0; i < (int)snode->ch.size(); i++) {
visit(snode->ch[i].get());
}
emit("]");
Expand Down Expand Up @@ -199,30 +199,25 @@ void Program::clear_all_gradients() {
}
}

void Program::get_snode_reader(SNode *snode) {
TC_NOT_IMPLEMENTED
}
void Program::get_snode_reader(SNode *snode){TC_NOT_IMPLEMENTED}

std::function<void(int, int)> Program::get_snode_writer(SNode *snode) {
Kernel &Program::get_snode_writer(SNode *snode) {
TC_ASSERT(snode->type == SNodeType::place);
auto kernel_name = fmt::format("snode_writer_{}", snode->id);
TC_ASSERT(snode->num_active_indices <= 4);
auto &ker = kernel([&] {
if (snode->num_active_indices == 1) {
(*snode->expr)[Expr::make<ArgLoadExpression>(0)] =
Expr::make<ArgLoadExpression>(1);
} else {
TC_NOT_IMPLEMENTED;
ExprGroup indices;
for (int i = 0; i < snode->num_active_indices; i++) {
indices.push_back(Expr::make<ArgLoadExpression>(i));
}
(*snode->expr)[indices] =
Expr::make<ArgLoadExpression>(snode->num_active_indices);
});
ker.name = kernel_name;
ker.insert_arg(DataType::i32, false);
ker.insert_arg(DataType::i32, false);
auto writer = [&](int i, int val) {
ker.set_arg_int(0, i);
ker.set_arg_int(1, val);
ker();
};
return writer;
for (int i = 0; i < snode->num_active_indices; i++)
ker.insert_arg(DataType::i32, false);
ker.insert_arg(snode->dt, false);
return ker;
}

TLANG_NAMESPACE_END
2 changes: 1 addition & 1 deletion src/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class Program {

void get_snode_reader(SNode *snode);

std::function<void(int, int)> get_snode_writer(SNode *snode);
Kernel &get_snode_writer(SNode *snode);
};

TLANG_NAMESPACE_END
2 changes: 2 additions & 0 deletions src/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ void export_lang(py::module &m) {
py::return_value_policy::reference)
.def("data_type", [](SNode *snode) { return snode->dt; })
.def("lazy_grad", &SNode::lazy_grad)
.def("write_int", &SNode::write_int)
.def("write_float", &SNode::write_float)
.def("num_active_indices",
[](SNode *snode) { return snode->num_active_indices; });

Expand Down
48 changes: 47 additions & 1 deletion src/snode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ bool SNode::is_primal() const {
}

bool SNode::has_grad() const {
return is_primal() && (*expr).cast<GlobalVariableExpression>()->adjoint.expr != nullptr;
return is_primal() &&
(*expr).cast<GlobalVariableExpression>()->adjoint.expr != nullptr;
}

SNode *SNode::get_grad() const {
Expand All @@ -129,4 +130,49 @@ SNode *SNode::get_grad() const {
->snode;
}

// for float and double
void SNode::write_float(int i, int j, int k, int l, float64 val) {
if (writer_kernel == nullptr) {
writer_kernel = &get_current_program().get_snode_writer(this);
}
if (num_active_indices >= 1)
writer_kernel->set_arg_int(0, i);
if (num_active_indices >= 2)
writer_kernel->set_arg_int(1, j);
if (num_active_indices >= 3)
writer_kernel->set_arg_int(2, k);
if (num_active_indices >= 4)
writer_kernel->set_arg_int(3, l);
writer_kernel->set_arg_float(num_active_indices, val);
(*writer_kernel)();
}

float64 SNode::read_float(int i, int j, int k, int l) {
return 0;
}

// for int32 and int64
void SNode::write_int(int i, int j, int k, int l, int64 val) {
if (writer_kernel == nullptr) {
TC_TAG;
writer_kernel = &get_current_program().get_snode_writer(this);
TC_TAG;
}
if (num_active_indices >= 1)
writer_kernel->set_arg_int(0, i);
if (num_active_indices >= 2)
writer_kernel->set_arg_int(1, j);
if (num_active_indices >= 3)
writer_kernel->set_arg_int(2, k);
if (num_active_indices >= 4)
writer_kernel->set_arg_int(3, l);
writer_kernel->set_arg_float(num_active_indices, val);
TC_TAG;
(*writer_kernel)();
TC_TAG;
}
int64 SNode::read_int(int i, int j, int k, int l) {
return 0;
}

TLANG_NAMESPACE_END
14 changes: 14 additions & 0 deletions src/snode.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class Index {
}
};

class Kernel;
// "Structural" nodes
class SNode {
public:
Expand Down Expand Up @@ -79,6 +80,8 @@ class SNode {
TypedConstant ambient_val;
// Note: parent will not be set until structural nodes are compiled!
SNode *parent;
Kernel *reader_kernel;
Kernel *writer_kernel;
std::unique_ptr<Expr> expr;

std::string data_type_name() {
Expand Down Expand Up @@ -137,6 +140,9 @@ class SNode {

llvm_type = nullptr;
llvm_element_type = nullptr;

reader_kernel = nullptr;
writer_kernel = nullptr;
}

SNode &insert_children(SNodeType t) {
Expand Down Expand Up @@ -255,6 +261,14 @@ class SNode {
return access_func(ds, i, j, k, l);
}

// for float and double
void write_float(int i, int j, int k, int l, float64);
float64 read_float(int i, int j, int k, int l);

// for int32 and int64
void write_int(int i, int j, int k, int l, int64);
int64 read_int(int i, int j, int k, int l);

TC_FORCE_INLINE AllocatorStat stat() {
TC_ASSERT(stat_func);
return stat_func();
Expand Down
22 changes: 0 additions & 22 deletions tests/python/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,6 @@ def place():
assert x[i] == i
assert y[i] == i + 123

@ti.program_test
def test_writer():
x = ti.var(ti.i32)
y = ti.var(ti.i32)

n = 128

@ti.layout
def place():
ti.root.dense(ti.i, n).place(x)
ti.root.dense(ti.i, n).place(y)

x[0] = 0
writer = ti.get_runtime().prog.get_snode_writer(x.ptr.snode())

for i in range(n):
writer(i, i * 2)
y[i] = i + 123

for i in range(n):
assert x[i] == i * 2
assert y[i] == i + 123

def test_linear_repeated():
for i in range(10):
Expand Down

0 comments on commit 689de60

Please sign in to comment.