Skip to content

Commit

Permalink
Support refresh entities
Browse files Browse the repository at this point in the history
  • Loading branch information
foolcage committed Apr 4, 2023
1 parent 5dffad4 commit 5351207
Show file tree
Hide file tree
Showing 38 changed files with 1,536 additions and 514 deletions.
4 changes: 2 additions & 2 deletions examples/factors/tech_factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
category_field: str = "entity_id",
time_field: str = "timestamp",
computing_window: int = None,
keep_window: int = None,
keep_all_timestamp: bool = False,
fill_method: str = "ffill",
effective_number: int = None,
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(
level,
category_field,
time_field,
computing_window,
keep_window,
keep_all_timestamp,
fill_method,
effective_number,
Expand Down
91 changes: 15 additions & 76 deletions examples/report_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@

from examples.utils import add_to_eastmoney
from zvt import zvt_config
from zvt.api import get_top_volume_entities, get_top_performance_entities, TopType
from zvt.api import get_top_volume_entities, TopType
from zvt.api.kdata import get_latest_kdata_date, get_kdata_schema, default_adjust_type
from zvt.api.selector import get_entity_ids_by_filter
from zvt.api.stats import get_top_performance_entities_by_periods
from zvt.contract import IntervalLevel
from zvt.contract.api import get_entities, get_entity_schema
from zvt.contract.factor import Factor
from zvt.factors import TargetSelector, SelectMode
from zvt.contract.factor import Factor, TargetType
from zvt.informer import EmailInformer
from zvt.utils import next_date

Expand Down Expand Up @@ -109,9 +108,6 @@ def report_targets(
current_entity_pool = set(factor_kv.pop("entity_ids"))

# add the factor
my_selector = TargetSelector(
start_timestamp=start_timestamp, end_timestamp=target_date, select_mode=SelectMode.condition_or
)
entity_schema = get_entity_schema(entity_type=entity_type)
tech_factor = factor_cls(
entity_schema=entity_schema,
Expand All @@ -123,11 +119,8 @@ def report_targets(
adjust_type=adjust_type,
**factor_kv,
)
my_selector.add_factor(tech_factor)

my_selector.run()

long_stocks = my_selector.get_open_long_targets(timestamp=target_date)
long_stocks = tech_factor.get_targets(timestamp=target_date, target_type=TargetType.positive)

inform(
informer,
Expand Down Expand Up @@ -174,79 +167,25 @@ def report_top_entities(

while error_count <= 10:
try:
if periods is None:
periods = [7, 30, 365]
if not adjust_type:
adjust_type = default_adjust_type(entity_type=entity_type)
kdata_schema = get_kdata_schema(entity_type=entity_type, adjust_type=adjust_type)
entity_schema = get_entity_schema(entity_type=entity_type)

target_date = get_latest_kdata_date(
provider=data_provider, entity_type=entity_type, adjust_type=adjust_type
)

filter_entity_ids = get_entity_ids_by_filter(
provider=entity_provider,
ignore_st=ignore_st,
selected = get_top_performance_entities_by_periods(
entity_provider=entity_provider,
data_provider=data_provider,
periods=periods,
ignore_new_stock=ignore_new_stock,
entity_schema=entity_schema,
target_date=target_date,
ignore_st=ignore_st,
entity_ids=entity_ids,
entity_type=entity_type,
adjust_type=adjust_type,
top_count=top_count,
turnover_threshold=turnover_threshold,
turnover_rate_threshold=turnover_rate_threshold,
return_type=return_type,
)

if not filter_entity_ids:
msg = f"{entity_type} no entity_ids selected"
logger.error(msg)
informer.send_message(zvt_config["email_username"], "report_top_stats error", msg)
return

filter_turnover_df = kdata_schema.query_data(
filters=[
kdata_schema.turnover >= turnover_threshold,
kdata_schema.turnover_rate >= turnover_rate_threshold,
],
provider=data_provider,
start_timestamp=target_date,
index="entity_id",
columns=["entity_id", "code"],
)
if filter_entity_ids:
filter_entity_ids = set(filter_entity_ids) & set(filter_turnover_df.index.tolist())
else:
filter_entity_ids = filter_turnover_df.index.tolist()

if not filter_entity_ids:
msg = f"{entity_type} no entity_ids selected"
logger.error(msg)
informer.send_message(zvt_config["email_username"], "report_top_stats error", msg)
return

logger.info(f"{entity_type} filter_entity_ids size: {len(filter_entity_ids)}")
filters = [kdata_schema.entity_id.in_(filter_entity_ids)]
selected = []
for i, period in enumerate(periods):
interval = period
if target_date.weekday() + 1 < interval:
interval = interval + 2
start = next_date(target_date, -interval)
positive_df, negative_df = get_top_performance_entities(
entity_type=entity_type,
start_timestamp=start,
kdata_filters=filters,
pct=1,
show_name=True,
entity_provider=entity_provider,
data_provider=data_provider,
return_type=return_type,
)

if return_type == TopType.positive:
df = positive_df
else:
df = negative_df
selected = selected + df.index[:top_count].tolist()
selected = list(dict.fromkeys(selected))

inform(
informer,
entity_ids=selected,
Expand Down
4 changes: 2 additions & 2 deletions examples/research/top_dragon_tiger.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
category_field: str = "entity_id",
time_field: str = "timestamp",
computing_window: int = None,
keep_window: int = None,
keep_all_timestamp: bool = False,
fill_method: str = "ffill",
effective_number: int = None,
Expand Down Expand Up @@ -80,7 +80,7 @@ def __init__(
level,
category_field,
time_field,
computing_window,
keep_window,
keep_all_timestamp,
fill_method,
effective_number,
Expand Down
23 changes: 5 additions & 18 deletions examples/trader/dragon_and_tiger_trader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from zvt.contract import IntervalLevel
from zvt.contract.factor import Factor, Transformer, Accumulator
from zvt.domain import Stock, DragonAndTiger
from zvt.factors import TargetSelector
from zvt.trader import StockTrader


Expand All @@ -27,7 +26,7 @@ def __init__(
level: Union[str, IntervalLevel] = IntervalLevel.LEVEL_1DAY,
category_field: str = "entity_id",
time_field: str = "timestamp",
computing_window: int = None,
keep_window: int = None,
keep_all_timestamp: bool = False,
fill_method: str = "ffill",
effective_number: int = None,
Expand Down Expand Up @@ -56,7 +55,7 @@ def __init__(
level,
category_field,
time_field,
computing_window,
keep_window,
keep_all_timestamp,
fill_method,
effective_number,
Expand All @@ -75,30 +74,18 @@ def compute_result(self):


class MyTrader(StockTrader):
def init_selectors(
def init_factors(
self, entity_ids, entity_schema, exchanges, codes, start_timestamp, end_timestamp, adjust_type=None
):
myselector = TargetSelector(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
provider="em",
)

myselector.add_factor(
return [
DragonTigerFactor(
entity_ids=entity_ids,
exchanges=exchanges,
codes=codes,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
)
)

self.selectors.append(myselector)
]


if __name__ == "__main__":
Expand Down
77 changes: 24 additions & 53 deletions examples/trader/keep_run_trader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,70 +5,41 @@
from zvt.api.stats import get_top_fund_holding_stocks
from zvt.api.trader_info_api import clear_trader
from zvt.contract import IntervalLevel
from zvt.factors import TargetSelector, GoldCrossFactor, BullFactor
from zvt.factors import GoldCrossFactor, BullFactor
from zvt.trader import StockTrader
from zvt.utils.time_utils import split_time_interval, next_date

logger = logging.getLogger(__name__)


class MultipleLevelTrader(StockTrader):
def init_selectors(
def init_factors(
self, entity_ids, entity_schema, exchanges, codes, start_timestamp, end_timestamp, adjust_type=None
):
start_timestamp = next_date(start_timestamp, -50)

# 周线策略
week_selector = TargetSelector(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=next_date(start_timestamp, -200),
end_timestamp=end_timestamp,
long_threshold=0.7,
level=IntervalLevel.LEVEL_1WEEK,
provider="joinquant",
)
week_bull_factor = BullFactor(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=next_date(start_timestamp, -200),
end_timestamp=end_timestamp,
provider="joinquant",
level=IntervalLevel.LEVEL_1WEEK,
)
week_selector.add_factor(week_bull_factor)

# 日线策略
day_selector = TargetSelector(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
long_threshold=0.7,
level=IntervalLevel.LEVEL_1DAY,
provider="joinquant",
)
day_gold_cross_factor = GoldCrossFactor(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
provider="joinquant",
level=IntervalLevel.LEVEL_1DAY,
)
day_selector.add_factor(day_gold_cross_factor)

# 同时使用日线,周线级别
self.selectors.append(day_selector)
self.selectors.append(week_selector)
return [
BullFactor(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=next_date(start_timestamp, -200),
end_timestamp=end_timestamp,
provider="joinquant",
level=IntervalLevel.LEVEL_1WEEK,
),
GoldCrossFactor(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
provider="joinquant",
level=IntervalLevel.LEVEL_1DAY,
),
]


if __name__ == "__main__":
Expand Down
37 changes: 6 additions & 31 deletions examples/trader/ma_trader.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,16 @@
# -*- coding: utf-8 -*-
from zvt.contract import IntervalLevel
from zvt.factors import CrossMaFactor
from zvt.factors.target_selector import TargetSelector
from zvt.factors.macd import BullFactor

from zvt.trader.trader import StockTrader


class MyMaTrader(StockTrader):
def init_selectors(
def init_factors(
self, entity_ids, entity_schema, exchanges, codes, start_timestamp, end_timestamp, adjust_type=None
):
myselector = TargetSelector(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
provider="joinquant",
)

myselector.add_factor(
return [
CrossMaFactor(
entity_ids=entity_ids,
entity_schema=entity_schema,
Expand All @@ -32,26 +21,14 @@ def init_selectors(
windows=[5, 10],
need_persist=False,
)
)

self.selectors.append(myselector)
]


class MyBullTrader(StockTrader):
def init_selectors(
def init_factors(
self, entity_ids, entity_schema, exchanges, codes, start_timestamp, end_timestamp, adjust_type=None
):
myselector = TargetSelector(
entity_ids=entity_ids,
entity_schema=entity_schema,
exchanges=exchanges,
codes=codes,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
provider="joinquant",
)

myselector.add_factor(
return [
BullFactor(
entity_ids=entity_ids,
entity_schema=entity_schema,
Expand All @@ -61,9 +38,7 @@ def init_selectors(
end_timestamp=end_timestamp,
adjust_type="hfq",
)
)

self.selectors.append(myselector)
]


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 5351207

Please sign in to comment.