Skip to content

Commit

Permalink
[refactor] Compile the Ndarray argument to a struct
Browse files Browse the repository at this point in the history
ghstack-source-id: 05892df446d0747f3f1ae31b5e804163cd111baf
Pull Request resolved: taichi-dev#7809
  • Loading branch information
lin-hitonami authored and Taichi Gardener committed Apr 20, 2023
1 parent 457ada6 commit a894676
Show file tree
Hide file tree
Showing 20 changed files with 115 additions and 73 deletions.
2 changes: 1 addition & 1 deletion python/taichi/lang/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def decl_sparse_matrix(dtype, name):
def decl_ndarray_arg(dtype, dim, element_shape, layout, name):
dtype = cook_dtype(dtype)
element_dim = len(element_shape)
arg_id = impl.get_runtime().compiling_callable.insert_arr_param(dtype, dim, element_shape, name)
arg_id = impl.get_runtime().compiling_callable.insert_ndarray_param(dtype, dim, element_shape, name)
if layout == Layout.AOS:
element_dim = -element_dim
return AnyArray(_ti_core.make_external_tensor_expr(dtype, dim, arg_id, element_dim, element_shape))
Expand Down
31 changes: 22 additions & 9 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1843,6 +1843,24 @@ void TaskCodeGenLLVM::visit(MatrixPtrStmt *stmt) {
}

void TaskCodeGenLLVM::visit(ExternalPtrStmt *stmt) {
// Index into ndarray struct
DataType operand_dtype = stmt->base_ptr->ret_type.ptr_removed()
->as<StructType>()
->get_element_type({0})
->as<PointerType>()
->get_pointee_type();
auto arg_type = operand_dtype;
if (operand_dtype->is<TensorType>()) {
arg_type = operand_dtype->as<TensorType>()->get_element_type();
}
auto ptr_type = TypeFactory::get_instance().get_pointer_type(arg_type);
auto *struct_type = tlctx->get_data_type(
TypeFactory::get_instance().get_struct_type({{ptr_type}}));
std::vector<llvm::Value *> index(2, tlctx->get_constant(0));
auto *gep =
builder->CreateGEP(struct_type, llvm_val.at(stmt->base_ptr), index);
auto *ptr_val = builder->CreateLoad(tlctx->get_data_type(ptr_type), gep);

auto argload = stmt->base_ptr->as<ArgLoadStmt>();
auto arg_id = argload->arg_id;
int num_indices = stmt->indices.size();
Expand Down Expand Up @@ -1902,13 +1920,8 @@ void TaskCodeGenLLVM::visit(ExternalPtrStmt *stmt) {
However, this does not fit with Taichi's Ndarray semantics. We will have to
do pointer arithmetics to manually calculate the offset.
*/
DataType operand_dtype = argload->ret_type.ptr_removed();
if (operand_dtype->is<TensorType>()) {
// Access PtrOffset via: base_ptr + offset * sizeof(element)
auto primitive_type = operand_dtype.get_element_type();
auto primitive_ptr = builder->CreateBitCast(
llvm_val[stmt->base_ptr],
llvm::PointerType::get(tlctx->get_data_type(primitive_type), 0));

auto address_offset = builder->CreateSExt(
linear_index, llvm::Type::getInt64Ty(*llvm_context));
Expand All @@ -1928,15 +1941,15 @@ void TaskCodeGenLLVM::visit(ExternalPtrStmt *stmt) {
// the stride for linear_index is 1, and there's nothing to do here.
}

auto ret_ptr = builder->CreateGEP(tlctx->get_data_type(primitive_type),
primitive_ptr, address_offset);
auto ret_ptr = builder->CreateGEP(tlctx->get_data_type(arg_type), ptr_val,
address_offset);
llvm_val[stmt] = builder->CreateBitCast(
ret_ptr, llvm::PointerType::get(tlctx->get_data_type(dt), 0));

} else {
auto base_ty = tlctx->get_data_type(dt);
auto base = builder->CreateBitCast(llvm_val[stmt->base_ptr],
llvm::PointerType::get(base_ty, 0));
auto base =
builder->CreateBitCast(ptr_val, llvm::PointerType::get(base_ty, 0));

llvm_val[stmt] = builder->CreateGEP(base_ty, base, linear_index);
}
Expand Down
10 changes: 7 additions & 3 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,10 @@ class TaskCodegen : public IRVisitor {
void visit(ArgLoadStmt *stmt) override {
const auto arg_id = stmt->arg_id;
const auto arg_type = ctx_attribs_->args_type()->get_element_type({arg_id});
if (arg_type->is<PointerType>()) {
if (arg_type->is<PointerType>() ||
(arg_type->is<lang::StructType>() && arg_type->as<lang::StructType>()
->get_element_type({0})
->is<PointerType>())) {
// Do not shift! We are indexing the buffers at byte granularity.
// spirv::Value val =
// ir_->int_immediate_number(ir_->i32_type(), offset_in_mem);
Expand Down Expand Up @@ -692,7 +695,7 @@ class TaskCodegen : public IRVisitor {
spv::OpShiftLeftLogical, ir_->i32_type(), linear_offset,
ir_->int_immediate_number(ir_->i32_type(),
log2int(ir_->get_primitive_type_size(
argload->ret_type.ptr_removed()))));
stmt->ret_type.ptr_removed()))));
if (caps_->get(DeviceCapability::spirv_has_no_integer_wrap_decoration)) {
ir_->decorate(spv::OpDecorate, linear_offset,
spv::DecorationNoSignedWrap);
Expand All @@ -703,7 +706,8 @@ class TaskCodegen : public IRVisitor {
spv::OpAccessChain,
ir_->get_pointer_type(ir_->u64_type(), spv::StorageClassUniform),
get_buffer_value(BufferType::Args, PrimitiveType::i32),
ir_->int_immediate_number(ir_->i32_type(), arg_id));
ir_->int_immediate_number(ir_->i32_type(), arg_id),
ir_->int_immediate_number(ir_->i32_type(), 0));
spirv::Value addr = ir_->load_variable(addr_ptr, ir_->u64_type());
addr = ir_->add(addr, ir_->make_value(spv::OpSConvert, ir_->u64_type(),
linear_offset));
Expand Down
3 changes: 3 additions & 0 deletions taichi/codegen/spirv/spirv_ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,9 @@ SType IRBuilder::from_taichi_type(const DataType &dt, bool has_buffer_ptr) {
}

size_t IRBuilder::get_primitive_type_size(const DataType &dt) const {
if (!dt->is<PrimitiveType>()) {
TI_ERROR("Type {} not supported.", dt->to_string());
}
if (dt == PrimitiveType::i64 || dt == PrimitiveType::u64 ||
dt == PrimitiveType::f64) {
return 8;
Expand Down
11 changes: 8 additions & 3 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -596,9 +596,14 @@ void ExternalTensorExpression::flatten(FlattenContext *ctx) {
// turned-on by default.
// The scalarization should happen after
// irpass::lower_access()
auto prim_dt = dt;
auto ptr = Stmt::make<ArgLoadStmt>(arg_id, prim_dt, /*is_ptr=*/true,
/*is_grad=*/is_grad, /*create_load=*/true);
auto ret_type = TypeFactory::get_instance().get_pointer_type(dt);
std::vector<StructMember> members;
members.push_back({ret_type, "data_ptr"});
auto type = TypeFactory::get_instance().get_struct_type(members);

auto ptr =
Stmt::make<ArgLoadStmt>(arg_id, type, /*is_ptr=*/true,
/*is_grad=*/is_grad, /*create_load=*/false);

ptr->tb = tb;
ctx->push_back(std::move(ptr));
Expand Down
9 changes: 9 additions & 0 deletions taichi/ir/ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,5 +492,14 @@ MeshRelationAccessStmt *IRBuilder::get_relation_access(
MeshPatchIndexStmt *IRBuilder::get_patch_index() {
return insert(Stmt::make_typed<MeshPatchIndexStmt>());
}
ArgLoadStmt *IRBuilder::create_ndarray_arg_load(int arg_id, DataType dt) {
auto ret_type = TypeFactory::get_instance().get_pointer_type(dt);
std::vector<StructMember> members;
members.push_back({ret_type, "data_ptr"});
auto type = TypeFactory::get_instance().get_struct_type(members);

return insert(Stmt::make_typed<ArgLoadStmt>(
arg_id, type, /*is_ptr=*/true, /*is_grad=*/false, /*create_load=*/false));
}

} // namespace taichi::lang
2 changes: 2 additions & 0 deletions taichi/ir/ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ class IRBuilder {

// Load kernel arguments.
ArgLoadStmt *create_arg_load(int arg_id, DataType dt, bool is_ptr);
// Load kernel arguments.
ArgLoadStmt *create_ndarray_arg_load(int arg_id, DataType dt);

// The return value of the kernel.
ReturnStmt *create_return(Stmt *value);
Expand Down
19 changes: 18 additions & 1 deletion taichi/program/callable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,23 @@ int Callable::insert_arr_param(const DataType &dt,
return (int)parameter_list.size() - 1;
}

int Callable::insert_ndarray_param(const DataType &dt,
int total_dim,
std::vector<int> element_shape,
const std::string &name) {
// Transform ndarray param to a struct type with a pointer to `dt`.
std::vector<StructMember> members;
members.push_back(
{TypeFactory::get_instance().get_pointer_type(dt->get_compute_type()),
"data_ptr"});

auto *type = TypeFactory::get_instance().get_struct_type(members);
parameter_list.emplace_back(type, /*is_array=*/true,
/*size=*/0, total_dim, element_shape);
parameter_list.back().name = name;
return (int)parameter_list.size() - 1;
}

int Callable::insert_texture_param(int total_dim, const std::string &name) {
// FIXME: we shouldn't abuse is_array for texture parameters
parameter_list.emplace_back(PrimitiveType::f32, /*is_array=*/true, 0,
Expand Down Expand Up @@ -73,7 +90,7 @@ void Callable::finalize_params() {
for (int i = 0; i < parameter_list.size(); i++) {
auto &param = parameter_list[i];
members.push_back(
{param.is_array
{param.is_array && !param.get_dtype()->is<StructType>()
? TypeFactory::get_instance().get_pointer_type(param.get_dtype())
: (const Type *)param.get_dtype(),
fmt::format("arg_{}", i)});
Expand Down
4 changes: 4 additions & 0 deletions taichi/program/callable.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ class TI_DLL_EXPORT Callable : public CallableBase {
int total_dim,
std::vector<int> element_shape,
const std::string &name = "");
int insert_ndarray_param(const DataType &dt,
int total_dim,
std::vector<int> element_shape,
const std::string &name = "");
int insert_texture_param(int total_dim, const std::string &name = "");
int insert_pointer_param(const DataType &dt, const std::string &name = "");
int insert_rw_texture_param(int total_dim,
Expand Down
8 changes: 5 additions & 3 deletions taichi/program/launch_context_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@ void LaunchContextBuilder::set_arg_float(int arg_id, float64 d) {
template <typename T>
void LaunchContextBuilder::set_struct_arg(std::vector<int> arg_indices, T d) {
auto dt = kernel_->args_type->get_element_type(arg_indices);
TI_ASSERT_INFO(dt->is<PrimitiveType>(),
"Assigning scalar value to external (numpy) array argument is "
"not allowed.");

TI_ASSERT(dt->is<PrimitiveType>() || dt->is<PointerType>());
if (dt->is<PointerType>()) {
set_struct_arg_impl(arg_indices, (uint64)d);
return;
}
PrimitiveTypeID typeId = dt->as<PrimitiveType>()->type;

switch (typeId) {
Expand Down
1 change: 1 addition & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,7 @@ void export_lang(py::module &m) {
})
.def("insert_scalar_param", &Kernel::insert_scalar_param)
.def("insert_arr_param", &Kernel::insert_arr_param)
.def("insert_ndarray_param", &Kernel::insert_ndarray_param)
.def("insert_texture_param", &Kernel::insert_texture_param)
.def("insert_pointer_param", &Kernel::insert_pointer_param)
.def("insert_rw_texture_param", &Kernel::insert_rw_texture_param)
Expand Down
4 changes: 2 additions & 2 deletions taichi/runtime/cpu/kernel_launcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void KernelLauncher::launch_llvm_kernel(Handle handle,
if (parameters[i].is_array &&
ctx.device_allocation_type[i] ==
LaunchContextBuilder::DevAllocType::kNone) {
ctx.set_arg(i, (uint64)ctx.array_ptrs[{i}]);
ctx.set_struct_arg({i, 0}, (uint64)ctx.array_ptrs[{i}]);
}
if (parameters[i].is_array &&
ctx.device_allocation_type[i] !=
Expand All @@ -27,7 +27,7 @@ void KernelLauncher::launch_llvm_kernel(Handle handle,
DeviceAllocation *ptr =
static_cast<DeviceAllocation *>(ctx.array_ptrs[{i}]);
uint64 host_ptr = (uint64)executor->get_ndarray_alloc_info_ptr(*ptr);
ctx.set_arg(i, host_ptr);
ctx.set_struct_arg({i, 0}, host_ptr);
ctx.set_array_device_allocation_type(
i, LaunchContextBuilder::DevAllocType::kNone);

Expand Down
4 changes: 2 additions & 2 deletions taichi/runtime/cuda/kernel_launcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ void KernelLauncher::launch_llvm_kernel(Handle handle,
device_buffers[i] = arg_buffers[i];
}
// device_buffers[i] saves a raw ptr on CUDA device.
ctx.set_arg(i, (uint64)device_buffers[i]);
ctx.set_struct_arg({i, 0}, (uint64)device_buffers[i]);

} else if (arr_sz > 0) {
// arg_buffers[i] is a DeviceAllocation*
Expand All @@ -77,7 +77,7 @@ void KernelLauncher::launch_llvm_kernel(Handle handle,
arg_buffers[i] = device_buffers[i];

// device_buffers[i] saves the unwrapped raw ptr from arg_buffers[i]
ctx.set_arg(i, (uint64)device_buffers[i]);
ctx.set_struct_arg({i, 0}, (uint64)device_buffers[i]);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/gfx/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class HostDeviceContextBlitter {
DeviceCapability::spirv_has_physical_storage_buffer)) {
uint64_t addr =
device_->get_memory_physical_pointer(ext_arrays.at(i));
host_ctx_.set_arg(i, addr);
host_ctx_.set_struct_arg({i, 0}, addr);
}
}
}
Expand Down
11 changes: 7 additions & 4 deletions taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -838,10 +838,13 @@ class MergeExternalAndMatrixPtr : public BasicStmtVisitor {
auto fused = std::make_unique<ExternalPtrStmt>(
origin->base_ptr, indices, element_shape, element_dim);
fused->ret_type = stmt->ret_type;
// Note: Update base_ptr's ret_type so that it matches the
// ExternalPtrStmt with flattened indices. Main goal is to keep all the
// hacks in a single place so that they're easier to remove
origin->base_ptr->as<ArgLoadStmt>()->ret_type = stmt->ret_type;
// Note: Update base_ptr's ret_type so that it matches the ExternalPtrStmt
// with flattened indices. Main goal is to keep all the hacks in a single
// place so that they're easier to remove
std::vector<StructMember> members;
members.push_back({stmt->ret_type, "data_ptr"});
auto type = TypeFactory::get_instance().get_struct_type(members);
origin->base_ptr->as<ArgLoadStmt>()->ret_type = type;
stmt->replace_usages_with(fused.get());
modifier_.insert_before(stmt, std::move(fused));
modifier_.erase(stmt);
Expand Down
4 changes: 3 additions & 1 deletion taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,9 @@ class TypeCheck : public IRVisitor {
if (stmt->overrided_dtype) {
// pass
} else {
stmt->ret_type = arg_load_stmt->ret_type;
stmt->ret_type = arg_load_stmt->ret_type.ptr_removed()
->as<StructType>()
->get_element_type({0});
}

stmt->ret_type.set_is_pointer(true);
Expand Down
12 changes: 6 additions & 6 deletions tests/cpp/ir/ir_builder_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ TEST(IRBuilder, ExternalPtr) {
auto array = std::make_unique<int[]>(size);
array[0] = 2;
array[2] = 40;
auto *arg = builder.create_arg_load(/*arg_id=*/0, get_data_type<int>(),
/*is_ptr=*/true);
auto *arg =
builder.create_ndarray_arg_load(/*arg_id=*/0, get_data_type<int>());
auto *zero = builder.get_int32(0);
auto *one = builder.get_int32(1);
auto *two = builder.get_int32(2);
Expand All @@ -111,7 +111,7 @@ TEST(IRBuilder, ExternalPtr) {
builder.create_global_store(a2ptr, a0plusa2); // a[2] = a[0] + a[2]
auto block = builder.extract_ir();
auto ker = std::make_unique<Kernel>(*test_prog.prog(), std::move(block));
ker->insert_arr_param(get_data_type<int>(), /*total_dim=*/1, {1});
ker->insert_ndarray_param(get_data_type<int>(), /*total_dim=*/1, {1});
ker->finalize_params();
auto launch_ctx = ker->make_launch_context();
launch_ctx.set_arg_external_array_with_shape(
Expand Down Expand Up @@ -164,15 +164,15 @@ TEST(IRBuilder, AtomicOp) {
auto array = std::make_unique<int[]>(size);
array[0] = 2;
array[2] = 40;
auto *arg = builder.create_arg_load(/*arg_id=*/0, get_data_type<int>(),
/*is_ptr=*/true);
auto *arg =
builder.create_ndarray_arg_load(/*arg_id=*/0, get_data_type<int>());
auto *zero = builder.get_int32(0);
auto *one = builder.get_int32(1);
auto *a0ptr = builder.create_external_ptr(arg, {zero});
builder.create_atomic_add(a0ptr, one); // a[0] += 1
auto block = builder.extract_ir();
auto ker = std::make_unique<Kernel>(*test_prog.prog(), std::move(block));
ker->insert_arr_param(get_data_type<int>(), /*total_dim=*/1, {1});
ker->insert_ndarray_param(get_data_type<int>(), /*total_dim=*/1, {1});
ker->finalize_params();
auto launch_ctx = ker->make_launch_context();
launch_ctx.set_arg_external_array_with_shape(
Expand Down
12 changes: 6 additions & 6 deletions tests/cpp/ir/ndarray_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ namespace taichi::lang {
std::unique_ptr<Kernel> setup_kernel1(Program *prog) {
IRBuilder builder1;
{
auto *arg = builder1.create_arg_load(/*arg_id=*/0, get_data_type<int>(),
/*is_ptr=*/true);
auto *arg =
builder1.create_ndarray_arg_load(/*arg_id=*/0, get_data_type<int>());
auto *zero = builder1.get_int32(0);
auto *one = builder1.get_int32(1);
auto *two = builder1.get_int32(2);
Expand All @@ -21,7 +21,7 @@ std::unique_ptr<Kernel> setup_kernel1(Program *prog) {
}
auto block = builder1.extract_ir();
auto ker1 = std::make_unique<Kernel>(*prog, std::move(block), "ker1");
ker1->insert_arr_param(get_data_type<int>(), /*total_dim=*/1, {1});
ker1->insert_ndarray_param(get_data_type<int>(), /*total_dim=*/1, {1});
ker1->finalize_params();
ker1->finalize_rets();
return ker1;
Expand All @@ -31,8 +31,8 @@ std::unique_ptr<Kernel> setup_kernel2(Program *prog) {
IRBuilder builder2;

{
auto *arg0 = builder2.create_arg_load(/*arg_id=*/0, get_data_type<int>(),
/*is_ptr=*/true);
auto *arg0 =
builder2.create_ndarray_arg_load(/*arg_id=*/0, get_data_type<int>());
auto *arg1 = builder2.create_arg_load(/*arg_id=*/1, get_data_type<int>(),
/*is_ptr=*/false);
auto *one = builder2.get_int32(1);
Expand All @@ -41,7 +41,7 @@ std::unique_ptr<Kernel> setup_kernel2(Program *prog) {
}
auto block2 = builder2.extract_ir();
auto ker2 = std::make_unique<Kernel>(*prog, std::move(block2), "ker2");
ker2->insert_arr_param(get_data_type<int>(), /*total_dim=*/1, {1});
ker2->insert_ndarray_param(get_data_type<int>(), /*total_dim=*/1, {1});
ker2->insert_scalar_param(get_data_type<int>());
ker2->finalize_params();
ker2->finalize_rets();
Expand Down
10 changes: 8 additions & 2 deletions tests/cpp/transforms/half2_vectorization_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,15 @@ TEST(Half2Vectorization, Ndarray) {
alloca_stmt1->replace_all_usages_with(old_val1);
*/

auto ret_type =
TypeFactory::get_instance().get_pointer_type(PrimitiveType::f16);
std::vector<StructMember> members;
members.push_back({ret_type, "data_ptr"});
auto type = TypeFactory::get_instance().get_struct_type(members);

auto argload_stmt = block->push_back<ArgLoadStmt>(
0 /*arg_id*/, PrimitiveType::f16, /*is_ptr*/ false, /*is_grad*/ false,
/*create_load*/ true);
0 /*arg_id*/, type, /*is_ptr*/ true, /*is_grad*/ false,
/*create_load*/ false);
auto const_0_stmt = block->push_back<ConstStmt>(TypedConstant(0));
auto const_1_stmt = block->push_back<ConstStmt>(TypedConstant(1));

Expand Down
Loading

0 comments on commit a894676

Please sign in to comment.