Skip to content

Commit

Permalink
make factor arguments meaning more clear
Browse files Browse the repository at this point in the history
  • Loading branch information
foolcage committed Jul 26, 2021
1 parent fc58e68 commit eecde5e
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 56 deletions.
62 changes: 40 additions & 22 deletions zvt/contract/factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,29 +207,47 @@ class Factor(DataReader, DataListener):
transformer: Transformer = None
accumulator: Accumulator = None

def __init__(self, data_schema: Type[Mixin], entity_schema: Type[TradableEntity] = None, provider: str = None,
entity_provider: str = None, entity_ids: List[str] = None, exchanges: List[str] = None,
codes: List[str] = None, start_timestamp: Union[str, pd.Timestamp] = None,
end_timestamp: Union[str, pd.Timestamp] = None, columns: List = None, filters: List = None,
order: object = None, limit: int = None, level: Union[str, IntervalLevel] = None,
category_field: str = 'entity_id', time_field: str = 'timestamp', computing_window: int = None,
keep_all_timestamp: bool = False, fill_method: str = 'ffill', effective_number: int = None,
transformer: Transformer = None, accumulator: Accumulator = None, need_persist: bool = False,
dry_run: bool = False, factor_name: str = None, clear_state: bool = False,
not_load_data: bool = False) -> None:
def __init__(self,
data_schema: Type[Mixin],
entity_schema: Type[TradableEntity] = None,
provider: str = None,
entity_provider: str = None,
entity_ids: List[str] = None,
exchanges: List[str] = None,
codes: List[str] = None,
start_timestamp: Union[str, pd.Timestamp] = None,
end_timestamp: Union[str, pd.Timestamp] = None,
columns: List = None,
filters: List = None,
order: object = None,
limit: int = None,
level: Union[str, IntervalLevel] = None,
category_field: str = 'entity_id',
time_field: str = 'timestamp',
computing_window: int = None,
keep_all_timestamp: bool = False,
fill_method: str = 'ffill',
effective_number: int = None,
transformer: Transformer = None,
accumulator: Accumulator = None,
need_persist: bool = False,
only_compute_factor: bool = False,
factor_name: str = None,
clear_state: bool = False,
only_load_factor: bool = False) -> None:
"""
:param computing_window: the window size for computing factor
:param keep_all_timestamp: whether fill all timestamp gap,default False
:param keep_all_timestamp:
:param fill_method:
:param effective_number:
:param transformer:
:param accumulator:
:param need_persist: whether persist factor
:param dry_run: True for just computing factor, False for backtesting
:param only_compute_factor: only compute factor nor result
:param factor_name:
:param clear_state:
:param only_load_factor: only load factor and compute result
"""

self.not_load_data = not_load_data
self.only_load_factor = only_load_factor

super().__init__(data_schema, entity_schema, provider, entity_provider, entity_ids, exchanges, codes,
start_timestamp, end_timestamp, columns, filters, order, limit, level,
Expand Down Expand Up @@ -259,7 +277,7 @@ def __init__(self, data_schema: Type[Mixin], entity_schema: Type[TradableEntity]
self.accumulator = self.__class__.accumulator

self.need_persist = need_persist
self.dry_run = dry_run
self.dry_run = only_compute_factor

# 中间结果,不持久化
# data_df->pipe_df
Expand Down Expand Up @@ -305,11 +323,11 @@ def __init__(self, data_schema: Type[Mixin], entity_schema: Type[TradableEntity]

# the compute logic is not triggered from load data
# for the case:1)load factor from db 2)compute the result
if self.not_load_data:
if self.only_load_factor:
self.compute()

def load_data(self):
if self.not_load_data:
if self.only_load_factor:
return
super().load_data()

Expand Down Expand Up @@ -372,7 +390,7 @@ def factor_encoder(self):
return None

def pre_compute(self):
if not self.not_load_data and not pd_is_not_null(self.pipe_df):
if not self.only_load_factor and not pd_is_not_null(self.pipe_df):
self.pipe_df = self.data_df

def do_compute(self):
Expand All @@ -385,7 +403,7 @@ def do_compute(self):
self.logger.info('compute result finish')

def compute_factor(self):
if self.not_load_data:
if self.only_load_factor:
return
# 无状态的转换运算
if pd_is_not_null(self.data_df) and self.transformer:
Expand All @@ -403,7 +421,7 @@ def compute_result(self):
pass

def after_compute(self):
if self.not_load_data:
if self.only_load_factor:
return
if self.keep_all_timestamp:
self.fill_gap()
Expand Down
51 changes: 31 additions & 20 deletions zvt/factors/fundamental/finance_factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,27 @@ def __init__(self, data_schema: Type[Mixin] = FinanceFactor, entity_schema: Type
category_field: str = 'entity_id', time_field: str = 'timestamp', computing_window: int = None,
keep_all_timestamp: bool = False, fill_method: str = 'ffill', effective_number: int = None,
transformer: Transformer = None, accumulator: Accumulator = None, need_persist: bool = False,
dry_run: bool = False, factor_name: str = None, clear_state: bool = False,
not_load_data: bool = False) -> None:
only_compute_factor: bool = False, factor_name: str = None, clear_state: bool = False,
only_load_factor: bool = False) -> None:
if not columns:
columns = data_schema.important_cols()
super().__init__(data_schema, entity_schema, provider, entity_provider, entity_ids, exchanges, codes,
start_timestamp, end_timestamp, columns, filters, order, limit, level, category_field,
time_field, computing_window, keep_all_timestamp, fill_method, effective_number, transformer,
accumulator, need_persist, dry_run, factor_name, clear_state, not_load_data)
accumulator, need_persist, only_compute_factor, factor_name, clear_state, only_load_factor)


class GoodCompanyFactor(FinanceBaseFactor, FilterFactor):
def __init__(self, data_schema: Type[Mixin] = FinanceFactor, entity_schema: TradableEntity = Stock,
def __init__(self,
data_schema: Type[Mixin] = FinanceFactor,
entity_schema: TradableEntity = Stock,
provider: str = None,
entity_provider: str = None, entity_ids: List[str] = None, exchanges: List[str] = None,
codes: List[str] = None, start_timestamp: Union[str, pd.Timestamp] = None,
entity_provider: str = None,
entity_ids: List[str] = None,
exchanges: List[str] = None,
codes: List[str] = None,
start_timestamp: Union[str, pd.Timestamp] = None,
end_timestamp: Union[str, pd.Timestamp] = None,
# 高roe,高现金流,低财务杠杆,有增长
columns: List = (FinanceFactor.roe,
FinanceFactor.op_income_growth_yoy,
FinanceFactor.net_profit_growth_yoy,
Expand All @@ -52,13 +56,22 @@ def __init__(self, data_schema: Type[Mixin] = FinanceFactor, entity_schema: Trad
FinanceFactor.sales_net_cash_flow_per_op_income >= 0.3,
FinanceFactor.current_ratio >= 1,
FinanceFactor.debt_asset_ratio <= 0.5),
order: object = None, limit: int = None,
level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY, category_field: str = 'entity_id',
time_field: str = 'timestamp', computing_window: int = None, keep_all_timestamp: bool = True,
fill_method: str = 'ffill', effective_number: int = None, transformer: Transformer = None,
accumulator: Accumulator = None, need_persist: bool = False, dry_run: bool = False,
factor_name: str = None, clear_state: bool = False, not_load_data: bool = False,
# 3 years
order: object = None,
limit: int = None,
level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
category_field: str = 'entity_id',
time_field: str = 'timestamp',
computing_window: int = None,
keep_all_timestamp: bool = True,
fill_method: str = 'ffill',
effective_number: int = None,
transformer: Transformer = None,
accumulator: Accumulator = None,
need_persist: bool = False,
only_compute_factor: bool = False,
factor_name: str = None,
clear_state: bool = False,
only_load_factor: bool = False,
window='1095d',
count=8,
col_period_threshold={'roe': 0.02}) -> None:
Expand All @@ -76,7 +89,7 @@ def __init__(self, data_schema: Type[Mixin] = FinanceFactor, entity_schema: Trad
super().__init__(data_schema, entity_schema, provider, entity_provider, entity_ids, exchanges, codes,
start_timestamp, end_timestamp, columns, filters, order, limit, level, category_field,
time_field, computing_window, keep_all_timestamp, fill_method, effective_number, transformer,
accumulator, need_persist, dry_run, factor_name, clear_state, not_load_data)
accumulator, need_persist, only_compute_factor, factor_name, clear_state, only_load_factor)

def compute_factor(self):
def filter_df(df):
Expand Down Expand Up @@ -127,11 +140,9 @@ def compute_result(self):
# print(f1.result_df)

# 高股息 低应收
factor2 = GoodCompanyFactor(data_schema=BalanceSheet,
columns=[BalanceSheet.accounts_receivable],
filters=[
BalanceSheet.accounts_receivable <= 0.2 * BalanceSheet.total_current_assets],
col_period_threshold=None, keep_all_timestamp=False)
factor2 = GoodCompanyFactor(data_schema=BalanceSheet, columns=[BalanceSheet.accounts_receivable], filters=[
BalanceSheet.accounts_receivable <= 0.2 * BalanceSheet.total_current_assets], keep_all_timestamp=False,
col_period_threshold=None)
print(factor2.result_df)
# the __all__ is generated
__all__ = ['FinanceBaseFactor', 'GoodCompanyFactor']
10 changes: 5 additions & 5 deletions zvt/factors/ma/ma_factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, entity_schema: Type[TradableEntity] = Stock, provider: str =
level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY, category_field: str = 'entity_id',
time_field: str = 'timestamp', computing_window: int = None, keep_all_timestamp: bool = False,
fill_method: str = 'ffill', effective_number: int = None, need_persist: bool = False,
dry_run: bool = False, factor_name: str = None, clear_state: bool = False, not_load_data: bool = False,
only_compute_factor: bool = False, factor_name: str = None, clear_state: bool = False, only_load_factor: bool = False,
adjust_type: Union[AdjustType, str] = None, windows=None) -> None:
if need_persist:
self.factor_schema = get_ma_factor_schema(entity_type=entity_schema.__name__, level=level)
Expand All @@ -45,7 +45,7 @@ def __init__(self, entity_schema: Type[TradableEntity] = Stock, provider: str =
super().__init__(entity_schema, provider, entity_provider, entity_ids, exchanges, codes, start_timestamp,
end_timestamp, columns, filters, order, limit, level, category_field, time_field,
computing_window, keep_all_timestamp, fill_method, effective_number, transformer, None,
need_persist, dry_run, factor_name, clear_state, not_load_data, adjust_type)
need_persist, only_compute_factor, factor_name, clear_state, only_load_factor, adjust_type)

def drawer_factor_df_list(self) -> Optional[List[pd.DataFrame]]:
return [self.factor_df[self.transformer.indicators]]
Expand Down Expand Up @@ -74,8 +74,8 @@ def __init__(self, entity_schema: Type[TradableEntity] = Stock, provider: str =
level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY, category_field: str = 'entity_id',
time_field: str = 'timestamp', computing_window: int = None, keep_all_timestamp: bool = False,
fill_method: str = 'ffill', effective_number: int = None, accumulator: Accumulator = None,
need_persist: bool = False, dry_run: bool = False, factor_name: str = None, clear_state: bool = False,
not_load_data: bool = False, adjust_type: Union[AdjustType, str] = None, windows=None,
need_persist: bool = False, only_compute_factor: bool = False, factor_name: str = None, clear_state: bool = False,
only_load_factor: bool = False, adjust_type: Union[AdjustType, str] = None, windows=None,
vol_windows=None) -> None:
if not windows:
windows = [250]
Expand All @@ -93,7 +93,7 @@ def __init__(self, entity_schema: Type[TradableEntity] = Stock, provider: str =
super().__init__(entity_schema, provider, entity_provider, entity_ids, exchanges, codes, start_timestamp,
end_timestamp, columns, filters, order, limit, level, category_field, time_field,
computing_window, keep_all_timestamp, fill_method, effective_number, transformer, accumulator,
need_persist, dry_run, factor_name, clear_state, not_load_data, adjust_type)
need_persist, only_compute_factor, factor_name, clear_state, only_load_factor, adjust_type)

def compute_result(self):
super().compute_result()
Expand Down
6 changes: 3 additions & 3 deletions zvt/factors/ma/top_bottom_factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def __init__(self, entity_schema: TradableEntity = Stock, provider: str = None,
level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY, category_field: str = 'entity_id',
time_field: str = 'timestamp', computing_window: int = None, keep_all_timestamp: bool = False,
fill_method: str = 'ffill', effective_number: int = None,
accumulator: Accumulator = None, need_persist: bool = False, dry_run: bool = False,
factor_name: str = None, clear_state: bool = False, not_load_data: bool = False,
accumulator: Accumulator = None, need_persist: bool = False, only_compute_factor: bool = False,
factor_name: str = None, clear_state: bool = False, only_load_factor: bool = False,
adjust_type: Union[AdjustType, str] = None, window=30) -> None:
self.adjust_type = adjust_type

Expand All @@ -50,7 +50,7 @@ def __init__(self, entity_schema: TradableEntity = Stock, provider: str = None,
super().__init__(entity_schema, provider, entity_provider, entity_ids, exchanges, codes, start_timestamp,
end_timestamp, columns, filters, order, limit, level, category_field, time_field,
computing_window, keep_all_timestamp, fill_method, effective_number, transformer, accumulator,
need_persist, dry_run, factor_name, clear_state, not_load_data, adjust_type)
need_persist, only_compute_factor, factor_name, clear_state, only_load_factor, adjust_type)


if __name__ == '__main__':
Expand Down
6 changes: 3 additions & 3 deletions zvt/factors/technical_factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def __init__(self, entity_schema: Type[TradableEntity] = Stock, provider: str =
level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY, category_field: str = 'entity_id',
time_field: str = 'timestamp', computing_window: int = None, keep_all_timestamp: bool = False,
fill_method: str = 'ffill', effective_number: int = None, transformer: Transformer = None,
accumulator: Accumulator = None, need_persist: bool = False, dry_run: bool = False,
factor_name: str = None, clear_state: bool = False, not_load_data: bool = False,
accumulator: Accumulator = None, need_persist: bool = False, only_compute_factor: bool = False,
factor_name: str = None, clear_state: bool = False, only_load_factor: bool = False,
adjust_type: Union[AdjustType, str] = None) -> None:
if columns is None:
columns = ['id', 'entity_id', 'timestamp', 'level', 'open', 'close', 'high', 'low']
Expand All @@ -38,7 +38,7 @@ def __init__(self, entity_schema: Type[TradableEntity] = Stock, provider: str =
super().__init__(self.data_schema, entity_schema, provider, entity_provider, entity_ids, exchanges, codes,
start_timestamp, end_timestamp, columns, filters, order, limit, level, category_field,
time_field, computing_window, keep_all_timestamp, fill_method, effective_number, transformer,
accumulator, need_persist, dry_run, factor_name, clear_state, not_load_data)
accumulator, need_persist, only_compute_factor, factor_name, clear_state, only_load_factor)


# the __all__ is generated
Expand Down
6 changes: 3 additions & 3 deletions zvt/factors/zen/zen_factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,14 +542,14 @@ def __init__(self, entity_schema: Type[TradableEntity] = Stock, provider: str =
level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY, category_field: str = 'entity_id',
time_field: str = 'timestamp', computing_window: int = None, keep_all_timestamp: bool = False,
fill_method: str = 'ffill', effective_number: int = None, transformer: Transformer = None,
accumulator: Accumulator = ZenAccumulator(), need_persist: bool = False, dry_run: bool = False,
factor_name: str = None, clear_state: bool = False, not_load_data: bool = False,
accumulator: Accumulator = ZenAccumulator(), need_persist: bool = False, only_compute_factor: bool = False,
factor_name: str = None, clear_state: bool = False, only_load_factor: bool = False,
adjust_type: Union[AdjustType, str] = None) -> None:
self.factor_schema = get_zen_factor_schema(entity_type=entity_schema.__name__, level=level)
super().__init__(entity_schema, provider, entity_provider, entity_ids, exchanges, codes, start_timestamp,
end_timestamp, columns, filters, order, limit, level, category_field, time_field,
computing_window, keep_all_timestamp, fill_method, effective_number, transformer, accumulator,
need_persist, dry_run, factor_name, clear_state, not_load_data, adjust_type)
need_persist, only_compute_factor, factor_name, clear_state, only_load_factor, adjust_type)

def factor_col_map_object_hook(self) -> dict:
return {
Expand Down

0 comments on commit eecde5e

Please sign in to comment.