Skip to content

Commit

Permalink
fix TsArgMax TsArgMin and reverse the back window loop order (#20)
Browse files Browse the repository at this point in the history
keep same semantics of `pd.argmax`. Also let back window loop from oldest to newest.
  • Loading branch information
Menooker authored Jul 16, 2024
1 parent f4076dd commit b384b89
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 7 deletions.
2 changes: 1 addition & 1 deletion KunQuant/passes/CodegenCpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def codegen_cpp(f: Function, input_name_to_idx: Dict[str, int], inputs: List[Tup
scope.scope.append(_CppSingleLine(scope, f"auto v{idx} = {thename}(v{inp[0]});"))
elif isinstance(op, ForeachBackWindow):
window = op.attrs["window"]
the_for = _CppFor(scope, f"for(size_t idx_{idx} = 0;idx_{idx} < {window};idx_{idx}++) ")
the_for = _CppFor(scope, f"for(int idx_{idx} = {window - 1};idx_{idx} >= 0;idx_{idx}--) ")
scope.scope.append(the_for)
loop_to_cpp_loop[op] = the_for.body
elif isinstance(op, IterValue):
Expand Down
14 changes: 9 additions & 5 deletions cpp/Kun/Ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,15 +400,19 @@ template <typename T, int stride>
struct ReduceMin {
using simd_t = kun_simd::vec<T, stride>;
simd_t v = std::numeric_limits<T>::infinity();
void step(simd_t input, size_t index) { v = sc_min(v, input); }
void step(simd_t input, size_t index) {
v = sc_select(sc_isnan(v, input), NAN, sc_min(v, input));
}
operator simd_t() { return v; }
};

template <typename T, int stride>
struct ReduceMax {
using simd_t = kun_simd::vec<T, stride>;
simd_t v = -std::numeric_limits<T>::infinity();
void step(simd_t input, size_t index) { v = sc_max(v, input); }
void step(simd_t input, size_t index) {
v = sc_select(sc_isnan(v, input), NAN, sc_max(v, input));
}
operator simd_t() { return v; }
};

Expand All @@ -418,11 +422,11 @@ struct ReduceDecayLinear {
static constexpr T stepSize() {
return 1.0 / ((1.0 + window) * window / 2);
}
simd_t weight = window * stepSize();
simd_t weight = stepSize();
simd_t v = 0;
void step(simd_t input, size_t index) {
v = sc_fmadd(input, weight, v);
weight = weight - stepSize();
weight = weight + stepSize();
}
operator simd_t() { return v; }
};
Expand Down Expand Up @@ -494,7 +498,7 @@ struct ReduceArgMin {
simd_t idx = 0;
void step(simd_t input, size_t index) {
auto is_nan = sc_isnan(v, input);
auto cmp = GreaterThan(v, input);
auto cmp = v > input;
v = sc_select(cmp, input, v);
v = sc_select(is_nan, NAN, v);
idx = sc_select(cmp, T(index), idx);
Expand Down
14 changes: 14 additions & 0 deletions projects/Test/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@ def check_ema():
with open(paths[-1], 'w') as f:
f.write(src)

def check_argmin():
builder = Builder()
with builder:
inp1 = Input("a")
out2 = Output(TsArgMin(inp1, 5), "ou2")
Output(WindowedMin(inp1, 5), "tsmin")
Output(TsRank(inp1, 5), "tsrank")
f = Function(builder.ops)
src = compileit(f, "test_argmin", input_layout="TS", output_layout="TS")
paths.append(sys.argv[1]+"/TestArgMin.cpp")
with open(paths[-1], 'w') as f:
f.write(src)

def check_rank():
builder = Builder()
with builder:
Expand Down Expand Up @@ -123,6 +136,7 @@ def check_alpha101_double():
check_pow()
check_alpha101_double()
check_ema()
check_argmin()

# with open(sys.argv[1]+"/generated.txt", 'w') as f:
# f.write(";".join(paths))
Expand Down
3 changes: 2 additions & 1 deletion projects/Test/list.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ TestLog.cpp
TestLog64.cpp
TestPow.cpp
Alpha101.cpp
TestEMA.cpp
TestEMA.cpp
TestArgMin.cpp
1 change: 1 addition & 0 deletions tests/test_alpha158.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,4 @@ def test(inputs: Dict[str, np.ndarray], ref: Dict[str, np.ndarray]):
args = parser.parse_args()
inp, ref = load(args.inputs, args.ref)
test(inp, ref)
print("done")
18 changes: 18 additions & 0 deletions tests/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,23 @@ def test_ema():
expected = pd.DataFrame(inp).ewm(span=5, adjust=False).mean()
np.testing.assert_allclose(output, expected, rtol=1e-6, equal_nan=True)

def test_argmin_issue19():
#https://github.com/Menooker/KunQuant/issues/19
modu = lib.getModule("test_argmin")
assert(modu)
inp = np.empty((6, 8),"float32")
data = [ 0.6898481863442985, 0.6992020600574415, 0.6992020600574417, 0.6968635916291558, 0.6968635916291558, 0.6968635916291558 ]
for i in range(6):
inp[i, :] = data[i]
executor = kr.createSingleThreadExecutor()
out = kr.runGraph(executor, modu, {"a": inp}, 0, 6)
df = pd.DataFrame(inp)
expected =df.rolling(5, min_periods=1).apply(lambda x: x.argmin() + 1, raw=True)
output = out["ou2"][4:]
np.testing.assert_allclose(output, expected[4:], rtol=1e-6, equal_nan=True)
np.testing.assert_allclose(out["tsmin"], df.rolling(5).min(), rtol=1e-6, equal_nan=True)
np.testing.assert_allclose(out["tsrank"], df.rolling(5).rank(), rtol=1e-6, equal_nan=True)

test_avg_stddev_TS()
test_runtime()
test_avg_stddev()
Expand All @@ -168,4 +185,5 @@ def test_ema():
test_log("float64", "64")
test_pow()
test_ema()
test_argmin_issue19()
print("done")

0 comments on commit b384b89

Please sign in to comment.