Skip to content

Commit

Permalink
[opengl] [refactor] Reduce SSBO numbers: merge earg with args (taichi…
Browse files Browse the repository at this point in the history
  • Loading branch information
archibate authored Nov 3, 2020
1 parent eb3e251 commit 2f809aa
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 21 deletions.
16 changes: 7 additions & 9 deletions taichi/backends/opengl/codegen_opengl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,6 @@ class KernelGen : public IRVisitor {
if (used.int64)
kernel_header += "layout(std430, binding = 2) buffer args_i64 { int64_t _args_i64_[]; };\n";
}
if (used.buf_earg) {
kernel_header +=
"layout(std430, binding = 3) buffer earg_i32 { int _earg_i32_[]; };\n";
}
if (used.buf_extr) {
kernel_header +=
"layout(std430, binding = 4) buffer extr_i32 { int _extr_i32_[]; };\n"
Expand Down Expand Up @@ -427,9 +423,10 @@ class KernelGen : public IRVisitor {
const int num_indices = stmt->indices.size();
std::vector<std::string> size_var_names;
for (int i = 0; i < num_indices; i++) {
used.buf_earg = true;
used.buf_args = true;
std::string var_name = fmt::format("_s{}_{}", i, stmt->short_name());
emit("int {} = _earg_i32_[{} * {} + {}];", var_name, arg_id,
emit("int {} = _args_i32_[{} + {} * {} + {}];", var_name,
taichi_opengl_earg_base / sizeof(int), arg_id,
taichi_max_num_indices, i);
size_var_names.push_back(std::move(var_name));
}
Expand Down Expand Up @@ -693,9 +690,10 @@ class KernelGen : public IRVisitor {
const auto name = stmt->short_name();
const auto arg_id = stmt->arg_id;
const auto axis = stmt->axis;
used.buf_earg = true;
emit("int {} = _earg_i32_[{} * {} + {}];", name, arg_id,
taichi_max_num_indices, axis);
used.buf_args = true;
emit("int {} = _args_i32_[{} + {} * {} + {}];", name,
taichi_opengl_earg_base / sizeof(int), arg_id, taichi_max_num_indices,
axis);
}

std::string make_kernel_name() {
Expand Down
27 changes: 16 additions & 11 deletions taichi/backends/opengl/opengl_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,15 +272,15 @@ struct GLBuffer : GLSSBO {
}

void copy_back() {
copy_back(base);
copy_back(base, size);
}

void copy_back(void *ptr) {
if (!size)
void copy_back(void *ptr, size_t len) {
if (!len)
return;
void *mapped = this->map();
TI_ASSERT(mapped);
std::memcpy(ptr, mapped, size);
std::memcpy(ptr, mapped, len);
this->unmap();
}
};
Expand Down Expand Up @@ -528,11 +528,15 @@ struct CompiledProgram::Impl {
GLBufferTable &bufs = launcher->impl->user_bufs;
std::vector<char> base_arr;
std::vector<void *> saved_ctx_ptrs;
std::vector<char> args;
args.resize(std::max(arg_count, ret_count) * sizeof(uint64_t));
// NOTE: these dirty codes are introduced by #694, TODO: RAII
/// DIRTY_BEGIN {{{
if (ext_arr_map.size()) {
bufs.add_buffer(GLBufId::Earg, ctx.extra_args,
arg_count * taichi_max_num_args * sizeof(int));
args.resize(taichi_opengl_earg_base +
arg_count * taichi_max_num_indices * sizeof(int));
std::memcpy(args.data() + taichi_opengl_earg_base, ctx.extra_args,
arg_count * taichi_max_num_indices * sizeof(int));
if (ext_arr_map.size() == 1) { // zero-copy for only one ext_arr
auto it = ext_arr_map.begin();
auto extptr = (void *)ctx.args[it->first];
Expand All @@ -558,8 +562,8 @@ struct CompiledProgram::Impl {
}
}
/// DIRTY_END }}}
bufs.add_buffer(GLBufId::Args, ctx.args,
std::max(arg_count, ret_count) * sizeof(uint64_t));
std::memcpy(args.data(), ctx.args, arg_count * sizeof(uint64_t));
bufs.add_buffer(GLBufId::Args, args.data(), args.size());
if (used.print) {
// TODO(archibate): use result_buffer for print results
auto runtime_buf = launcher->impl->core_bufs.get(GLBufId::Runtime);
Expand All @@ -571,10 +575,11 @@ struct CompiledProgram::Impl {
ker->dispatch_compute(launcher);
}
for (auto &[idx, buf] : launcher->impl->user_bufs.bufs) {
if (buf->index == GLBufId::Args)
buf->copy_back(launcher->result_buffer);
else
if (buf->index == GLBufId::Args) {
buf->copy_back(launcher->result_buffer, ret_count * sizeof(uint64_t));
} else {
buf->copy_back();
}
}
launcher->impl->user_bufs.clear();
if (used.print) {
Expand Down
3 changes: 2 additions & 1 deletion taichi/backends/opengl/opengl_kernel_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class SNode;

namespace opengl {

constexpr int taichi_opengl_earg_base = taichi_max_num_args * sizeof(uint64_t);

struct UsedFeature {
// types:
bool simulated_atomic_float{false};
Expand Down Expand Up @@ -60,7 +62,6 @@ enum class GLBufId {
Listman = 7,
Gtmp = 1,
Args = 2,
Earg = 3,
Extr = 4,
};

Expand Down

0 comments on commit 2f809aa

Please sign in to comment.