Skip to content

Commit

Permalink
多元回归整理
Browse files Browse the repository at this point in the history
  • Loading branch information
wukan1986 committed Feb 26, 2024
1 parent c4a8ac7 commit 6c8461f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
2 changes: 1 addition & 1 deletion polars_ta/wq/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def round_down(x: Expr, f: int = 1) -> Expr:


def s_log_1p(x: Expr) -> Expr:
return (x.abs() + 1).log() * x.sign()
return x.abs().log1p() * x.sign()


def sign(x: Expr) -> Expr:
Expand Down
12 changes: 9 additions & 3 deletions polars_ta/wq/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from polars import Expr, Series
from polars import Expr, Series, map_batches


def cs_standardize_zscore(x: Expr, ddof: int = 0) -> Expr:
Expand Down Expand Up @@ -42,6 +42,7 @@ def cs_neutralize_demean(x: Expr) -> Expr:


def cs_neutralize_residual_simple(y: Expr, x: Expr) -> Expr:
"""一元回归"""
# https://stackoverflow.com/a/74906705/1894479
# 一元回归时,这个版本更快,不需再补充常量1
# e_i = y_i - a - bx_i
Expand All @@ -56,7 +57,7 @@ def cs_neutralize_residual_simple(y: Expr, x: Expr) -> Expr:
return y_demeaned - beta * x_demeaned


def cs_neutralize_residual_multiple(cols) -> Series:
def residual_multiple(cols) -> Series:
# https://stackoverflow.com/a/74906705/1894479
# 比struct.unnest要快一些
cols = [c.to_numpy() for c in cols]
Expand All @@ -66,4 +67,9 @@ def cs_neutralize_residual_multiple(cols) -> Series:
coef = np.linalg.lstsq(A, y, rcond=None)[0]
y_hat = np.sum(A * coef, axis=1)
residual = y - y_hat
return Series(residual)
return Series(residual, nan_to_null=True)


def cs_neutralize_residual_multiple(y: Expr, x: Expr, *more_x: Expr) -> Expr:
"""多元回归"""
return map_batches([y, x, *more_x], lambda xx: residual_multiple(xx))

0 comments on commit 6c8461f

Please sign in to comment.