Skip to content

Commit

Permalink
[refactor] Re-impl JIT and Offline Cache on LLVM backends
Browse files Browse the repository at this point in the history
ghstack-source-id: f300eb101d3c0dd7bc2a99bd0764a97bb75dd9c7
Pull Request resolved: taichi-dev#7585
  • Loading branch information
PGZXB authored and Taichi Gardener committed Mar 24, 2023
1 parent 01bc21f commit fb9756d
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 54 deletions.
18 changes: 14 additions & 4 deletions taichi/compilation_manager/kernel_compilation_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ void KernelCompilationManager::dump() {
if (!lock_with_file(lock_path)) {
TI_WARN("Lock {} failed. Please run 'ti cache clean -p {}' and try again.",
lock_path, config_.offline_cache_path);
caching_kernels_.clear(); // Ignore the caching kernels
return;
}

Expand All @@ -118,6 +119,8 @@ void KernelCompilationManager::dump() {
TI_ASSERT(!ok || iter->second.size == 0);
}
}
// Clear caching_kernels_
caching_kernels_.clear();
// Dump cached CompiledKernelData to disk
for (auto &[_, k] : kernels) {
if (k.compiled_kernel_data) {
Expand All @@ -126,9 +129,15 @@ void KernelCompilationManager::dump() {
if (try_lock_with_file(cache_filename)) {
std::ofstream fs{cache_filename, std::ios::out | std::ios::binary};
TI_ASSERT(fs.is_open());
k.compiled_kernel_data->dump(fs);
k.size = fs.tellp();
data.size += k.size;
auto err = k.compiled_kernel_data->dump(fs);
if (err == CompiledKernelData::Err::kNoError) {
TI_ASSERT(!!fs);
k.size = fs.tellp();
data.size += k.size;
} else {
TI_DEBUG("Dump cached CompiledKernelData(kernel_key={}) failed: {}",
k.kernel_key, CompiledKernelData::get_err_msg(err));
}
}
}
}
Expand Down Expand Up @@ -264,7 +273,8 @@ std::unique_ptr<CompiledKernelData> KernelCompilationManager::load_ckd(
CacheData::CacheMode KernelCompilationManager::get_cache_mode(
const CompileConfig &compile_config,
const Kernel &kernel_def) {
return compile_config.offline_cache && kernel_def.ir_is_ast()
return compile_config.offline_cache && kernel_def.ir_is_ast() &&
!kernel_def.is_evaluator
? CacheData::MemAndDiskCache
: CacheData::MemCache;
}
Expand Down
76 changes: 53 additions & 23 deletions taichi/runtime/program_impls/llvm/llvm_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include "llvm/IR/Module.h"

#include "taichi/codegen/cpu/codegen_cpu.h"
#include "taichi/codegen/llvm/llvm_compiled_data.h"
#include "taichi/program/program.h"
#include "taichi/codegen/codegen.h"
#include "taichi/codegen/llvm/struct_llvm.h"
Expand All @@ -24,7 +26,45 @@
#include "taichi/codegen/dx12/codegen_dx12.h"
#endif

#include "taichi/codegen/llvm/kernel_compiler.h"
#include "taichi/codegen/llvm/compiled_kernel_data.h"

namespace taichi::lang {
namespace {
FunctionType llvm_compiled_kernel_to_executable(
Arch arch,
TaichiLLVMContext *tlctx,
LlvmRuntimeExecutor *executor,
Kernel *kernel,
LLVMCompiledKernel llvm_compiled_kernel) {
TI_ASSERT(arch_uses_llvm(arch));

FunctionType func = nullptr;
if (arch_is_cpu(arch)) {
CPUModuleToFunctionConverter converter(tlctx, executor);
func = converter.convert(kernel, std::move(llvm_compiled_kernel));
} else if (arch == Arch::cuda) {
#if defined(TI_WITH_CUDA)
CUDAModuleToFunctionConverter converter(tlctx, executor);
func = converter.convert(kernel, std::move(llvm_compiled_kernel));
#endif
} else if (arch == Arch::amdgpu) {
#if defined(TI_WITH_AMDGPU)
AMDGPUModuleToFunctionConverter converter(tlctx, executor);
func = converter.convert(kernel, std::move(llvm_compiled_kernel));
#endif
} else if (arch == Arch::wasm) {
// Not implemented
} else if (arch == Arch::dx12) {
// Not implemented
}

if (!func) {
TI_NOT_IMPLEMENTED;
}
return func;
}
} // namespace

LlvmProgramImpl::LlvmProgramImpl(CompileConfig &config_,
KernelProfilerBase *profiler)
Expand All @@ -36,9 +76,15 @@ LlvmProgramImpl::LlvmProgramImpl(CompileConfig &config_,

FunctionType LlvmProgramImpl::compile(const CompileConfig &compile_config,
Kernel *kernel) {
auto codegen = KernelCodeGen::create(compile_config, kernel, kernel->ir.get(),
*runtime_exec_->get_llvm_context());
return codegen->compile_to_function();
// NOTE: Temporary implementation
// TODO(PGZXB): Final solution: compile -> load_or_compile + launch_kernel
auto &mgr = get_kernel_compilation_manager();
const auto &compiled = mgr.load_or_compile(compile_config, {}, *kernel);
auto &llvm_data = dynamic_cast<const llvm::CompiledKernelData &>(compiled);
return llvm_compiled_kernel_to_executable(
compile_config.arch, runtime_exec_->get_llvm_context(),
runtime_exec_.get(), kernel,
llvm_data.get_internal_data().compiled_data.clone());
}

std::unique_ptr<StructCompiler> LlvmProgramImpl::compile_snode_tree_types_impl(
Expand Down Expand Up @@ -126,26 +172,10 @@ void LlvmProgramImpl::cache_field(int snode_tree_id,
cache_data_->fields[snode_tree_id] = std::move(ret);
}

void LlvmProgramImpl::dump_cache_data_to_disk() {
if (config->offline_cache) {
auto policy = offline_cache::string_to_clean_cache_policy(
config->offline_cache_cleaning_policy);
LlvmOfflineCacheFileWriter::clean_cache(
offline_cache::get_cache_path_by_arch(config->offline_cache_file_path,
config->arch),
policy, config->offline_cache_max_size_of_files,
config->offline_cache_cleaning_factor);
if (!cache_data_->kernels.empty()) {
LlvmOfflineCacheFileWriter writer{};
writer.set_data(std::move(cache_data_));

// Note: For offline-cache, new-metadata should be merged with
// old-metadata
writer.dump(offline_cache::get_cache_path_by_arch(
config->offline_cache_file_path, config->arch),
LlvmOfflineCache::LL, true);
}
}
std::unique_ptr<KernelCompiler> LlvmProgramImpl::make_kernel_compiler() {
lang::llvm::KernelCompiler::Config cfg;
cfg.tlctx = runtime_exec_->get_llvm_context();
return std::make_unique<lang::llvm::KernelCompiler>(std::move(cfg));
}

LlvmProgramImpl *get_llvm_program(Program *prog) {
Expand Down
6 changes: 1 addition & 5 deletions taichi/runtime/program_impls/llvm/llvm_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ class LlvmProgramImpl : public ProgramImpl {
std::unique_ptr<AotModuleBuilder> make_aot_module_builder(
const DeviceCapabilityConfig &caps) override;

void dump_cache_data_to_disk() override;

/* -------------------------------- */
/* ---- JIT-Runtime Interfaces ---- */
/* -------------------------------- */
Expand Down Expand Up @@ -276,9 +274,7 @@ class LlvmProgramImpl : public ProgramImpl {
ParallelExecutor compilation_workers; // parallel compilation

protected:
std::unique_ptr<KernelCompiler> make_kernel_compiler() override {
TI_NOT_IMPLEMENTED;
}
std::unique_ptr<KernelCompiler> make_kernel_compiler() override;

private:
std::size_t num_snode_trees_processed_{0};
Expand Down
4 changes: 2 additions & 2 deletions tests/python/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,12 @@ def test_cli_run():
assert args.filename == "a.py"


def _test_cli_cache(): # TODO(PGZXB): Re-enable the test
def test_cli_cache():
archs = {
ti.cpu, ti.cuda, ti.opengl, ti.vulkan, ti.metal, ti.gles, ti.amdgpu
}
archs = {v for v in archs if v in test_utils.expected_archs()}
exts = ('ll', 'bc', 'spv', 'metal', 'tcb', 'lock')
exts = ('tic', 'tcb', 'lock')
tmp_path = tempfile.mkdtemp()

@ti.kernel
Expand Down
25 changes: 5 additions & 20 deletions tests/python/test_offline_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
OFFLINE_CACHE_TEMP_DIR = mkdtemp()
atexit.register(lambda: rmdir(OFFLINE_CACHE_TEMP_DIR))

supported_llvm_archs = set()
supported_llvm_archs = {ti.cpu, ti.cuda}
supported_gfx_archs = {ti.opengl, ti.vulkan, ti.metal}
supported_archs_offline_cache = supported_llvm_archs | supported_gfx_archs
supported_archs_offline_cache = {
Expand All @@ -28,7 +28,7 @@


def is_offline_cache_file(filename):
suffixes = ('.ll', '.bc', '.tic')
suffixes = ('.tic', )
return filename.endswith(suffixes)


Expand All @@ -45,31 +45,16 @@ def expected_num_cache_files(arch, num_offloads: List[int] = None) -> int:
assert arch in supported_archs_offline_cache
if not num_offloads:
return 0
result = 0
# code files
if arch in supported_llvm_archs:
result += len(num_offloads)
elif arch in supported_gfx_archs:
result += len(num_offloads)
# metadata files
if arch in supported_llvm_archs:
result += 2 # metadata.{json, tcb}
elif arch in supported_gfx_archs:
# ticache.tcb
result += 1
return result
# code files(*.tic) + metadata files(ticache.tcb)
return len(num_offloads) + 1


def tmp_offline_cache_file_path():
return join(OFFLINE_CACHE_TEMP_DIR, str(threading.currentThread().ident))


def backend_specified_cache_path(arch):
if arch in supported_llvm_archs:
return join(tmp_offline_cache_file_path(), 'llvm')
elif arch in supported_gfx_archs:
return tmp_offline_cache_file_path()
assert False
return tmp_offline_cache_file_path()


def current_thread_ext_options():
Expand Down

0 comments on commit fb9756d

Please sign in to comment.