Skip to content

Commit

Permalink
used llvm orcjit example from BuildingAJIT3 instead of BuildingAJIT2 …
Browse files Browse the repository at this point in the history
…to avoid excaption crashing after JIT destruction
  • Loading branch information
yuanming-hu committed Oct 18, 2019
1 parent 14b46e5 commit 2f32eb4
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 52 deletions.
6 changes: 3 additions & 3 deletions lang/src/backends/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,12 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder {
}

virtual FunctionType compile_module_to_executable() {
llvm::cantFail(jit->addModule(std::move(module)));
jit->addModule(std::move(module));

auto kernel_symbol = llvm::cantFail(jit->lookup(kernel_name));
auto kernel_symbol = jit->lookup(kernel_name);
TC_ASSERT_INFO(kernel_symbol, "Function not found");

auto f = (int32(*)(void *))(void *)(kernel_symbol.getAddress());
auto f = (int32(*)(void *))(void *)(llvm::cantFail(kernel_symbol.getAddress()));
return [=](Context context) { f(&context); };
}

Expand Down
136 changes: 88 additions & 48 deletions lang/src/backends/llvm_jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,22 @@ static_assert(false, "please use C++17.");
// https://llvm.org/docs/tutorial/BuildingAJIT2.html
#include "../util.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/JITSymbol.h"
#include "llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h"
#include "llvm/ExecutionEngine/Orc/CompileUtils.h"
#include "llvm/ExecutionEngine/Orc/Core.h"
#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
#include "llvm/ExecutionEngine/Orc/IRTransformLayer.h"
#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
#include "llvm/ExecutionEngine/Orc/LambdaResolver.h"
#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
#include "llvm/ExecutionEngine/RTDyldMemoryManager.h"
#include "llvm/ExecutionEngine/RuntimeDyld.h"
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/Support/DynamicLibrary.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/Transforms/InstCombine/InstCombine.h"
#include "llvm/Transforms/Scalar.h"
Expand All @@ -40,35 +45,57 @@ int compile_ptx_and_launch(const std::string &ptx,
class TaichiLLVMJIT {
private:
ExecutionSession ES;
RTDyldObjectLinkingLayer ObjectLayer;
IRCompileLayer CompileLayer;
IRTransformLayer OptimizeLayer;
std::map<VModuleKey, std::shared_ptr<SymbolResolver>> Resolvers;
std::unique_ptr<TargetMachine> TM;
const DataLayout DL;
LegacyRTDyldObjectLinkingLayer ObjectLayer;
LegacyIRCompileLayer<decltype(ObjectLayer), SimpleCompiler> CompileLayer;

DataLayout DL;
MangleAndInterner Mangle;
ThreadSafeContext Ctx;
using OptimizeFunction =
std::function<std::unique_ptr<Module>(std::unique_ptr<Module>)>;

LegacyIRTransformLayer<decltype(CompileLayer), OptimizeFunction>
OptimizeLayer;

std::unique_ptr<JITCompileCallbackManager> CompileCallbackManager;
LegacyCompileOnDemandLayer<decltype(OptimizeLayer)> CODLayer;

public:
TaichiLLVMJIT(JITTargetMachineBuilder JTMB, DataLayout DL)
: ObjectLayer(ES,
[]() { return llvm::make_unique<SectionMemoryManager>(); }),
CompileLayer(ES, ObjectLayer, ConcurrentIRCompiler(std::move(JTMB))),
OptimizeLayer(ES, CompileLayer, optimizeModule),
DL(std::move(DL)),
Mangle(ES, this->DL),
Ctx(llvm::make_unique<LLVMContext>()) {
ES.getMainJITDylib().setGenerator(
cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess(DL)));
: TM(EngineBuilder().selectTarget()),
DL(TM->createDataLayout()),
ObjectLayer(ES,
[this](VModuleKey K) {
return LegacyRTDyldObjectLinkingLayer::Resources{
std::make_shared<SectionMemoryManager>(),
Resolvers[K]};
}),
CompileLayer(ObjectLayer, SimpleCompiler(*TM)),
OptimizeLayer(CompileLayer,
[this](std::unique_ptr<Module> M) {
return optimizeModule(std::move(M));
}),
CompileCallbackManager(cantFail(
orc::createLocalCompileCallbackManager(TM->getTargetTriple(),
ES,
0))),
CODLayer(ES,
OptimizeLayer,
[&](orc::VModuleKey K) { return Resolvers[K]; },
[&](orc::VModuleKey K, std::shared_ptr<SymbolResolver> R) {
Resolvers[K] = std::move(R);
},
[](Function &F) { return std::set<Function *>({&F}); },
*CompileCallbackManager,
orc::createLocalIndirectStubsManagerBuilder(
TM->getTargetTriple())) {
llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
}

const DataLayout &getDataLayout() const {
return DL;
}

LLVMContext &getContext() {
return *Ctx.getContext();
}

static Expected<std::unique_ptr<TaichiLLVMJIT>> create(Arch arch) {
std::unique_ptr<JITTargetMachineBuilder> jtmb;
if (arch == Arch::x86_64) {
Expand All @@ -89,48 +116,59 @@ class TaichiLLVMJIT {
return llvm::make_unique<TaichiLLVMJIT>(std::move(*jtmb), std::move(*DL));
}

Error addModule(std::unique_ptr<Module> M) {
return OptimizeLayer.add(ES.getMainJITDylib(),
ThreadSafeModule(std::move(M), Ctx));
VModuleKey addModule(std::unique_ptr<Module> M) {
// Create a new VModuleKey.
VModuleKey K = ES.allocateVModule();

// Build a resolver and associate it with the new key.
Resolvers[K] = createLegacyLookupResolver(
ES,
[this](const std::string &Name) -> JITSymbol {
if (auto Sym = CompileLayer.findSymbol(Name, false))
return Sym;
else if (auto Err = Sym.takeError())
return std::move(Err);
if (auto SymAddr =
RTDyldMemoryManager::getSymbolAddressInProcess(Name))
return JITSymbol(SymAddr, JITSymbolFlags::Exported);
return nullptr;
},
[](Error Err) { cantFail(std::move(Err), "lookupFlags failed"); });

// Add the module to the JIT with the new key.
cantFail(CODLayer.addModule(K, std::move(M)));
return K;
}

JITSymbol lookup(const std::string Name) {
std::string MangledName;
raw_string_ostream MangledNameStream(MangledName);
Mangler::getNameWithPrefix(MangledNameStream, Name, DL);
return CODLayer.findSymbol(MangledNameStream.str(), true);
}

Expected<JITEvaluatedSymbol> lookup(StringRef Name) {
return ES.lookup({&ES.getMainJITDylib()}, Mangle(Name.str()));
void removeModule(VModuleKey K) {
cantFail(CODLayer.removeModule(K));
}

private:
static Expected<ThreadSafeModule> optimizeModule(
ThreadSafeModule TSM,
const MaterializationResponsibility &R) {
std::unique_ptr<Module> optimizeModule(std::unique_ptr<Module> M) {
// Create a function pass manager.
auto FPM = llvm::make_unique<legacy::FunctionPassManager>(TSM.getModule());
auto FPM = llvm::make_unique<legacy::FunctionPassManager>(M.get());

// Add some optimizations.
// FPM->add(createFunctionInliningPass());
FPM->add(createInstructionCombiningPass());
FPM->add(createReassociatePass());
FPM->add(createGVNPass());
FPM->add(createCFGSimplificationPass());

FPM->doInitialization();

/*
llvm::ModulePassManager MPM;
llvm::ModuleAnalysisManager moduleAnalysisManager;
MPM.addPass(createFunctionInliningPass());
MPM.run(*TSM.getModule(), moduleAnalysisManager);
*/

// Run the optimizations over all functions in the module being added to
// the JIT.

for (auto &F : *TSM.getModule()) {
for (auto &F : *M)
FPM->run(F);
// TC_INFO("Function IR Optimized");
// F.print(errs(), nullptr);
}

return TSM;
return M;
}

public:
Expand All @@ -140,8 +178,10 @@ class TaichiLLVMJIT {
};

inline void *jit_lookup_name(TaichiLLVMJIT *jit, const std::string &name) {
llvm::ExitOnError exit_on_err;
return (void *)exit_on_err(jit->lookup(name)).getAddress();
auto ExprSymbol = jit->lookup(name);
if (!ExprSymbol)
TC_ERROR("Function not found");
return (void *)(llvm::cantFail(ExprSymbol.getAddress()));
}

TLANG_NAMESPACE_END
2 changes: 1 addition & 1 deletion lang/src/backends/struct_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ void StructCompilerLLVM::run(SNode &root, bool host) {

tlctx->set_struct_module(module);

llvm::cantFail(tlctx->jit->addModule(std::move(module)));
(tlctx->jit->addModule(std::move(module)));

if (host) {
for (auto n : snodes) {
Expand Down

0 comments on commit 2f32eb4

Please sign in to comment.