Skip to content

Commit

Permalink
[Metal] Support assert() (taichi-dev#1959)
Browse files Browse the repository at this point in the history
* [Metal] Support assert()

* simplify
  • Loading branch information
k-ye authored Oct 16, 2020
1 parent 3c2e5e9 commit 6fea4c2
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 47 deletions.
59 changes: 54 additions & 5 deletions taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ constexpr char kContextBufferName[] = "ctx_addr";
constexpr char kContextVarName[] = "kernel_ctx_";
constexpr char kRuntimeBufferName[] = "runtime_addr";
constexpr char kRuntimeVarName[] = "runtime_";
constexpr char kPrintBufferName[] = "print_addr";
constexpr char kPrintAssertBufferName[] = "print_assert_addr";
constexpr char kPrintAllocVarName[] = "print_alloc_";
constexpr char kAssertRecorderVarName[] = "assert_rec_";
constexpr char kLinearLoopIndexName[] = "linear_loop_idx_";
constexpr char kListgenElemVarName[] = "listgen_elem_";
constexpr char kRandStateVarName[] = "rand_state_";
Expand All @@ -60,7 +61,7 @@ std::string buffer_to_name(BuffersEnum b) {
case BuffersEnum::Runtime:
return kRuntimeBufferName;
case BuffersEnum::Print:
return kPrintBufferName;
return kPrintAssertBufferName;
default:
TI_NOT_IMPLEMENTED;
break;
Expand Down Expand Up @@ -639,6 +640,46 @@ class KernelCodegen : public IRVisitor {
emit("}}");
}

void visit(AssertStmt *stmt) override {
used_features()->assertion = true;

const auto &args = stmt->args;
// +1 because the assertion message template itself takes one slot
const auto num_args = args.size() + 1;
TI_ASSERT_INFO(num_args <= shaders::kMetalMaxNumAssertArgs,
"[Metal] Too many args in assert()");
emit("if (!({})) {{", stmt->cond->raw_name());
{
ScopedIndent s(current_appender());
// Only record the message for the first-time assertion failure.
emit("if ({}.mark_first_failure()) {{", kAssertRecorderVarName);
{
ScopedIndent s2(current_appender());
emit("{}.set_num_args({});", kAssertRecorderVarName, num_args);
const std::string asst_var_name = stmt->raw_name() + "_msg_";
emit("PrintMsg {}({}.msg_buf_addr(), {});", asst_var_name,
kAssertRecorderVarName, num_args);
const int msg_str_id = print_strtab_->put(stmt->text);
emit("{}.pm_set_str(/*i=*/0, {});", asst_var_name, msg_str_id);
for (int i = 1; i < num_args; ++i) {
auto *arg = args[i - 1];
const auto ty = arg->element_type();
if (ty == PrimitiveType::i32 || ty == PrimitiveType::f32) {
emit("{}.pm_set_{}({}, {});", asst_var_name,
data_type_short_name(ty), i, arg->raw_name());
} else {
TI_ERROR(
"[Metal] assert() only supports i32 or f32 scalars for now.");
}
}
}
emit("}}");
// This has failed, no point executing the rest of the kernel.
emit("return;");
}
emit("}}");
}

void visit(StackAllocaStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);

Expand Down Expand Up @@ -1125,9 +1166,17 @@ class KernelCodegen : public IRVisitor {
fmt::arg("rtm", kRuntimeVarName),
fmt::arg("lidx", kLinearLoopIndexName),
fmt::arg("nums", kNumRandSeeds));
// Init PrintMsgAllocator
emit("device auto* {} = reinterpret_cast<device PrintMsgAllocator*>({});",
kPrintAllocVarName, kPrintBufferName);
// Init AssertRecorder.
emit("AssertRecorder {}({});", kAssertRecorderVarName,
kPrintAssertBufferName);
// Init PrintMsgAllocator.
// The print buffer comes after (AssertRecorder + assert message buffer),
// therefore we skip by +|kMetalAssertBufferSize|.
emit(
"device auto* {} = reinterpret_cast<device PrintMsgAllocator*>({} + "
"{});",
kPrintAllocVarName, kPrintAssertBufferName,
shaders::kMetalAssertBufferSize);
}
// We do not need additional indentation, because |func_ir| itself is a
// block, which will be indented automatically.
Expand Down
87 changes: 69 additions & 18 deletions taichi/backends/metal/kernel_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "taichi/util/action_recorder.h"
#include "taichi/python/print_buffer.h"
#include "taichi/util/file_sequence_writer.h"
#include "taichi/util/str.h"

#ifdef TI_PLATFORM_OSX
#include <sys/mman.h>
Expand Down Expand Up @@ -53,7 +54,7 @@ class BufferMemoryView {
BufferMemoryView(size_t size, MemoryPool *mem_pool) {
// Both |ptr_| and |size_| must be aligned to page size.
size_ = iroundup(size, taichi_page_size);
ptr_ = mem_pool->allocate(size_, /*alignment=*/taichi_page_size);
ptr_ = (char *)mem_pool->allocate(size_, /*alignment=*/taichi_page_size);
TI_ASSERT(ptr_ != nullptr);
std::memset(ptr_, 0, size_);
}
Expand All @@ -66,13 +67,13 @@ class BufferMemoryView {
inline size_t size() const {
return size_;
}
inline void *ptr() const {
inline char *ptr() const {
return ptr_;
}

private:
size_t size_;
void *ptr_;
char *ptr_;
};

// MetalRuntime maintains a series of MTLBuffers that are shared across all the
Expand Down Expand Up @@ -578,14 +579,13 @@ class KernelManager::Impl {
"Failed to allocate Metal runtime buffer, requested {} bytes",
runtime_mem_->size());
print_mem_ = std::make_unique<BufferMemoryView>(
sizeof(shaders::PrintMsgAllocator) + shaders::kMetalPrintBufferSize,
mem_pool_);
shaders::kMetalPrintAssertBufferSize, mem_pool_);
print_buffer_ = new_mtl_buffer_no_copy(device_.get(), print_mem_->ptr(),
print_mem_->size());
TI_ASSERT(print_buffer_ != nullptr);

init_runtime(params.root_id);
init_print_buffer();
clear_print_assert_buffer();
}

void register_taichi_kernel(const std::string &taichi_kernel_name,
Expand Down Expand Up @@ -640,24 +640,30 @@ class KernelManager::Impl {
for (const auto &mk : ctk.compiled_mtl_kernels) {
mk->launch(input_buffers, cur_command_buffer_.get());
}
const bool used_print = ctk.ti_kernel_attribs.used_features.print;
if (ctx_blitter || used_print) {

const auto &used = ctk.ti_kernel_attribs.used_features;
const bool used_print_assert = (used.print || used.assertion);
if (ctx_blitter || used_print_assert) {
// TODO(k-ye): One optimization is to synchronize only when we absolutely
// need to transfer the data back to host. This includes the cases where
// an arg is 1) an array, or 2) used as return value.
std::vector<MTLBuffer *> buffers_to_blit;
if (ctx_blitter) {
buffers_to_blit.push_back(ctx_blitter->ctx_buffer());
}
if (used_print) {
if (used_print_assert) {
clear_print_assert_buffer();
buffers_to_blit.push_back(print_buffer_.get());
}
blit_buffers_and_sync(buffers_to_blit);

if (ctx_blitter) {
ctx_blitter->metal_to_host();
}
if (used_print) {
if (used.assertion) {
check_assertion_failure();
}
if (used.print) {
flush_print_buffers();
}
}
Expand Down Expand Up @@ -801,9 +807,10 @@ class KernelManager::Impl {
runtime_mem_->size());
}

void init_print_buffer() {
// TODO(k-ye): Do we need this at all?
did_modify_range(print_buffer_.get(), /*location=*/0, print_mem_->size());
void clear_print_assert_buffer() {
const auto sz = print_mem_->size();
std::memset(print_mem_->ptr(), 0, sz);
did_modify_range(print_buffer_.get(), /*location=*/0, sz);
}

void blit_buffers_and_sync(
Expand All @@ -828,10 +835,51 @@ class KernelManager::Impl {
profiler_->stop();
}

void check_assertion_failure() {
// TODO: Copy this to program's result_buffer, and let the Taichi runtime
// handle the assertion failures uniformly.
auto *asst_rec =
reinterpret_cast<shaders::AssertRecorderData *>(print_mem_->ptr());
if (!asst_rec->flag) {
return;
}
auto *msg_ptr = reinterpret_cast<int32_t *>(asst_rec + 1);
shaders::PrintMsg msg(msg_ptr, asst_rec->num_args);
using MsgType = shaders::PrintMsg::Type;
TI_ASSERT(msg.pm_get_type(0) == MsgType::Str);
const auto fmt_str = print_strtable_.get(msg.pm_get_data(0));
const auto err_str = format_error_message(fmt_str, [&msg](int argument_id) {
// +1 to skip the first arg, which is the error message template.
const int32 x = msg.pm_get_data(argument_id + 1);
return taichi_union_cast_with_different_sizes<uint64>(x);
});
// Note that we intentionally comment out the flag reset below, because it
// is ineffective at all. This is a very tricky part:
// 1. Under .managed storage mode, we need to call [didModifyRange:] to sync
// buffer data from CPU -> GPU. So ideally, after resetting the flag, we
// should just do so.
// 2. However, during the assertion (TI_ERROR), the stack unwinding seems to
// have deviated from the normal execution path. As a result, if we put
// [didModifyRange:] after TI_ERROR, it doesn't get executed...
// 3. The reason we put [didModifyRange:] after TI_ERROR is because we
// should do so after flush_print_buffers():
//
// check_assertion_failure(); <-- Code below is skipped...
// flush_print_buffers();
// memset(print_mem_->ptr(), 0, print_mem_->size());
// did_modify_range(print_buffer_);
//
// As a workaround, we put [didModifyRange:] before sync, where the program
// is still executing normally.
// asst_rec->flag = 0;
TI_ERROR("Assertion failure: {}", err_str);
}

void flush_print_buffers() {
auto *pa =
reinterpret_cast<shaders::PrintMsgAllocator *>(print_mem_->ptr());
const int used_sz = std::min(pa->next, shaders::kMetalPrintBufferSize);
auto *pa = reinterpret_cast<shaders::PrintMsgAllocator *>(
print_mem_->ptr() + shaders::kMetalAssertBufferSize);
const int used_sz =
std::min(pa->next, shaders::kMetalPrintMsgsMaxQueueSize);
using MsgType = shaders::PrintMsg::Type;
char *buf = reinterpret_cast<char *>(pa + 1);
const char *buf_end = buf + used_sz;
Expand All @@ -857,11 +905,13 @@ class KernelManager::Impl {
buf += shaders::mtl_compute_print_msg_bytes(num_entries);
}

if (pa->next >= shaders::kMetalPrintBufferSize) {
if (pa->next >= shaders::kMetalPrintMsgsMaxQueueSize) {
py_cout << "...(maximum print buffer reached)\n";
}

pa->next = 0;
// Comment out intentionally since it is ineffective otherwise. See
// check_assertion_failure() for the explanation.
// pa->next = 0;
}

static int compute_num_elems_per_chunk(int n) {
Expand Down Expand Up @@ -902,6 +952,7 @@ class KernelManager::Impl {
nsobj_unique_ptr<MTLBuffer> global_tmps_buffer_;
std::unique_ptr<BufferMemoryView> runtime_mem_;
nsobj_unique_ptr<MTLBuffer> runtime_buffer_;
// TODO: Rename these to 'print_assert_{mem|buffer}_'
std::unique_ptr<BufferMemoryView> print_mem_;
nsobj_unique_ptr<MTLBuffer> print_buffer_;
std::unordered_map<std::string, std::unique_ptr<CompiledTaichiKernel>>
Expand Down
2 changes: 2 additions & 0 deletions taichi/backends/metal/kernel_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ struct TaichiKernelAttributes {
struct UsedFeatures {
// Whether print() is called inside this kernel.
bool print = false;
// Whether assert is called inside this kernel.
bool assertion = false;
// Whether this kernel accesses (read or write) sparse SNodes.
bool sparse = false;
// Whether [[thread_index_in_simdgroup]] is used. This is only supported
Expand Down
Loading

0 comments on commit 6fea4c2

Please sign in to comment.