Skip to content

Commit

Permalink
add base layout interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Menooker committed Jan 9, 2024
1 parent 9bc7db8 commit 8a21ee9
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 49 deletions.
19 changes: 16 additions & 3 deletions KunQuant/Driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ class _Partition:
num_in_dep = 0
is_rank = False

def compileit(f: Function, module_name: str, input_stride: int, output_stride: int, partition_factor = 4):
def compileit(f: Function, module_name: str, partition_factor = 4, output_layout = "ST8s"):
if output_layout not in ["ST8s", "FTS"]:
raise RuntimeError("Bad output_layout name " + output_layout)
input_name_to_idx: Dict[str, int] = dict()
buffer_names: List[_Buffer] = []
partitions: typing.OrderedDict[str, _Partition] = OrderedDict()
Expand All @@ -58,6 +60,14 @@ def insert_name(op: OpBase, kind: str) -> _Buffer:
return newbuf
return buffer_names[input_name_to_idx[name]]

def set_buffer_layout(op: OpBase, buf: _Buffer):
if buf.kind == "TEMP":
op.attrs["layout"] = "ST8s"
elif buf.kind == "INPUT":
op.attrs["layout"] = "ST8s"
elif buf.kind == "OUTPUT":
op.attrs["layout"] = output_layout

for op in f.ops:
if isinstance(op, Input):
insert_name(op, "INPUT")
Expand Down Expand Up @@ -85,11 +95,13 @@ def insert_name(op: OpBase, kind: str) -> _Buffer:
buf = insert_name(op, "TEMP")
pins.append(buf)
ins.append(op)
set_buffer_layout(op, buf)
elif isinstance(op, Output):
buf = insert_name(op, "TEMP")
pouts.append(buf)
outs.append(op)
src = codegen_cpp(func, input_stride, output_stride, input_name_to_idx, ins, outs)
set_buffer_layout(op, buf)
src = codegen_cpp(func, input_name_to_idx, ins, outs)
impl_src.append(src)
newparti = _Partition(func.name, len(partitions), pins, pouts)
if len(func.ops) == 3 and isinstance(func.ops[1], Rank):
Expand Down Expand Up @@ -143,6 +155,7 @@ def insert_name(op: OpBase, kind: str) -> _Buffer:
{len(partitions)},
__stages,
{len(buffer_names)},
__buffers
__buffers,
OutputLayout::{output_layout}
}};''')
return "\n\n".join(impl_src)
18 changes: 9 additions & 9 deletions KunQuant/passes/CodegenCpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,26 +59,26 @@ def _get_buffer_name(op: OpBase, idx: int) -> str:

vector_len = 8

def codegen_cpp(f: Function, input_stride: int, output_stride: int, input_name_to_idx: Dict[str, int], inputs: List[Input], outputs: List[Output]) -> str:
def codegen_cpp(f: Function, input_name_to_idx: Dict[str, int], inputs: List[Input], outputs: List[Output]) -> str:
if len(f.ops) == 3 and isinstance(f.ops[1], Rank):
return f'''static auto stage_{f.name} = RankST8sTimeStride8;'''
return f'''static auto stage_{f.name} = RankStocks{f.ops[0].attrs["layout"]}_{f.ops[2].attrs["layout"]};'''
header = f'''void stage_{f.name}(Context* __ctx, size_t __stock_idx, size_t __total_time, size_t __start, size_t __length) '''
toplevel = _CppScope(None)
buffer_type: Dict[OpBase, str] = dict()
# currently only support ST8s format
assert(input_stride == vector_len)
for inp in inputs:
name = inp.attrs["name"]
layout = inp.attrs["layout"]
idx_in_ctx = input_name_to_idx[name]
buffer_type[inp] = f"Input<{input_stride}>"
code = f"Input<{input_stride}> buf_{name}{{__ctx->buffers[{idx_in_ctx}].ptr + __stock_idx * __total_time * {input_stride} + __start * {input_stride} }};"
buffer_type[inp] = f"Input{layout}"
code = f"Input{layout} buf_{name}{{__ctx->buffers[{idx_in_ctx}].ptr, __stock_idx, __total_time, __start}};"
toplevel.scope.append(_CppSingleLine(toplevel, code))
assert(output_stride == vector_len)

for idx, outp in enumerate(outputs):
name = outp.attrs["name"]
layout = inp.attrs["layout"]
idx_in_ctx = input_name_to_idx[name]
buffer_type[outp] = f"Output<{output_stride}>"
code = f"Output<{output_stride}> buf_{name}{{__ctx->buffers[{idx_in_ctx}].ptr + __stock_idx * __length * {input_stride}}};"
buffer_type[outp] = f"Output{layout}"
code = f"Output{layout} buf_{name}{{__ctx->buffers[{idx_in_ctx}].ptr, __stock_idx, __length, __start}};"
toplevel.scope.append(_CppSingleLine(toplevel, code))
for op in f.ops:
if op.get_parent() is None and isinstance(op, WindowedTempOutput):
Expand Down
2 changes: 1 addition & 1 deletion cpp/Kun/Context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ struct Context {

KUN_API std::shared_ptr<Executor> createSingleThreadExecutor();
namespace ops {
KUN_API void RankST8sTimeStride8(RuntimeStage *stage, size_t __stock_idx,
KUN_API void RankStocksST8s_ST8s(RuntimeStage *stage, size_t __stock_idx,
size_t __total_time, size_t __start, size_t __length);
}
} // namespace kun
22 changes: 14 additions & 8 deletions cpp/Kun/Module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,26 @@

namespace kun {

enum class OutputLayout {
ST8s,
FTS,
};

struct Module {
size_t num_stages;
Stage* stages;
Stage *stages;
size_t num_buffers;
BufferInfo* buffers;
BufferInfo *buffers;
OutputLayout layout;
};

struct Library {
void* handle;
const Module* getModule(const char* name);
static std::shared_ptr<Library> load(const char* filename);
Library(const Library&) = delete;
Library(void* handle): handle{handle} {}
void *handle;
const Module *getModule(const char *name);
static std::shared_ptr<Library> load(const char *filename);
Library(const Library &) = delete;
Library(void *handle) : handle{handle} {}
~Library();
};

}
} // namespace kun
14 changes: 8 additions & 6 deletions cpp/Kun/Ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ inline f32x8 Select(f32x8 cond, f32x8 vtrue, f32x8 vfalse) {
return _mm256_blendv_ps(vfalse, vtrue, cond);
}

template <int stride>
struct Input : DataSource<true> {
struct InputST8s : DataSource<true> {
constexpr static size_t stride = 8;
float *buf;
Input(float *buf) : buf{buf} {}
InputST8s(float *base, size_t stock_idx, size_t total_time, size_t start)
: buf{base + stock_idx * total_time * stride + start * stride} {}
f32x8 step(size_t index) { return _mm256_loadu_ps(&buf[index * stride]); }

f32x8 getWindow(size_t index, size_t offset) {
Expand All @@ -39,10 +40,11 @@ struct Input : DataSource<true> {
}
};

template <int stride>
struct Output : DataSource<true> {
struct OutputST8s : DataSource<true> {
constexpr static size_t stride = 8;
float *buf;
Output(float *buf) : buf{buf} {}
OutputST8s(float *base, size_t stock_idx, size_t length, size_t start)
: buf{base + stock_idx * length * stride} {}
void store(size_t index, const f32x8 &v) {
_mm256_storeu_ps(&buf[index * stride], v);
}
Expand Down
38 changes: 26 additions & 12 deletions cpp/Kun/Rank.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,41 +8,55 @@
namespace kun {
namespace ops {

void RankST8sTimeStride8(RuntimeStage *stage, size_t __stock_idx,
size_t __total_time, size_t __start, size_t __length) {
struct MapperST8s {
static size_t call(size_t stockid, size_t t, size_t num_time) {
auto S = stockid / simd_len;
return S * num_time * simd_len + t * simd_len + stockid % simd_len;
}
};

template <typename INPUT, typename OUTPUT>
static void RankStocks(RuntimeStage *stage, size_t time_idx,
size_t __total_time, size_t __start, size_t __length) {
auto num_stocks = stage->ctx->stock_count;
auto num_time = stage->ctx->total_time;
float* input = stage->ctx->buffers[stage->stage->in_buffers[0]->id].ptr;
float* output = stage->ctx->buffers[stage->stage->out_buffers[0]->id].ptr;
float *input = stage->ctx->buffers[stage->stage->in_buffers[0]->id].ptr;
float *output = stage->ctx->buffers[stage->stage->out_buffers[0]->id].ptr;
auto time_end =
std::min(__start + (__stock_idx + 1) * time_stride, __start + __length);
std::min(__start + (time_idx + 1) * time_stride, __start + __length);
std::vector<float> data;
data.reserve(num_stocks);
for (size_t t = __start + __stock_idx * time_stride; t < time_end; t++) {
for (size_t t = __start + time_idx * time_stride; t < time_end; t++) {
for (size_t i = 0; i < num_stocks; i++) {
auto S = i / simd_len;
float in = input[S * num_time * simd_len + t * simd_len + i % simd_len];
if(!std::isnan(in)) {
float in = input[INPUT::call(i, t, num_time)];
if (!std::isnan(in)) {
data.push_back(in);
}
}
std::sort(data.begin(), data.end());
for (size_t i = 0; i < num_stocks; i++) {
auto S = i / simd_len;
float in = input[S * num_time * simd_len + t * simd_len + i % simd_len];
float in = input[INPUT::call(i, t, num_time)];
float out;
if(!std::isnan(in)) {
if (!std::isnan(in)) {
auto pos = std::equal_range(data.begin(), data.end(), in);
auto start = pos.first - data.begin();
auto end = pos.second - data.begin();
out = ((start+end-1)/2.0f + 1.0f)/data.size();
out = ((start + end - 1) / 2.0f + 1.0f) / data.size();
} else {
out = NAN;
}
output[S * num_time * simd_len + t * simd_len + i % simd_len] = out;
output[OUTPUT::call(i, t, num_time)] = out;
}
data.clear();
}
}

void RankStocksST8s_ST8s(RuntimeStage *stage, size_t time_idx,
size_t __total_time, size_t __start, size_t __length) {
RankStocks<MapperST8s, MapperST8s>(stage, time_idx, __total_time,
__start, __length);
}
} // namespace ops
} // namespace kun
3 changes: 2 additions & 1 deletion tests/cpp/TestRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,6 @@ KUN_API Module testRuntimeModule{
arraySize(stages),
stages,
arraySize(buffers),
buffers
buffers,
OutputLayout::ST8s
};
18 changes: 9 additions & 9 deletions tests/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def build_avg_and_stddev():

def check_1():
f = build_avg_and_stddev()
src = compileit(f, "avg_and_stddev", 8, 8)
src = compileit(f, "avg_and_stddev")
with open("./tests/cpp/generated/AvgAndStddev.cpp", 'w') as f:
f.write(src)

Expand All @@ -28,7 +28,7 @@ def check_rank():
inp1 = Input("a")
out2 = Output(Rank(inp1), "ou2")
f = Function(builder.ops)
src = compileit(f, "test_rank", 8, 8)
src = compileit(f, "test_rank")
with open("./tests/cpp/generated/TestRank.cpp", 'w') as f:
f.write(src)

Expand All @@ -41,7 +41,7 @@ def check_rank2():
v3 = Add(v2, v1)
Output(v3, "out")
f = Function(builder.ops)
src = compileit(f, "test_rank2", 8, 8)
src = compileit(f, "test_rank2")
with open("./tests/cpp/generated/TestRank2.cpp", 'w') as f:
f.write(src)

Expand All @@ -51,7 +51,7 @@ def check_log():
inp1 = Input("a")
Output(Log(inp1), "outlog")
f = Function(builder.ops)
src = compileit(f, "test_log", 8, 8)
src = compileit(f, "test_log")
with open("./tests/cpp/generated/TestLog.cpp", 'w') as f:
f.write(src)

Expand All @@ -62,12 +62,12 @@ def check_alpha101():
for f in all_alpha:
f(all_data)
f = Function(builder.ops)
src = compileit(f, "alpha_101", 8, 8, 4)
src = compileit(f, "alpha_101", output_layout="ST8s")
with open("./tests/cpp/generated/Alpha101.cpp", 'w') as f:
f.write(src)

#check_1()
#check_rank()
#check_rank2()
#check_log()
check_1()
check_rank()
check_rank2()
check_log()
check_alpha101()

0 comments on commit 8a21ee9

Please sign in to comment.