diff --git a/examples/Kaleidoscope/BuildingAJIT/CMakeLists.txt b/examples/Kaleidoscope/BuildingAJIT/CMakeLists.txt index 8315eb6e0e5e..17e280c16717 100644 --- a/examples/Kaleidoscope/BuildingAJIT/CMakeLists.txt +++ b/examples/Kaleidoscope/BuildingAJIT/CMakeLists.txt @@ -2,4 +2,5 @@ add_subdirectory(Chapter1) add_subdirectory(Chapter2) add_subdirectory(Chapter3) add_subdirectory(Chapter4) +add_subdirectory(Chapter5) diff --git a/examples/Kaleidoscope/BuildingAJIT/Chapter5/CMakeLists.txt b/examples/Kaleidoscope/BuildingAJIT/Chapter5/CMakeLists.txt new file mode 100644 index 000000000000..d5b832b49550 --- /dev/null +++ b/examples/Kaleidoscope/BuildingAJIT/Chapter5/CMakeLists.txt @@ -0,0 +1,21 @@ +add_subdirectory(Server) + +set(LLVM_LINK_COMPONENTS + Analysis + Core + ExecutionEngine + InstCombine + Object + OrcJIT + RuntimeDyld + ScalarOpts + Support + TransformUtils + native + ) + +add_kaleidoscope_chapter(BuildingAJIT-Ch5 + toy.cpp + ) + +export_executable_symbols(BuildingAJIT-Ch5) diff --git a/examples/Kaleidoscope/BuildingAJIT/Chapter5/KaleidoscopeJIT.h b/examples/Kaleidoscope/BuildingAJIT/Chapter5/KaleidoscopeJIT.h new file mode 100644 index 000000000000..d72dfdde3433 --- /dev/null +++ b/examples/Kaleidoscope/BuildingAJIT/Chapter5/KaleidoscopeJIT.h @@ -0,0 +1,263 @@ +//===----- KaleidoscopeJIT.h - A simple JIT for Kaleidoscope ----*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Contains a simple JIT definition for use in the kaleidoscope tutorials. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H +#define LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H + +#include "RemoteJITUtils.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ExecutionEngine/ExecutionEngine.h" +#include "llvm/ExecutionEngine/RuntimeDyld.h" +#include "llvm/ExecutionEngine/SectionMemoryManager.h" +#include "llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h" +#include "llvm/ExecutionEngine/Orc/CompileUtils.h" +#include "llvm/ExecutionEngine/Orc/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" +#include "llvm/ExecutionEngine/Orc/IRTransformLayer.h" +#include "llvm/ExecutionEngine/Orc/LambdaResolver.h" +#include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Mangler.h" +#include "llvm/Support/DynamicLibrary.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Target/TargetMachine.h" +#include +#include +#include +#include + +class PrototypeAST; +class ExprAST; + +/// FunctionAST - This class represents a function definition itself. +class FunctionAST { + std::unique_ptr Proto; + std::unique_ptr Body; + +public: + FunctionAST(std::unique_ptr Proto, + std::unique_ptr Body) + : Proto(std::move(Proto)), Body(std::move(Body)) {} + const PrototypeAST& getProto() const; + const std::string& getName() const; + llvm::Function *codegen(); +}; + +/// This will compile FnAST to IR, rename the function to add the given +/// suffix (needed to prevent a name-clash with the function's stub), +/// and then take ownership of the module that the function was compiled +/// into. +std::unique_ptr +irgenAndTakeOwnership(FunctionAST &FnAST, const std::string &Suffix); + +namespace llvm { +namespace orc { + +// Typedef the remote-client API. +typedef remote::OrcRemoteTargetClient MyRemote; + +class KaleidoscopeJIT { +private: + MyRemote &Remote; + std::unique_ptr TM; + const DataLayout DL; + JITCompileCallbackManager *CompileCallbackMgr; + std::unique_ptr IndirectStubsMgr; + ObjectLinkingLayer<> ObjectLayer; + IRCompileLayer CompileLayer; + + typedef std::function(std::unique_ptr)> + OptimizeFunction; + + IRTransformLayer OptimizeLayer; + +public: + typedef decltype(OptimizeLayer)::ModuleSetHandleT ModuleHandle; + + KaleidoscopeJIT(MyRemote &Remote) + : Remote(Remote), + TM(EngineBuilder().selectTarget()), + DL(TM->createDataLayout()), + CompileLayer(ObjectLayer, SimpleCompiler(*TM)), + OptimizeLayer(CompileLayer, + [this](std::unique_ptr M) { + return optimizeModule(std::move(M)); + }) { + auto CCMgrOrErr = Remote.enableCompileCallbacks(0); + if (!CCMgrOrErr) { + logAllUnhandledErrors(CCMgrOrErr.takeError(), errs(), + "Error enabling remote compile callbacks:"); + exit(1); + } + CompileCallbackMgr = &*CCMgrOrErr; + std::unique_ptr ISM; + if (auto Err = Remote.createIndirectStubsManager(ISM)) { + logAllUnhandledErrors(std::move(Err), errs(), + "Error creating indirect stubs manager:"); + exit(1); + } + IndirectStubsMgr = std::move(ISM); + llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr); + } + + TargetMachine &getTargetMachine() { return *TM; } + + ModuleHandle addModule(std::unique_ptr M) { + + // Build our symbol resolver: + // Lambda 1: Look back into the JIT itself to find symbols that are part of + // the same "logical dylib". + // Lambda 2: Search for external symbols in the host process. + auto Resolver = createLambdaResolver( + [&](const std::string &Name) { + if (auto Sym = IndirectStubsMgr->findStub(Name, false)) + return RuntimeDyld::SymbolInfo(Sym.getAddress(), Sym.getFlags()); + if (auto Sym = OptimizeLayer.findSymbol(Name, false)) + return RuntimeDyld::SymbolInfo(Sym.getAddress(), Sym.getFlags()); + return RuntimeDyld::SymbolInfo(nullptr); + }, + [&](const std::string &Name) { + if (auto AddrOrErr = Remote.getSymbolAddress(Name)) + return RuntimeDyld::SymbolInfo(*AddrOrErr, + JITSymbolFlags::Exported); + else { + logAllUnhandledErrors(AddrOrErr.takeError(), errs(), + "Error resolving remote symbol:"); + exit(1); + } + return RuntimeDyld::SymbolInfo(nullptr); + }); + + std::unique_ptr MemMgr; + if (auto Err = Remote.createRemoteMemoryManager(MemMgr)) { + logAllUnhandledErrors(std::move(Err), errs(), + "Error creating remote memory manager:"); + exit(1); + } + + // Build a singlton module set to hold our module. + std::vector> Ms; + Ms.push_back(std::move(M)); + + // Add the set to the JIT with the resolver we created above and a newly + // created SectionMemoryManager. + return OptimizeLayer.addModuleSet(std::move(Ms), + std::move(MemMgr), + std::move(Resolver)); + } + + Error addFunctionAST(std::unique_ptr FnAST) { + // Create a CompileCallback - this is the re-entry point into the compiler + // for functions that haven't been compiled yet. + auto CCInfo = CompileCallbackMgr->getCompileCallback(); + + // Create an indirect stub. This serves as the functions "canonical + // definition" - an unchanging (constant address) entry point to the + // function implementation. + // Initially we point the stub's function-pointer at the compile callback + // that we just created. In the compile action for the callback (see below) + // we will update the stub's function pointer to point at the function + // implementation that we just implemented. + if (auto Err = IndirectStubsMgr->createStub(mangle(FnAST->getName()), + CCInfo.getAddress(), + JITSymbolFlags::Exported)) + return Err; + + // Move ownership of FnAST to a shared pointer - C++11 lambdas don't support + // capture-by-move, which is be required for unique_ptr. + auto SharedFnAST = std::shared_ptr(std::move(FnAST)); + + // Set the action to compile our AST. This lambda will be run if/when + // execution hits the compile callback (via the stub). + // + // The steps to compile are: + // (1) IRGen the function. + // (2) Add the IR module to the JIT to make it executable like any other + // module. + // (3) Use findSymbol to get the address of the compiled function. + // (4) Update the stub pointer to point at the implementation so that + /// subsequent calls go directly to it and bypass the compiler. + // (5) Return the address of the implementation: this lambda will actually + // be run inside an attempted call to the function, and we need to + // continue on to the implementation to complete the attempted call. + // The JIT runtime (the resolver block) will use the return address of + // this function as the address to continue at once it has reset the + // CPU state to what it was immediately before the call. + CCInfo.setCompileAction( + [this, SharedFnAST]() { + auto M = irgenAndTakeOwnership(*SharedFnAST, "$impl"); + addModule(std::move(M)); + auto Sym = findSymbol(SharedFnAST->getName() + "$impl"); + assert(Sym && "Couldn't find compiled function?"); + TargetAddress SymAddr = Sym.getAddress(); + if (auto Err = + IndirectStubsMgr->updatePointer(mangle(SharedFnAST->getName()), + SymAddr)) { + logAllUnhandledErrors(std::move(Err), errs(), + "Error updating function pointer: "); + exit(1); + } + + return SymAddr; + }); + + return Error::success(); + } + + Error executeRemoteExpr(TargetAddress ExprAddr) { + return Remote.callVoidVoid(ExprAddr); + } + + JITSymbol findSymbol(const std::string Name) { + return OptimizeLayer.findSymbol(mangle(Name), true); + } + + void removeModule(ModuleHandle H) { + OptimizeLayer.removeModuleSet(H); + } + +private: + + std::string mangle(const std::string &Name) { + std::string MangledName; + raw_string_ostream MangledNameStream(MangledName); + Mangler::getNameWithPrefix(MangledNameStream, Name, DL); + return MangledNameStream.str(); + } + + std::unique_ptr optimizeModule(std::unique_ptr M) { + // Create a function pass manager. + auto FPM = llvm::make_unique(M.get()); + + // Add some optimizations. + FPM->add(createInstructionCombiningPass()); + FPM->add(createReassociatePass()); + FPM->add(createGVNPass()); + FPM->add(createCFGSimplificationPass()); + FPM->doInitialization(); + + // Run the optimizations over all functions in the module being added to + // the JIT. + for (auto &F : *M) + FPM->run(F); + + return M; + } + +}; + +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H diff --git a/examples/Kaleidoscope/BuildingAJIT/Chapter5/RemoteJITUtils.h b/examples/Kaleidoscope/BuildingAJIT/Chapter5/RemoteJITUtils.h new file mode 100644 index 000000000000..869d0a7ef39d --- /dev/null +++ b/examples/Kaleidoscope/BuildingAJIT/Chapter5/RemoteJITUtils.h @@ -0,0 +1,74 @@ +//===-- RemoteJITUtils.h - Utilities for remote-JITing with LLI -*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Utilities for remote-JITing with LLI. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TOOLS_LLI_REMOTEJITUTILS_H +#define LLVM_TOOLS_LLI_REMOTEJITUTILS_H + +#include "llvm/ExecutionEngine/Orc/RPCChannel.h" +#include "llvm/ExecutionEngine/RTDyldMemoryManager.h" +#include + +#if !defined(_MSC_VER) && !defined(__MINGW32__) +#include +#else +#include +#endif + +/// RPC channel that reads from and writes from file descriptors. +class FDRPCChannel final : public llvm::orc::remote::RPCChannel { +public: + FDRPCChannel(int InFD, int OutFD) : InFD(InFD), OutFD(OutFD) {} + + llvm::Error readBytes(char *Dst, unsigned Size) override { + assert(Dst && "Attempt to read into null."); + ssize_t Completed = 0; + while (Completed < static_cast(Size)) { + ssize_t Read = ::read(InFD, Dst + Completed, Size - Completed); + if (Read <= 0) { + auto ErrNo = errno; + if (ErrNo == EAGAIN || ErrNo == EINTR) + continue; + else + return llvm::errorCodeToError( + std::error_code(errno, std::generic_category())); + } + Completed += Read; + } + return llvm::Error::success(); + } + + llvm::Error appendBytes(const char *Src, unsigned Size) override { + assert(Src && "Attempt to append from null."); + ssize_t Completed = 0; + while (Completed < static_cast(Size)) { + ssize_t Written = ::write(OutFD, Src + Completed, Size - Completed); + if (Written < 0) { + auto ErrNo = errno; + if (ErrNo == EAGAIN || ErrNo == EINTR) + continue; + else + return llvm::errorCodeToError( + std::error_code(errno, std::generic_category())); + } + Completed += Written; + } + return llvm::Error::success(); + } + + llvm::Error send() override { return llvm::Error::success(); } + +private: + int InFD, OutFD; +}; + +#endif diff --git a/examples/Kaleidoscope/BuildingAJIT/Chapter5/Server/CMakeLists.txt b/examples/Kaleidoscope/BuildingAJIT/Chapter5/Server/CMakeLists.txt new file mode 100644 index 000000000000..15dd53516ceb --- /dev/null +++ b/examples/Kaleidoscope/BuildingAJIT/Chapter5/Server/CMakeLists.txt @@ -0,0 +1,17 @@ +set(LLVM_LINK_COMPONENTS + Analysis + Core + ExecutionEngine + InstCombine + Object + OrcJIT + RuntimeDyld + ScalarOpts + Support + TransformUtils + native + ) + +add_kaleidoscope_chapter(BuildingAJIT-Ch5-Server + server.cpp + ) diff --git a/examples/Kaleidoscope/BuildingAJIT/Chapter5/Server/server.cpp b/examples/Kaleidoscope/BuildingAJIT/Chapter5/Server/server.cpp new file mode 100644 index 000000000000..c53e22fe83ae --- /dev/null +++ b/examples/Kaleidoscope/BuildingAJIT/Chapter5/Server/server.cpp @@ -0,0 +1,119 @@ +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/DynamicLibrary.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h" +#include "llvm/ExecutionEngine/Orc/OrcABISupport.h" + +#include "../RemoteJITUtils.h" + +#include +#include +#include +#include + + +using namespace llvm; +using namespace llvm::orc; + +// Command line argument for TCP port. +cl::opt Port("port", + cl::desc("TCP port to listen on"), + cl::init(20000)); + +ExitOnError ExitOnErr; + +typedef int (*MainFun)(int, const char*[]); + +template +NativePtrT MakeNative(uint64_t P) { + return reinterpret_cast(static_cast(P)); +} + +extern "C" +void printExprResult(double Val) { + printf("Expression evaluated to: %f\n", Val); +} + +// --- LAZY COMPILE TEST --- +int main(int argc, char* argv[]) { + + if (argc == 0) + ExitOnErr.setBanner("jit_server: "); + else + ExitOnErr.setBanner(std::string(argv[0]) + ": "); + + // --- Initialize LLVM --- + cl::ParseCommandLineOptions(argc, argv, "LLVM lazy JIT example.\n"); + + InitializeNativeTarget(); + InitializeNativeTargetAsmPrinter(); + InitializeNativeTargetAsmParser(); + + if (sys::DynamicLibrary::LoadLibraryPermanently(nullptr)) { + errs() << "Error loading program symbols.\n"; + return 1; + } + + // --- Initialize remote connection --- + + int sockfd = socket(PF_INET, SOCK_STREAM, 0); + sockaddr_in servAddr, clientAddr; + socklen_t clientAddrLen = sizeof(clientAddr); + bzero(&servAddr, sizeof(servAddr)); + servAddr.sin_family = PF_INET; + servAddr.sin_family = INADDR_ANY; + servAddr.sin_port = htons(Port); + + { + // avoid "Address already in use" error. + int yes=1; + if (setsockopt(sockfd,SOL_SOCKET,SO_REUSEADDR,&yes,sizeof(int)) == -1) { + errs() << "Error calling setsockopt.\n"; + return 1; + } + } + + if (bind(sockfd, reinterpret_cast(&servAddr), + sizeof(servAddr)) < 0) { + errs() << "Error on binding.\n"; + return 1; + } + listen(sockfd, 1); + int newsockfd = accept(sockfd, reinterpret_cast(&clientAddr), + &clientAddrLen); + + auto SymbolLookup = + [](const std::string &Name) { + return RTDyldMemoryManager::getSymbolAddressInProcess(Name); + }; + + auto RegisterEHFrames = + [](uint8_t *Addr, uint32_t Size) { + RTDyldMemoryManager::registerEHFramesInProcess(Addr, Size); + }; + + auto DeregisterEHFrames = + [](uint8_t *Addr, uint32_t Size) { + RTDyldMemoryManager::deregisterEHFramesInProcess(Addr, Size); + }; + + FDRPCChannel TCPChannel(newsockfd, newsockfd); + typedef remote::OrcRemoteTargetServer MyServerT; + + MyServerT Server(TCPChannel, SymbolLookup, RegisterEHFrames, DeregisterEHFrames); + + while (1) { + MyServerT::JITFuncId Id = MyServerT::InvalidId; + ExitOnErr(Server.startReceivingFunction(TCPChannel, (uint32_t&)Id)); + switch (Id) { + case MyServerT::TerminateSessionId: + ExitOnErr(Server.handleTerminateSession()); + return 0; + default: + ExitOnErr(Server.handleKnownFunction(Id)); + break; + } + } + + llvm_unreachable("Fell through server command loop."); +} diff --git a/examples/Kaleidoscope/BuildingAJIT/Chapter5/toy.cpp b/examples/Kaleidoscope/BuildingAJIT/Chapter5/toy.cpp new file mode 100644 index 000000000000..9c21098971a6 --- /dev/null +++ b/examples/Kaleidoscope/BuildingAJIT/Chapter5/toy.cpp @@ -0,0 +1,1294 @@ +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/GVN.h" +#include "KaleidoscopeJIT.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +using namespace llvm; +using namespace llvm::orc; + +// Command line argument for TCP hostname. +cl::opt HostName("hostname", + cl::desc("TCP hostname to connect to"), + cl::init("localhost")); + +// Command line argument for TCP port. +cl::opt Port("port", + cl::desc("TCP port to connect to"), + cl::init(20000)); + +//===----------------------------------------------------------------------===// +// Lexer +//===----------------------------------------------------------------------===// + +// The lexer returns tokens [0-255] if it is an unknown character, otherwise one +// of these for known things. +enum Token { + tok_eof = -1, + + // commands + tok_def = -2, + tok_extern = -3, + + // primary + tok_identifier = -4, + tok_number = -5, + + // control + tok_if = -6, + tok_then = -7, + tok_else = -8, + tok_for = -9, + tok_in = -10, + + // operators + tok_binary = -11, + tok_unary = -12, + + // var definition + tok_var = -13 +}; + +static std::string IdentifierStr; // Filled in if tok_identifier +static double NumVal; // Filled in if tok_number + +/// gettok - Return the next token from standard input. +static int gettok() { + static int LastChar = ' '; + + // Skip any whitespace. + while (isspace(LastChar)) + LastChar = getchar(); + + if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]* + IdentifierStr = LastChar; + while (isalnum((LastChar = getchar()))) + IdentifierStr += LastChar; + + if (IdentifierStr == "def") + return tok_def; + if (IdentifierStr == "extern") + return tok_extern; + if (IdentifierStr == "if") + return tok_if; + if (IdentifierStr == "then") + return tok_then; + if (IdentifierStr == "else") + return tok_else; + if (IdentifierStr == "for") + return tok_for; + if (IdentifierStr == "in") + return tok_in; + if (IdentifierStr == "binary") + return tok_binary; + if (IdentifierStr == "unary") + return tok_unary; + if (IdentifierStr == "var") + return tok_var; + return tok_identifier; + } + + if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+ + std::string NumStr; + do { + NumStr += LastChar; + LastChar = getchar(); + } while (isdigit(LastChar) || LastChar == '.'); + + NumVal = strtod(NumStr.c_str(), nullptr); + return tok_number; + } + + if (LastChar == '#') { + // Comment until end of line. + do + LastChar = getchar(); + while (LastChar != EOF && LastChar != '\n' && LastChar != '\r'); + + if (LastChar != EOF) + return gettok(); + } + + // Check for end of file. Don't eat the EOF. + if (LastChar == EOF) + return tok_eof; + + // Otherwise, just return the character as its ascii value. + int ThisChar = LastChar; + LastChar = getchar(); + return ThisChar; +} + +//===----------------------------------------------------------------------===// +// Abstract Syntax Tree (aka Parse Tree) +//===----------------------------------------------------------------------===// + +/// ExprAST - Base class for all expression nodes. +class ExprAST { +public: + virtual ~ExprAST() {} + virtual Value *codegen() = 0; +}; + +/// NumberExprAST - Expression class for numeric literals like "1.0". +class NumberExprAST : public ExprAST { + double Val; + +public: + NumberExprAST(double Val) : Val(Val) {} + Value *codegen() override; +}; + +/// VariableExprAST - Expression class for referencing a variable, like "a". +class VariableExprAST : public ExprAST { + std::string Name; + +public: + VariableExprAST(const std::string &Name) : Name(Name) {} + const std::string &getName() const { return Name; } + Value *codegen() override; +}; + +/// UnaryExprAST - Expression class for a unary operator. +class UnaryExprAST : public ExprAST { + char Opcode; + std::unique_ptr Operand; + +public: + UnaryExprAST(char Opcode, std::unique_ptr Operand) + : Opcode(Opcode), Operand(std::move(Operand)) {} + Value *codegen() override; +}; + +/// BinaryExprAST - Expression class for a binary operator. +class BinaryExprAST : public ExprAST { + char Op; + std::unique_ptr LHS, RHS; + +public: + BinaryExprAST(char Op, std::unique_ptr LHS, + std::unique_ptr RHS) + : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {} + Value *codegen() override; +}; + +/// CallExprAST - Expression class for function calls. +class CallExprAST : public ExprAST { + std::string Callee; + std::vector> Args; + +public: + CallExprAST(const std::string &Callee, + std::vector> Args) + : Callee(Callee), Args(std::move(Args)) {} + Value *codegen() override; +}; + +/// IfExprAST - Expression class for if/then/else. +class IfExprAST : public ExprAST { + std::unique_ptr Cond, Then, Else; + +public: + IfExprAST(std::unique_ptr Cond, std::unique_ptr Then, + std::unique_ptr Else) + : Cond(std::move(Cond)), Then(std::move(Then)), Else(std::move(Else)) {} + Value *codegen() override; +}; + +/// ForExprAST - Expression class for for/in. +class ForExprAST : public ExprAST { + std::string VarName; + std::unique_ptr Start, End, Step, Body; + +public: + ForExprAST(const std::string &VarName, std::unique_ptr Start, + std::unique_ptr End, std::unique_ptr Step, + std::unique_ptr Body) + : VarName(VarName), Start(std::move(Start)), End(std::move(End)), + Step(std::move(Step)), Body(std::move(Body)) {} + Value *codegen() override; +}; + +/// VarExprAST - Expression class for var/in +class VarExprAST : public ExprAST { + std::vector>> VarNames; + std::unique_ptr Body; + +public: + VarExprAST( + std::vector>> VarNames, + std::unique_ptr Body) + : VarNames(std::move(VarNames)), Body(std::move(Body)) {} + Value *codegen() override; +}; + +/// PrototypeAST - This class represents the "prototype" for a function, +/// which captures its name, and its argument names (thus implicitly the number +/// of arguments the function takes), as well as if it is an operator. +class PrototypeAST { + std::string Name; + std::vector Args; + bool IsOperator; + unsigned Precedence; // Precedence if a binary op. + +public: + PrototypeAST(const std::string &Name, std::vector Args, + bool IsOperator = false, unsigned Prec = 0) + : Name(Name), Args(std::move(Args)), IsOperator(IsOperator), + Precedence(Prec) {} + Function *codegen(); + const std::string &getName() const { return Name; } + + bool isUnaryOp() const { return IsOperator && Args.size() == 1; } + bool isBinaryOp() const { return IsOperator && Args.size() == 2; } + + char getOperatorName() const { + assert(isUnaryOp() || isBinaryOp()); + return Name[Name.size() - 1]; + } + + unsigned getBinaryPrecedence() const { return Precedence; } +}; + +//===----------------------------------------------------------------------===// +// Parser +//===----------------------------------------------------------------------===// + +/// CurTok/getNextToken - Provide a simple token buffer. CurTok is the current +/// token the parser is looking at. getNextToken reads another token from the +/// lexer and updates CurTok with its results. +static int CurTok; +static int getNextToken() { return CurTok = gettok(); } + +/// BinopPrecedence - This holds the precedence for each binary operator that is +/// defined. +static std::map BinopPrecedence; + +/// GetTokPrecedence - Get the precedence of the pending binary operator token. +static int GetTokPrecedence() { + if (!isascii(CurTok)) + return -1; + + // Make sure it's a declared binop. + int TokPrec = BinopPrecedence[CurTok]; + if (TokPrec <= 0) + return -1; + return TokPrec; +} + +/// LogError* - These are little helper functions for error handling. +std::unique_ptr LogError(const char *Str) { + fprintf(stderr, "Error: %s\n", Str); + return nullptr; +} + +std::unique_ptr LogErrorP(const char *Str) { + LogError(Str); + return nullptr; +} + +static std::unique_ptr ParseExpression(); + +/// numberexpr ::= number +static std::unique_ptr ParseNumberExpr() { + auto Result = llvm::make_unique(NumVal); + getNextToken(); // consume the number + return std::move(Result); +} + +/// parenexpr ::= '(' expression ')' +static std::unique_ptr ParseParenExpr() { + getNextToken(); // eat (. + auto V = ParseExpression(); + if (!V) + return nullptr; + + if (CurTok != ')') + return LogError("expected ')'"); + getNextToken(); // eat ). + return V; +} + +/// identifierexpr +/// ::= identifier +/// ::= identifier '(' expression* ')' +static std::unique_ptr ParseIdentifierExpr() { + std::string IdName = IdentifierStr; + + getNextToken(); // eat identifier. + + if (CurTok != '(') // Simple variable ref. + return llvm::make_unique(IdName); + + // Call. + getNextToken(); // eat ( + std::vector> Args; + if (CurTok != ')') { + while (true) { + if (auto Arg = ParseExpression()) + Args.push_back(std::move(Arg)); + else + return nullptr; + + if (CurTok == ')') + break; + + if (CurTok != ',') + return LogError("Expected ')' or ',' in argument list"); + getNextToken(); + } + } + + // Eat the ')'. + getNextToken(); + + return llvm::make_unique(IdName, std::move(Args)); +} + +/// ifexpr ::= 'if' expression 'then' expression 'else' expression +static std::unique_ptr ParseIfExpr() { + getNextToken(); // eat the if. + + // condition. + auto Cond = ParseExpression(); + if (!Cond) + return nullptr; + + if (CurTok != tok_then) + return LogError("expected then"); + getNextToken(); // eat the then + + auto Then = ParseExpression(); + if (!Then) + return nullptr; + + if (CurTok != tok_else) + return LogError("expected else"); + + getNextToken(); + + auto Else = ParseExpression(); + if (!Else) + return nullptr; + + return llvm::make_unique(std::move(Cond), std::move(Then), + std::move(Else)); +} + +/// forexpr ::= 'for' identifier '=' expr ',' expr (',' expr)? 'in' expression +static std::unique_ptr ParseForExpr() { + getNextToken(); // eat the for. + + if (CurTok != tok_identifier) + return LogError("expected identifier after for"); + + std::string IdName = IdentifierStr; + getNextToken(); // eat identifier. + + if (CurTok != '=') + return LogError("expected '=' after for"); + getNextToken(); // eat '='. + + auto Start = ParseExpression(); + if (!Start) + return nullptr; + if (CurTok != ',') + return LogError("expected ',' after for start value"); + getNextToken(); + + auto End = ParseExpression(); + if (!End) + return nullptr; + + // The step value is optional. + std::unique_ptr Step; + if (CurTok == ',') { + getNextToken(); + Step = ParseExpression(); + if (!Step) + return nullptr; + } + + if (CurTok != tok_in) + return LogError("expected 'in' after for"); + getNextToken(); // eat 'in'. + + auto Body = ParseExpression(); + if (!Body) + return nullptr; + + return llvm::make_unique(IdName, std::move(Start), std::move(End), + std::move(Step), std::move(Body)); +} + +/// varexpr ::= 'var' identifier ('=' expression)? +// (',' identifier ('=' expression)?)* 'in' expression +static std::unique_ptr ParseVarExpr() { + getNextToken(); // eat the var. + + std::vector>> VarNames; + + // At least one variable name is required. + if (CurTok != tok_identifier) + return LogError("expected identifier after var"); + + while (true) { + std::string Name = IdentifierStr; + getNextToken(); // eat identifier. + + // Read the optional initializer. + std::unique_ptr Init = nullptr; + if (CurTok == '=') { + getNextToken(); // eat the '='. + + Init = ParseExpression(); + if (!Init) + return nullptr; + } + + VarNames.push_back(std::make_pair(Name, std::move(Init))); + + // End of var list, exit loop. + if (CurTok != ',') + break; + getNextToken(); // eat the ','. + + if (CurTok != tok_identifier) + return LogError("expected identifier list after var"); + } + + // At this point, we have to have 'in'. + if (CurTok != tok_in) + return LogError("expected 'in' keyword after 'var'"); + getNextToken(); // eat 'in'. + + auto Body = ParseExpression(); + if (!Body) + return nullptr; + + return llvm::make_unique(std::move(VarNames), std::move(Body)); +} + +/// primary +/// ::= identifierexpr +/// ::= numberexpr +/// ::= parenexpr +/// ::= ifexpr +/// ::= forexpr +/// ::= varexpr +static std::unique_ptr ParsePrimary() { + switch (CurTok) { + default: + return LogError("unknown token when expecting an expression"); + case tok_identifier: + return ParseIdentifierExpr(); + case tok_number: + return ParseNumberExpr(); + case '(': + return ParseParenExpr(); + case tok_if: + return ParseIfExpr(); + case tok_for: + return ParseForExpr(); + case tok_var: + return ParseVarExpr(); + } +} + +/// unary +/// ::= primary +/// ::= '!' unary +static std::unique_ptr ParseUnary() { + // If the current token is not an operator, it must be a primary expr. + if (!isascii(CurTok) || CurTok == '(' || CurTok == ',') + return ParsePrimary(); + + // If this is a unary operator, read it. + int Opc = CurTok; + getNextToken(); + if (auto Operand = ParseUnary()) + return llvm::make_unique(Opc, std::move(Operand)); + return nullptr; +} + +/// binoprhs +/// ::= ('+' unary)* +static std::unique_ptr ParseBinOpRHS(int ExprPrec, + std::unique_ptr LHS) { + // If this is a binop, find its precedence. + while (true) { + int TokPrec = GetTokPrecedence(); + + // If this is a binop that binds at least as tightly as the current binop, + // consume it, otherwise we are done. + if (TokPrec < ExprPrec) + return LHS; + + // Okay, we know this is a binop. + int BinOp = CurTok; + getNextToken(); // eat binop + + // Parse the unary expression after the binary operator. + auto RHS = ParseUnary(); + if (!RHS) + return nullptr; + + // If BinOp binds less tightly with RHS than the operator after RHS, let + // the pending operator take RHS as its LHS. + int NextPrec = GetTokPrecedence(); + if (TokPrec < NextPrec) { + RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS)); + if (!RHS) + return nullptr; + } + + // Merge LHS/RHS. + LHS = + llvm::make_unique(BinOp, std::move(LHS), std::move(RHS)); + } +} + +/// expression +/// ::= unary binoprhs +/// +static std::unique_ptr ParseExpression() { + auto LHS = ParseUnary(); + if (!LHS) + return nullptr; + + return ParseBinOpRHS(0, std::move(LHS)); +} + +/// prototype +/// ::= id '(' id* ')' +/// ::= binary LETTER number? (id, id) +/// ::= unary LETTER (id) +static std::unique_ptr ParsePrototype() { + std::string FnName; + + unsigned Kind = 0; // 0 = identifier, 1 = unary, 2 = binary. + unsigned BinaryPrecedence = 30; + + switch (CurTok) { + default: + return LogErrorP("Expected function name in prototype"); + case tok_identifier: + FnName = IdentifierStr; + Kind = 0; + getNextToken(); + break; + case tok_unary: + getNextToken(); + if (!isascii(CurTok)) + return LogErrorP("Expected unary operator"); + FnName = "unary"; + FnName += (char)CurTok; + Kind = 1; + getNextToken(); + break; + case tok_binary: + getNextToken(); + if (!isascii(CurTok)) + return LogErrorP("Expected binary operator"); + FnName = "binary"; + FnName += (char)CurTok; + Kind = 2; + getNextToken(); + + // Read the precedence if present. + if (CurTok == tok_number) { + if (NumVal < 1 || NumVal > 100) + return LogErrorP("Invalid precedecnce: must be 1..100"); + BinaryPrecedence = (unsigned)NumVal; + getNextToken(); + } + break; + } + + if (CurTok != '(') + return LogErrorP("Expected '(' in prototype"); + + std::vector ArgNames; + while (getNextToken() == tok_identifier) + ArgNames.push_back(IdentifierStr); + if (CurTok != ')') + return LogErrorP("Expected ')' in prototype"); + + // success. + getNextToken(); // eat ')'. + + // Verify right number of names for operator. + if (Kind && ArgNames.size() != Kind) + return LogErrorP("Invalid number of operands for operator"); + + return llvm::make_unique(FnName, ArgNames, Kind != 0, + BinaryPrecedence); +} + +/// definition ::= 'def' prototype expression +static std::unique_ptr ParseDefinition() { + getNextToken(); // eat def. + auto Proto = ParsePrototype(); + if (!Proto) + return nullptr; + + if (auto E = ParseExpression()) + return llvm::make_unique(std::move(Proto), std::move(E)); + return nullptr; +} + +/// toplevelexpr ::= expression +static std::unique_ptr ParseTopLevelExpr() { + if (auto E = ParseExpression()) { + + auto PEArgs = std::vector>(); + PEArgs.push_back(std::move(E)); + auto PrintExpr = + llvm::make_unique("printExprResult", std::move(PEArgs)); + + // Make an anonymous proto. + auto Proto = llvm::make_unique("__anon_expr", + std::vector()); + return llvm::make_unique(std::move(Proto), + std::move(PrintExpr)); + } + return nullptr; +} + +/// external ::= 'extern' prototype +static std::unique_ptr ParseExtern() { + getNextToken(); // eat extern. + return ParsePrototype(); +} + +//===----------------------------------------------------------------------===// +// Code Generation +//===----------------------------------------------------------------------===// + +static LLVMContext TheContext; +static IRBuilder<> Builder(TheContext); +static std::unique_ptr TheModule; +static std::map NamedValues; +static std::unique_ptr TheJIT; +static std::map> FunctionProtos; +static ExitOnError ExitOnErr; + +Value *LogErrorV(const char *Str) { + LogError(Str); + return nullptr; +} + +Function *getFunction(std::string Name) { + // First, see if the function has already been added to the current module. + if (auto *F = TheModule->getFunction(Name)) + return F; + + // If not, check whether we can codegen the declaration from some existing + // prototype. + auto FI = FunctionProtos.find(Name); + if (FI != FunctionProtos.end()) + return FI->second->codegen(); + + // If no existing prototype exists, return null. + return nullptr; +} + +/// CreateEntryBlockAlloca - Create an alloca instruction in the entry block of +/// the function. This is used for mutable variables etc. +static AllocaInst *CreateEntryBlockAlloca(Function *TheFunction, + const std::string &VarName) { + IRBuilder<> TmpB(&TheFunction->getEntryBlock(), + TheFunction->getEntryBlock().begin()); + return TmpB.CreateAlloca(Type::getDoubleTy(TheContext), nullptr, VarName); +} + +Value *NumberExprAST::codegen() { + return ConstantFP::get(TheContext, APFloat(Val)); +} + +Value *VariableExprAST::codegen() { + // Look this variable up in the function. + Value *V = NamedValues[Name]; + if (!V) + return LogErrorV("Unknown variable name"); + + // Load the value. + return Builder.CreateLoad(V, Name.c_str()); +} + +Value *UnaryExprAST::codegen() { + Value *OperandV = Operand->codegen(); + if (!OperandV) + return nullptr; + + Function *F = getFunction(std::string("unary") + Opcode); + if (!F) + return LogErrorV("Unknown unary operator"); + + return Builder.CreateCall(F, OperandV, "unop"); +} + +Value *BinaryExprAST::codegen() { + // Special case '=' because we don't want to emit the LHS as an expression. + if (Op == '=') { + // Assignment requires the LHS to be an identifier. + // This assume we're building without RTTI because LLVM builds that way by + // default. If you build LLVM with RTTI this can be changed to a + // dynamic_cast for automatic error checking. + VariableExprAST *LHSE = static_cast(LHS.get()); + if (!LHSE) + return LogErrorV("destination of '=' must be a variable"); + // Codegen the RHS. + Value *Val = RHS->codegen(); + if (!Val) + return nullptr; + + // Look up the name. + Value *Variable = NamedValues[LHSE->getName()]; + if (!Variable) + return LogErrorV("Unknown variable name"); + + Builder.CreateStore(Val, Variable); + return Val; + } + + Value *L = LHS->codegen(); + Value *R = RHS->codegen(); + if (!L || !R) + return nullptr; + + switch (Op) { + case '+': + return Builder.CreateFAdd(L, R, "addtmp"); + case '-': + return Builder.CreateFSub(L, R, "subtmp"); + case '*': + return Builder.CreateFMul(L, R, "multmp"); + case '<': + L = Builder.CreateFCmpULT(L, R, "cmptmp"); + // Convert bool 0/1 to double 0.0 or 1.0 + return Builder.CreateUIToFP(L, Type::getDoubleTy(TheContext), "booltmp"); + default: + break; + } + + // If it wasn't a builtin binary operator, it must be a user defined one. Emit + // a call to it. + Function *F = getFunction(std::string("binary") + Op); + assert(F && "binary operator not found!"); + + Value *Ops[] = {L, R}; + return Builder.CreateCall(F, Ops, "binop"); +} + +Value *CallExprAST::codegen() { + // Look up the name in the global module table. + Function *CalleeF = getFunction(Callee); + if (!CalleeF) + return LogErrorV("Unknown function referenced"); + + // If argument mismatch error. + if (CalleeF->arg_size() != Args.size()) + return LogErrorV("Incorrect # arguments passed"); + + std::vector ArgsV; + for (unsigned i = 0, e = Args.size(); i != e; ++i) { + ArgsV.push_back(Args[i]->codegen()); + if (!ArgsV.back()) + return nullptr; + } + + return Builder.CreateCall(CalleeF, ArgsV, "calltmp"); +} + +Value *IfExprAST::codegen() { + Value *CondV = Cond->codegen(); + if (!CondV) + return nullptr; + + // Convert condition to a bool by comparing equal to 0.0. + CondV = Builder.CreateFCmpONE( + CondV, ConstantFP::get(TheContext, APFloat(0.0)), "ifcond"); + + Function *TheFunction = Builder.GetInsertBlock()->getParent(); + + // Create blocks for the then and else cases. Insert the 'then' block at the + // end of the function. + BasicBlock *ThenBB = BasicBlock::Create(TheContext, "then", TheFunction); + BasicBlock *ElseBB = BasicBlock::Create(TheContext, "else"); + BasicBlock *MergeBB = BasicBlock::Create(TheContext, "ifcont"); + + Builder.CreateCondBr(CondV, ThenBB, ElseBB); + + // Emit then value. + Builder.SetInsertPoint(ThenBB); + + Value *ThenV = Then->codegen(); + if (!ThenV) + return nullptr; + + Builder.CreateBr(MergeBB); + // Codegen of 'Then' can change the current block, update ThenBB for the PHI. + ThenBB = Builder.GetInsertBlock(); + + // Emit else block. + TheFunction->getBasicBlockList().push_back(ElseBB); + Builder.SetInsertPoint(ElseBB); + + Value *ElseV = Else->codegen(); + if (!ElseV) + return nullptr; + + Builder.CreateBr(MergeBB); + // Codegen of 'Else' can change the current block, update ElseBB for the PHI. + ElseBB = Builder.GetInsertBlock(); + + // Emit merge block. + TheFunction->getBasicBlockList().push_back(MergeBB); + Builder.SetInsertPoint(MergeBB); + PHINode *PN = Builder.CreatePHI(Type::getDoubleTy(TheContext), 2, "iftmp"); + + PN->addIncoming(ThenV, ThenBB); + PN->addIncoming(ElseV, ElseBB); + return PN; +} + +// Output for-loop as: +// var = alloca double +// ... +// start = startexpr +// store start -> var +// goto loop +// loop: +// ... +// bodyexpr +// ... +// loopend: +// step = stepexpr +// endcond = endexpr +// +// curvar = load var +// nextvar = curvar + step +// store nextvar -> var +// br endcond, loop, endloop +// outloop: +Value *ForExprAST::codegen() { + Function *TheFunction = Builder.GetInsertBlock()->getParent(); + + // Create an alloca for the variable in the entry block. + AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, VarName); + + // Emit the start code first, without 'variable' in scope. + Value *StartVal = Start->codegen(); + if (!StartVal) + return nullptr; + + // Store the value into the alloca. + Builder.CreateStore(StartVal, Alloca); + + // Make the new basic block for the loop header, inserting after current + // block. + BasicBlock *LoopBB = BasicBlock::Create(TheContext, "loop", TheFunction); + + // Insert an explicit fall through from the current block to the LoopBB. + Builder.CreateBr(LoopBB); + + // Start insertion in LoopBB. + Builder.SetInsertPoint(LoopBB); + + // Within the loop, the variable is defined equal to the PHI node. If it + // shadows an existing variable, we have to restore it, so save it now. + AllocaInst *OldVal = NamedValues[VarName]; + NamedValues[VarName] = Alloca; + + // Emit the body of the loop. This, like any other expr, can change the + // current BB. Note that we ignore the value computed by the body, but don't + // allow an error. + if (!Body->codegen()) + return nullptr; + + // Emit the step value. + Value *StepVal = nullptr; + if (Step) { + StepVal = Step->codegen(); + if (!StepVal) + return nullptr; + } else { + // If not specified, use 1.0. + StepVal = ConstantFP::get(TheContext, APFloat(1.0)); + } + + // Compute the end condition. + Value *EndCond = End->codegen(); + if (!EndCond) + return nullptr; + + // Reload, increment, and restore the alloca. This handles the case where + // the body of the loop mutates the variable. + Value *CurVar = Builder.CreateLoad(Alloca, VarName.c_str()); + Value *NextVar = Builder.CreateFAdd(CurVar, StepVal, "nextvar"); + Builder.CreateStore(NextVar, Alloca); + + // Convert condition to a bool by comparing equal to 0.0. + EndCond = Builder.CreateFCmpONE( + EndCond, ConstantFP::get(TheContext, APFloat(0.0)), "loopcond"); + + // Create the "after loop" block and insert it. + BasicBlock *AfterBB = + BasicBlock::Create(TheContext, "afterloop", TheFunction); + + // Insert the conditional branch into the end of LoopEndBB. + Builder.CreateCondBr(EndCond, LoopBB, AfterBB); + + // Any new code will be inserted in AfterBB. + Builder.SetInsertPoint(AfterBB); + + // Restore the unshadowed variable. + if (OldVal) + NamedValues[VarName] = OldVal; + else + NamedValues.erase(VarName); + + // for expr always returns 0.0. + return Constant::getNullValue(Type::getDoubleTy(TheContext)); +} + +Value *VarExprAST::codegen() { + std::vector OldBindings; + + Function *TheFunction = Builder.GetInsertBlock()->getParent(); + + // Register all variables and emit their initializer. + for (unsigned i = 0, e = VarNames.size(); i != e; ++i) { + const std::string &VarName = VarNames[i].first; + ExprAST *Init = VarNames[i].second.get(); + + // Emit the initializer before adding the variable to scope, this prevents + // the initializer from referencing the variable itself, and permits stuff + // like this: + // var a = 1 in + // var a = a in ... # refers to outer 'a'. + Value *InitVal; + if (Init) { + InitVal = Init->codegen(); + if (!InitVal) + return nullptr; + } else { // If not specified, use 0.0. + InitVal = ConstantFP::get(TheContext, APFloat(0.0)); + } + + AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, VarName); + Builder.CreateStore(InitVal, Alloca); + + // Remember the old variable binding so that we can restore the binding when + // we unrecurse. + OldBindings.push_back(NamedValues[VarName]); + + // Remember this binding. + NamedValues[VarName] = Alloca; + } + + // Codegen the body, now that all vars are in scope. + Value *BodyVal = Body->codegen(); + if (!BodyVal) + return nullptr; + + // Pop all our variables from scope. + for (unsigned i = 0, e = VarNames.size(); i != e; ++i) + NamedValues[VarNames[i].first] = OldBindings[i]; + + // Return the body computation. + return BodyVal; +} + +Function *PrototypeAST::codegen() { + // Make the function type: double(double,double) etc. + std::vector Doubles(Args.size(), Type::getDoubleTy(TheContext)); + FunctionType *FT = + FunctionType::get(Type::getDoubleTy(TheContext), Doubles, false); + + Function *F = + Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get()); + + // Set names for all arguments. + unsigned Idx = 0; + for (auto &Arg : F->args()) + Arg.setName(Args[Idx++]); + + return F; +} + +const PrototypeAST& FunctionAST::getProto() const { + return *Proto; +} + +const std::string& FunctionAST::getName() const { + return Proto->getName(); +} + +Function *FunctionAST::codegen() { + // Transfer ownership of the prototype to the FunctionProtos map, but keep a + // reference to it for use below. + auto &P = *Proto; + Function *TheFunction = getFunction(P.getName()); + if (!TheFunction) + return nullptr; + + // If this is an operator, install it. + if (P.isBinaryOp()) + BinopPrecedence[P.getOperatorName()] = P.getBinaryPrecedence(); + + // Create a new basic block to start insertion into. + BasicBlock *BB = BasicBlock::Create(TheContext, "entry", TheFunction); + Builder.SetInsertPoint(BB); + + // Record the function arguments in the NamedValues map. + NamedValues.clear(); + for (auto &Arg : TheFunction->args()) { + // Create an alloca for this variable. + AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, Arg.getName()); + + // Store the initial value into the alloca. + Builder.CreateStore(&Arg, Alloca); + + // Add arguments to variable symbol table. + NamedValues[Arg.getName()] = Alloca; + } + + if (Value *RetVal = Body->codegen()) { + // Finish off the function. + Builder.CreateRet(RetVal); + + // Validate the generated code, checking for consistency. + verifyFunction(*TheFunction); + + return TheFunction; + } + + // Error reading body, remove function. + TheFunction->eraseFromParent(); + + if (P.isBinaryOp()) + BinopPrecedence.erase(Proto->getOperatorName()); + return nullptr; +} + +//===----------------------------------------------------------------------===// +// Top-Level parsing and JIT Driver +//===----------------------------------------------------------------------===// + +static void InitializeModule() { + // Open a new module. + TheModule = llvm::make_unique("my cool jit", TheContext); + TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout()); +} + +std::unique_ptr +irgenAndTakeOwnership(FunctionAST &FnAST, const std::string &Suffix) { + if (auto *F = FnAST.codegen()) { + F->setName(F->getName() + Suffix); + auto M = std::move(TheModule); + // Start a new module. + InitializeModule(); + return M; + } else + report_fatal_error("Couldn't compile lazily JIT'd function"); +} + +static void HandleDefinition() { + if (auto FnAST = ParseDefinition()) { + FunctionProtos[FnAST->getProto().getName()] = + llvm::make_unique(FnAST->getProto()); + ExitOnErr(TheJIT->addFunctionAST(std::move(FnAST))); + } else { + // Skip token for error recovery. + getNextToken(); + } +} + +static void HandleExtern() { + if (auto ProtoAST = ParseExtern()) { + if (auto *FnIR = ProtoAST->codegen()) { + fprintf(stderr, "Read extern: "); + FnIR->dump(); + FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST); + } + } else { + // Skip token for error recovery. + getNextToken(); + } +} + +static void HandleTopLevelExpression() { + // Evaluate a top-level expression into an anonymous function. + if (auto FnAST = ParseTopLevelExpr()) { + FunctionProtos[FnAST->getName()] = + llvm::make_unique(FnAST->getProto()); + if (FnAST->codegen()) { + // JIT the module containing the anonymous expression, keeping a handle so + // we can free it later. + auto H = TheJIT->addModule(std::move(TheModule)); + InitializeModule(); + + // Search the JIT for the __anon_expr symbol. + auto ExprSymbol = TheJIT->findSymbol("__anon_expr"); + assert(ExprSymbol && "Function not found"); + + // Get the symbol's address and cast it to the right type (takes no + // arguments, returns a double) so we can call it as a native function. + ExitOnErr(TheJIT->executeRemoteExpr(ExprSymbol.getAddress())); + + // Delete the anonymous expression module from the JIT. + TheJIT->removeModule(H); + } + } else { + // Skip token for error recovery. + getNextToken(); + } +} + +/// top ::= definition | external | expression | ';' +static void MainLoop() { + while (true) { + fprintf(stderr, "ready> "); + switch (CurTok) { + case tok_eof: + return; + case ';': // ignore top-level semicolons. + getNextToken(); + break; + case tok_def: + HandleDefinition(); + break; + case tok_extern: + HandleExtern(); + break; + default: + HandleTopLevelExpression(); + break; + } + } +} + +//===----------------------------------------------------------------------===// +// "Library" functions that can be "extern'd" from user code. +//===----------------------------------------------------------------------===// + +/// putchard - putchar that takes a double and returns 0. +extern "C" double putchard(double X) { + fputc((char)X, stderr); + return 0; +} + +/// printd - printf that takes a double prints it as "%f\n", returning 0. +extern "C" double printd(double X) { + fprintf(stderr, "%f\n", X); + return 0; +} + +//===----------------------------------------------------------------------===// +// TCP / Connection setup code. +//===----------------------------------------------------------------------===// + +std::unique_ptr connect() { + int sockfd = socket(PF_INET, SOCK_STREAM, 0); + hostent *server = gethostbyname(HostName.c_str()); + + if (!server) { + errs() << "Could not find host " << HostName << "\n"; + exit(1); + } + + sockaddr_in servAddr; + bzero(&servAddr, sizeof(servAddr)); + servAddr.sin_family = PF_INET; + bcopy(server->h_addr, &servAddr.sin_addr.s_addr, server->h_length); + servAddr.sin_port = htons(Port); + if (connect(sockfd, reinterpret_cast(&servAddr), + sizeof(servAddr)) < 0) { + errs() << "Failure to connect.\n"; + exit(1); + } + + return llvm::make_unique(sockfd, sockfd); +} + +//===----------------------------------------------------------------------===// +// Main driver code. +//===----------------------------------------------------------------------===// + +int main(int argc, char *argv[]) { + // Parse the command line options. + cl::ParseCommandLineOptions(argc, argv, "Building A JIT - Client.\n"); + + InitializeNativeTarget(); + InitializeNativeTargetAsmPrinter(); + InitializeNativeTargetAsmParser(); + + ExitOnErr.setBanner("Kaleidoscope: "); + + // Install standard binary operators. + // 1 is lowest precedence. + BinopPrecedence['='] = 2; + BinopPrecedence['<'] = 10; + BinopPrecedence['+'] = 20; + BinopPrecedence['-'] = 20; + BinopPrecedence['*'] = 40; // highest. + + auto TCPChannel = connect(); + MyRemote Remote = ExitOnErr(MyRemote::Create(*TCPChannel)); + TheJIT = llvm::make_unique(Remote); + + // Automatically inject a definition for 'printExprResult'. + FunctionProtos["printExprResult"] = + llvm::make_unique("printExprResult", + std::vector({"Val"})); + + // Prime the first token. + fprintf(stderr, "ready> "); + getNextToken(); + + InitializeModule(); + + // Run the main "interpreter loop" now. + MainLoop(); + + // Delete the JIT before the Remote and Channel go out of scope, otherwise + // we'll crash in the JIT destructor when it tries to release remote + // resources over a channel that no longer exists. + TheJIT = nullptr; + + // Send a terminate message to the remote to tell it to exit cleanly. + ExitOnErr(Remote.terminateSession()); + + return 0; +}