diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 2157bfa801365..34b74ebd89f01 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1,5 +1,7 @@ #include "codegen_llvm.h" +#include "taichi/struct/struct_llvm.h" + TLANG_NAMESPACE_BEGIN // TODO: sort function definitions to match declaration order in header @@ -206,14 +208,6 @@ std::unique_ptr CodeGenLLVM::emit_struct_meta_object( TI_P(snode_type_name(snode->type)); TI_NOT_IMPLEMENTED; } - if (false) { - // auto ptr_type = llvm::Type::getInt8PtrTy(*llvm_context, 0); - auto ptr_type = llvm::PointerType::get(meta->type, 0); - auto ptr = meta->ptr; // builder->CreatePointerCast(meta->ptr, ptr_type); - auto struct_meta_size = tlctx->get_type_size(meta->type); - builder->CreateIntrinsic(llvm::Intrinsic::invariant_start, {ptr_type}, - {tlctx->get_constant(struct_meta_size), ptr}); - } return meta; } @@ -223,13 +217,17 @@ void CodeGenLLVM::emit_struct_meta_base(const std::string &name, RuntimeObject common("StructMeta", this, builder.get(), node_meta); std::size_t element_size; if (snode->type == SNodeType::dense) { - auto element_ty = snode_attr[snode].llvm_body_type->getArrayElementType(); + auto body_type = + StructCompilerLLVM::get_llvm_body_type(module.get(), snode); + auto element_ty = body_type->getArrayElementType(); element_size = tlctx->get_type_size(element_ty); } else if (snode->type == SNodeType::pointer) { - auto element_ty = tlctx->snode_attr[snode->ch[0]].llvm_type; + auto element_ty = StructCompilerLLVM::get_llvm_node_type( + module.get(), snode->ch[0].get()); element_size = tlctx->get_type_size(element_ty); } else { - auto element_ty = tlctx->snode_attr[snode].llvm_element_type; + auto element_ty = + StructCompilerLLVM::get_llvm_element_type(module.get(), snode); element_size = tlctx->get_type_size(element_ty); } common.set("snode_id", tlctx->get_constant(snode->id)); @@ -266,13 +264,12 @@ void CodeGenLLVM::emit_struct_meta_base(const std::string &name, } CodeGenLLVM::CodeGenLLVM(Kernel *kernel, IRNode *ir) - // TODO: simplify ModuleBuilder ctor input - : ModuleBuilder(kernel->program.get_llvm_context(kernel->arch) - ->clone_struct_module()), + // TODO: simplify LLVMModuleBuilder ctor input + : LLVMModuleBuilder(kernel->program.get_llvm_context(kernel->arch) + ->clone_struct_module()), kernel(kernel), ir(ir), - prog(&kernel->program), - snode_attr(prog->get_llvm_context(kernel->arch)->snode_attr) { + prog(&kernel->program) { if (ir == nullptr) this->ir = kernel->ir; initialize_context(); @@ -1117,8 +1114,9 @@ llvm::Value *CodeGenLLVM::call(SNode *snode, void CodeGenLLVM::visit(GetRootStmt *stmt) { llvm_val[stmt] = builder->CreateBitCast( - get_root(), - PointerType::get(snode_attr[prog->snode_root.get()].llvm_type, 0)); + get_root(), PointerType::get(StructCompilerLLVM::get_llvm_node_type( + module.get(), prog->snode_root.get()), + 0)); } void CodeGenLLVM::visit(OffsetAndExtractBitsStmt *stmt) { @@ -1184,7 +1182,9 @@ void CodeGenLLVM::visit(GetChStmt *stmt) { {builder->CreateBitCast(llvm_val[stmt->input_ptr], PointerType::getInt8PtrTy(*llvm_context))}); llvm_val[stmt] = builder->CreateBitCast( - ch, PointerType::get(snode_attr[stmt->output_snode].llvm_type, 0)); + ch, PointerType::get(StructCompilerLLVM::get_llvm_node_type( + module.get(), stmt->output_snode), + 0)); } void CodeGenLLVM::visit(ExternalPtrStmt *stmt) { diff --git a/taichi/codegen/codegen_llvm.h b/taichi/codegen/codegen_llvm.h index 57be93d965ea3..030649ddadb13 100644 --- a/taichi/codegen/codegen_llvm.h +++ b/taichi/codegen/codegen_llvm.h @@ -48,7 +48,7 @@ class FunctionCreationGuard { ~FunctionCreationGuard(); }; -class CodeGenLLVM : public IRVisitor, public ModuleBuilder { +class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { public: static uint64 task_counter; @@ -66,7 +66,6 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { llvm::BasicBlock *current_while_after_loop; llvm::FunctionType *task_function_type; OffloadedStmt *current_offloaded_stmt; - SNodeAttributes &snode_attr; std::unordered_map llvm_val; llvm::Function *func; std::unique_ptr current_task; @@ -74,7 +73,7 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder { BasicBlock *func_body_bb; using IRVisitor::visit; - using ModuleBuilder::call; + using LLVMModuleBuilder::call; CodeGenLLVM(Kernel *kernel, IRNode *ir = nullptr); diff --git a/taichi/codegen/codegen_llvm_cuda.cpp b/taichi/codegen/codegen_llvm_cuda.cpp index 24487a296f24d..7ed6f18051686 100644 --- a/taichi/codegen/codegen_llvm_cuda.cpp +++ b/taichi/codegen/codegen_llvm_cuda.cpp @@ -135,7 +135,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { auto format_str = "[debug] " + stmt->str + " = " + format + "\n"; - llvm_val[stmt] = ModuleBuilder::call( + llvm_val[stmt] = LLVMModuleBuilder::call( builder.get(), "vprintf", builder->CreateGlobalStringPtr(format_str, "format_string"), builder->CreateBitCast(values, diff --git a/taichi/ir/snode.h b/taichi/ir/snode.h index b0c9aa061758d..aec200bc84f9b 100644 --- a/taichi/ir/snode.h +++ b/taichi/ir/snode.h @@ -256,28 +256,4 @@ class SNode { uint64 fetch_reader_result(); // TODO: refactor }; -class SNodeAttribute { - public: - llvm::Type *llvm_type, *llvm_body_type, *llvm_aux_type; - llvm::Type *llvm_element_type; -}; - -class SNodeAttributes { - private: - std::map snode_llvm_attr; - - public: - SNodeAttribute &operator[](SNode &snode) { - return snode_llvm_attr[&snode]; - } - - SNodeAttribute &operator[](SNode *snode) { - return snode_llvm_attr[snode]; - } - - SNodeAttribute &operator[](const std::unique_ptr &snode) { - return snode_llvm_attr[snode.get()]; - } -}; - TLANG_NAMESPACE_END diff --git a/taichi/llvm/llvm_codegen_utils.h b/taichi/llvm/llvm_codegen_utils.h index 177e10d75704f..b4696aaa0a32f 100644 --- a/taichi/llvm/llvm_codegen_utils.h +++ b/taichi/llvm/llvm_codegen_utils.h @@ -50,7 +50,7 @@ inline bool check_func_call_signature(llvm::Value *func, Args &&... args) { return check_func_call_signature(func, {args...}); } -class ModuleBuilder { +class LLVMModuleBuilder { public: std::unique_ptr module; llvm::BasicBlock *entry_block; @@ -58,7 +58,7 @@ class ModuleBuilder { TaichiLLVMContext *tlctx; llvm::LLVMContext *llvm_context; - ModuleBuilder(std::unique_ptr &&module) + LLVMModuleBuilder(std::unique_ptr &&module) : module(std::move(module)) { } @@ -131,12 +131,12 @@ class RuntimeObject { public: std::string cls_name; llvm::Value *ptr; - ModuleBuilder *mb; + LLVMModuleBuilder *mb; llvm::Type *type; llvm::IRBuilder<> *builder; RuntimeObject(const std::string &cls_name, - ModuleBuilder *mb, + LLVMModuleBuilder *mb, llvm::IRBuilder<> *builder, llvm::Value *init = nullptr) : cls_name(cls_name), mb(mb), builder(builder) { diff --git a/taichi/llvm/llvm_context.h b/taichi/llvm/llvm_context.h index 4975d95f74380..018dd6b8b4031 100644 --- a/taichi/llvm/llvm_context.h +++ b/taichi/llvm/llvm_context.h @@ -24,8 +24,6 @@ class TaichiLLVMContext { std::mutex mut; Arch arch; - SNodeAttributes snode_attr; - TaichiLLVMContext(Arch arch); std::unique_ptr get_init_module(); diff --git a/taichi/program/program.cpp b/taichi/program/program.cpp index cf45c6e4e6c3d..c028b2af9b886 100644 --- a/taichi/program/program.cpp +++ b/taichi/program/program.cpp @@ -11,6 +11,7 @@ #include "taichi/codegen/codegen_opengl.h" #include "taichi/codegen/codegen_cpu.h" #include "taichi/struct/struct.h" +#include "taichi/struct/struct_llvm.h" #include "taichi/struct/struct_metal.h" #include "taichi/struct/struct_opengl.h" #include "taichi/system/unified_allocator.h" @@ -215,15 +216,15 @@ void Program::initialize_runtime_system(StructCompiler *scomp) { for (int i = 0; i < (int)snodes.size(); i++) { if (is_gc_able(snodes[i]->type)) { std::size_t node_size; - if (snodes[i]->type == SNodeType::pointer) - node_size = tlctx->get_type_size( - scomp->snode_attr[snodes[i]].llvm_element_type); - else { + auto element_size = + tlctx->get_type_size(StructCompilerLLVM::get_llvm_element_type( + tlctx->struct_module.get(), snodes[i])); + if (snodes[i]->type == SNodeType::pointer) { + // pointer. Allocators are for single elements + node_size = element_size; + } else { // dynamic. Allocators are for the chunks - node_size = sizeof(void *) + - tlctx->get_type_size( - scomp->snode_attr[snodes[i]].llvm_element_type) * - snodes[i]->chunk_size; + node_size = sizeof(void *) + element_size * snodes[i]->chunk_size; } TI_TRACE("Initializing allocator for snode {} (node size {})", snodes[i]->id, node_size); diff --git a/taichi/struct/struct.h b/taichi/struct/struct.h index a56a96e57ac5f..4570786ec5412 100644 --- a/taichi/struct/struct.h +++ b/taichi/struct/struct.h @@ -14,8 +14,6 @@ class StructCompiler { std::size_t root_size; Program *prog; - SNodeAttributes snode_attr; - explicit StructCompiler(Program *prog); virtual ~StructCompiler() = default; diff --git a/taichi/struct/struct_llvm.cpp b/taichi/struct/struct_llvm.cpp index 7958aaa606096..58f335aa5ce4f 100644 --- a/taichi/struct/struct_llvm.cpp +++ b/taichi/struct/struct_llvm.cpp @@ -13,7 +13,7 @@ using namespace llvm; StructCompilerLLVM::StructCompilerLLVM(Program *prog, Arch arch) : StructCompiler(prog), - ModuleBuilder(prog->get_llvm_context(arch)->get_init_module()), + LLVMModuleBuilder(prog->get_llvm_context(arch)->get_init_module()), arch(arch) { tlctx = prog->get_llvm_context(arch); llvm_ctx = tlctx->ctx.get(); @@ -22,7 +22,7 @@ StructCompilerLLVM::StructCompilerLLVM(Program *prog, Arch arch) void StructCompilerLLVM::generate_types(SNode &snode) { TI_AUTO_PROF; auto type = snode.type; - llvm::Type *llvm_type = nullptr; + llvm::Type *node_type = nullptr; auto ctx = llvm_ctx; @@ -30,15 +30,12 @@ void StructCompilerLLVM::generate_types(SNode &snode) { std::vector ch_types; for (int i = 0; i < snode.ch.size(); i++) { - auto ch = snode_attr[snode.ch[i]].llvm_type; + auto ch = get_llvm_node_type(module.get(), snode.ch[i].get()); ch_types.push_back(ch); } auto ch_type = llvm::StructType::create(*ctx, ch_types, snode.node_type_name + "_ch"); - ch_type->setName(snode.node_type_name + "_ch"); - - snode_attr[snode].llvm_element_type = ch_type; llvm::Type *body_type = nullptr, *aux_type = nullptr; if (type == SNodeType::dense || type == SNodeType::bitmasked) { @@ -69,15 +66,21 @@ void StructCompilerLLVM::generate_types(SNode &snode) { TI_NOT_IMPLEMENTED; } if (aux_type != nullptr) { - llvm_type = llvm::StructType::create(*ctx, {aux_type, body_type}, ""); + node_type = llvm::StructType::create(*ctx, {aux_type, body_type}, ""); } else { - llvm_type = body_type; + node_type = body_type; } - TI_ASSERT(llvm_type != nullptr); - snode_attr[snode].llvm_type = llvm_type; - snode_attr[snode].llvm_aux_type = aux_type; - snode_attr[snode].llvm_body_type = body_type; + TI_ASSERT(node_type != nullptr); + TI_ASSERT(body_type != nullptr); + + // Here we create a stub holding 4 LLVM types as struct members. + // The aim is to give a **unique** name to the stub, so that we can look up + // these types using this name. This decouples them from the LLVM context. + // Note that body_type might not have a unique name, since literal structs + // (such as {i32, i32}) are uniqued in LLVM. + llvm::StructType::create(*ctx, {node_type, body_type, aux_type, ch_type}, + type_stub_name(&snode)); } void StructCompilerLLVM::generate_refine_coordinates(SNode *snode) { @@ -140,7 +143,7 @@ void StructCompilerLLVM::generate_child_accessors(SNode &snode) { auto parent = snode.parent; auto inp_type = - llvm::PointerType::get(snode_attr[parent].llvm_element_type, 0); + llvm::PointerType::get(get_llvm_element_type(module.get(), parent), 0); auto ft = llvm::FunctionType::get(llvm::Type::getInt8PtrTy(*llvm_ctx), @@ -174,6 +177,10 @@ void StructCompilerLLVM::generate_child_accessors(SNode &snode) { stack.pop_back(); } +std::string StructCompilerLLVM::type_stub_name(SNode *snode) { + return snode->node_type_name + "_type_stubs"; +} + void StructCompilerLLVM::run(SNode &root, bool host) { TI_AUTO_PROF; // bottom to top @@ -197,11 +204,44 @@ void StructCompilerLLVM::run(SNode &root, bool host) { TI_ASSERT((int)snodes.size() <= taichi_max_num_snodes); - root_size = - tlctx->get_data_layout().getTypeAllocSize(snode_attr[root].llvm_type); + auto node_type = get_llvm_node_type(module.get(), &root); + root_size = tlctx->get_data_layout().getTypeAllocSize(node_type); tlctx->set_struct_module(module); - tlctx->snode_attr = snode_attr; +} + +llvm::Type *StructCompilerLLVM::get_stub(llvm::Module *module, + SNode *snode, + uint32 index) { + TI_ASSERT(module); + TI_ASSERT(snode); + auto stub = module->getTypeByName(type_stub_name(snode)); + TI_ASSERT(stub); + TI_ASSERT(stub->getStructNumElements() == 4); + TI_ASSERT(0 <= index && index < 4); + auto type = stub->getContainedType(index); + TI_ASSERT(type); + return type; +} + +llvm::Type *StructCompilerLLVM::get_llvm_node_type(llvm::Module *module, + SNode *snode) { + return get_stub(module, snode, 0); +} + +llvm::Type *StructCompilerLLVM::get_llvm_body_type(llvm::Module *module, + SNode *snode) { + return get_stub(module, snode, 1); +} + +llvm::Type *StructCompilerLLVM::get_llvm_aux_type(llvm::Module *module, + SNode *snode) { + return get_stub(module, snode, 2); +} + +llvm::Type *StructCompilerLLVM::get_llvm_element_type(llvm::Module *module, + SNode *snode) { + return get_stub(module, snode, 3); } std::unique_ptr StructCompiler::make(Program *prog, Arch arch) { diff --git a/taichi/struct/struct_llvm.h b/taichi/struct/struct_llvm.h index 7cad7b4d04ff4..7930f1a6d74d8 100644 --- a/taichi/struct/struct_llvm.h +++ b/taichi/struct/struct_llvm.h @@ -1,11 +1,11 @@ // Codegen for the hierarchical data structure (LLVM) -#include "struct.h" +#include "taichi/struct/struct.h" #include "taichi/llvm/llvm_codegen_utils.h" TLANG_NAMESPACE_BEGIN -class StructCompilerLLVM : public StructCompiler, public ModuleBuilder { +class StructCompilerLLVM : public StructCompiler, public LLVMModuleBuilder { public: StructCompilerLLVM(Program *prog, Arch arch); @@ -20,6 +20,18 @@ class StructCompilerLLVM : public StructCompiler, public ModuleBuilder { void run(SNode &node, bool host) override; void generate_refine_coordinates(SNode *snode); + + static std::string type_stub_name(SNode *snode); + + static llvm::Type *get_stub(llvm::Module *module, SNode *snode, uint32 index); + + static llvm::Type *get_llvm_node_type(llvm::Module *module, SNode *snode); + + static llvm::Type *get_llvm_body_type(llvm::Module *module, SNode *snode); + + static llvm::Type *get_llvm_aux_type(llvm::Module *module, SNode *snode); + + static llvm::Type *get_llvm_element_type(llvm::Module *module, SNode *snode); }; TLANG_NAMESPACE_END