Skip to content

Commit

Permalink
[type] Support bit-level read and write in Python-scope (taichi-dev#2029
Browse files Browse the repository at this point in the history
)

Co-authored-by: Yuanming Hu <[email protected]>
Co-authored-by: Taichi Gardener <[email protected]>
Co-authored-by: Xuanda Yang <[email protected]>
  • Loading branch information
4 people authored Nov 7, 2020
1 parent 9ea17a6 commit 36232a5
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 3 deletions.
18 changes: 15 additions & 3 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,14 @@ void CodeGenLLVM::visit(ArgLoadStmt *stmt) {
llvm_val[stmt] = builder->CreateIntToPtr(raw_arg, dest_ty);
} else {
TI_ASSERT(!stmt->ret_type->is<PointerType>());
dest_ty = tlctx->get_data_type(stmt->ret_type);
if (auto cit = stmt->ret_type->cast<CustomIntType>()) {
if (cit->get_is_signed())
dest_ty = tlctx->get_data_type(PrimitiveType::i32);
else
dest_ty = tlctx->get_data_type(PrimitiveType::u32);
} else {
dest_ty = tlctx->get_data_type(stmt->ret_type);
}
auto dest_bits = dest_ty->getPrimitiveSizeInBits();
auto truncated = builder->CreateTrunc(
raw_arg, llvm::Type::getIntNTy(*llvm_context, dest_bits));
Expand All @@ -899,8 +906,13 @@ void CodeGenLLVM::visit(KernelReturnStmt *stmt) {
if (stmt->ret_type.is_pointer()) {
TI_NOT_IMPLEMENTED
} else {
auto intermediate_bits =
tlctx->get_data_type(stmt->value->ret_type)->getPrimitiveSizeInBits();
auto intermediate_bits = 0;
if (stmt->value->ret_type->is<CustomIntType>()) {
intermediate_bits = 32;
} else {
intermediate_bits =
tlctx->get_data_type(stmt->value->ret_type)->getPrimitiveSizeInBits();
}
llvm::Type *intermediate_type =
llvm::Type::getIntNTy(*llvm_context, intermediate_bits);
llvm::Type *dest_ty = tlctx->get_data_type<int64>();
Expand Down
11 changes: 11 additions & 0 deletions taichi/program/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,13 @@ void Kernel::LaunchContextBuilder::set_arg_int(int i, int64 d) {
ctx_->set_arg(i, (float32)d);
} else if (dt->is_primitive(PrimitiveTypeID::f64)) {
ctx_->set_arg(i, (float64)d);
} else if (auto cit = dt->cast<CustomIntType>()) {
if (cit->get_is_signed())
ctx_->set_arg(i, (int32)d);
else
ctx_->set_arg(i, (uint32)d);
} else {
TI_INFO(dt->to_string());
TI_NOT_IMPLEMENTED
}
}
Expand Down Expand Up @@ -295,6 +301,11 @@ int64 Kernel::get_ret_int(int i) {
return (int64)get_current_program().fetch_result<float32>(i);
} else if (dt->is_primitive(PrimitiveTypeID::f64)) {
return (int64)get_current_program().fetch_result<float64>(i);
} else if (auto cit = dt->cast<CustomIntType>()) {
if (cit->get_is_signed())
return (int64)get_current_program().fetch_result<int32>(i);
else
return (int64)get_current_program().fetch_result<uint32>(i);
} else {
TI_NOT_IMPLEMENTED
}
Expand Down
9 changes: 9 additions & 0 deletions tests/python/test_bit_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ def verify_val():
set_val()
verify_val()

# Test bit_struct SNode read and write in Python-scope by calling the wrapped, untranslated function body
set_val.__wrapped__()
verify_val.__wrapped__()


@ti.test(arch=ti.cpu, debug=True, cfg_optimization=False)
def test_custom_int_load_and_store():
Expand Down Expand Up @@ -66,3 +70,8 @@ def verify_val(idx: ti.i32):
for idx in range(len(test_case_np)):
set_val(idx)
verify_val(idx)

# Test bit_struct SNode read and write in Python-scope by calling the wrapped, untranslated function body
for idx in range(len(test_case_np)):
set_val.__wrapped__(idx)
verify_val.__wrapped__(idx)

0 comments on commit 36232a5

Please sign in to comment.