Skip to content

Commit

Permalink
add cpp-fake-jit (#27)
Browse files Browse the repository at this point in the history
Users can now compile & run factors without cmake
  • Loading branch information
Menooker authored Nov 25, 2024
1 parent 87d4e08 commit a5c56ba
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 0 deletions.
Empty file added KunQuant/jit/__init__.py
Empty file.
59 changes: 59 additions & 0 deletions KunQuant/jit/cfake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
import subprocess
import tempfile
from typing import List, Tuple
import KunRunner
from KunQuant.Driver import compileit as driver_compileit
from KunQuant.Stage import Function
from KunQuant.passes import Util
import timeit

_cpp_root = os.path.join(os.path.dirname(__file__), "..", "..", "cpp")
_include_path = [_cpp_root]

def call_cpp_compiler(source: str, module_name: str, compiler: str, options: List[str], tempdir: str) -> str:
inpath = os.path.join(tempdir, f"{module_name}.cpp")
with open(inpath, 'w') as f:
f.write(source)
outpath = os.path.join(tempdir, f"{module_name}.so")
if Util.jit_debug_mode:
print("[KUN_JIT] temp jit files:", inpath, outpath)
cmd = [compiler] + options + [inpath, "-o", outpath]
subprocess.check_call(cmd, shell=False)
return outpath

class _fake_temp:
def __init__(self, dir) -> None:
self.dir = dir

def __enter__(self):
return self.dir

def __exit__(self, exception_type, exception_value, exception_traceback):
pass

def compile_cpp_and_load(source: str, module_name: str, tempdir: str, compiler: str, options: List[str], keep_files: bool) -> KunRunner.Library:
tempclass = _fake_temp if keep_files else tempfile.TemporaryDirectory
with tempclass(dir=tempdir) as tmpdirname:
outdir = call_cpp_compiler(source, module_name, compiler, options, tmpdirname)
lib = KunRunner.Library.load(outdir)
return lib

def compileit(f: Function, module_name: str, compiler: str = "g++", tempdir: str = None, keep_files: bool = False, **kwargs) -> Tuple[KunRunner.Library, KunRunner.Module]:
lib = None
src = None
if keep_files and not tempdir:
raise RuntimeError("if keep_files=True, tempdir should not be empty")
def kuncompile():
nonlocal src
src = driver_compileit(f, module_name, **kwargs)
def dowork():
nonlocal lib
lib = compile_cpp_and_load(src, module_name, tempdir, compiler, ["-std=c++11", "-O2", "-shared", "-fPIC", "-march=native"] + [f"-I{v}" for v in _include_path], keep_files)
if Util.jit_debug_mode:
print("[KUN_JIT] Source generation takes ", timeit.timeit(kuncompile, number=1), "s")
print("[KUN_JIT] C++ compiler takes ", timeit.timeit(dowork, number=1), "s")
else:
kuncompile()
dowork()
return lib, lib.getModule(module_name)
1 change: 1 addition & 0 deletions KunQuant/passes/Util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def _safe_cast(val):


debug_mode = _safe_cast(os.environ.get("KUN_DEBUG", "0"))
jit_debug_mode = _safe_cast(os.environ.get("KUN_DEBUG_JIT", ""))


def kun_pass(p):
Expand Down
19 changes: 19 additions & 0 deletions tests/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,31 @@
import sys
import warnings
import os
from KunQuant.jit import cfake
from KunQuant.Op import Input, Output, Builder
from KunQuant.Stage import Function

base_dir = "./build/Release/projects" if os.name == "nt" else "./build/projects"
base_dir2 = "./build/Release/" if os.name == "nt" else "./build/"
sys.path.append(base_dir2)
import KunRunner as kr


def test_cfake():
builder = Builder()
with builder:
inp1 = Input("a")
inp2 = Input("b")
Output(inp1 * inp2 + 10, "out")
f = Function(builder.ops)
lib, mod = cfake.compileit(f, "test1", input_layout="TS", output_layout="TS")

inp = np.random.rand(10, 24).astype("float32")
inp2 = np.random.rand(10, 24).astype("float32")
executor = kr.createSingleThreadExecutor()
out = kr.runGraph(executor, mod, {"a": inp, "b": inp2}, 0, 10)
np.testing.assert_allclose(inp * inp2 + 10, out["out"])

# inp = np.ndarray((3, 100, 8), dtype="float32")

lib = kr.Library.load(base_dir+"/Test.dll" if os.name == "nt" else base_dir+"/libTest.so")
Expand Down Expand Up @@ -244,4 +262,5 @@ def check(inp, timelen):
test_argmin_issue19()
test_aligned()
test_rank029()
test_cfake()
print("done")

0 comments on commit a5c56ba

Please sign in to comment.