Skip to content

Commit

Permalink
1)add multiple stock trader examples 2)marshal factors of trader 3)dr…
Browse files Browse the repository at this point in the history
…aw multiple securities with indicators and trading signals 4)some refactor

Former-commit-id: 345b5f5
  • Loading branch information
foolcage committed Jun 28, 2019
1 parent a886d75 commit afd1eca
Show file tree
Hide file tree
Showing 19 changed files with 284 additions and 137 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ marshmallow-sqlalchemy
# pyecharts == 1.1.0
ccxt == 1.17.191
dash==0.43.0
dash-daq==0.1.0
dash-daq==0.1.0
simplejson==3.16.0
21 changes: 21 additions & 0 deletions tests/api/test_technical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
from ..context import init_context

init_context()

from zvt.api.technical import get_securities, get_securities_in_blocks


def test_basic_get_securities():
items = get_securities(security_type='stock', provider='eastmoney')
print(items)
items = get_securities(security_type='index', provider='eastmoney')
print(items)
items = get_securities(security_type='coin', provider='ccxt')
print(items)


def test_get_security_blocks():
hs300 = get_securities_in_blocks(block_names=['HS300_'])
assert len(hs300) == 300
assert 'stock_sz_000338' in hs300
19 changes: 12 additions & 7 deletions tests/selectors/test_selector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
from zvt.selectors.zvt_selector import TechnicalSelector
from zvt.selectors.examples.technical_selector import TechnicalSelector
from zvt.utils.pd_utils import df_is_not_null
from ..context import init_context

init_context()
Expand All @@ -17,13 +18,17 @@ def test_technical_selector():

print(selector.get_result_df())

assert 'stock_sz_000338' in selector.get_targets('2019-06-04')['security_id'].tolist()
assert 'stock_sz_000338' in selector.get_targets('2019-06-04')['security_id'].tolist()
assert 'stock_sz_002572' not in selector.get_targets('2019-06-04')['security_id'].tolist()
assert 'stock_sz_002572' not in selector.get_targets('2019-06-04')['security_id'].tolist()
targets = selector.get_targets('2019-06-04')
if df_is_not_null(targets):
assert 'stock_sz_000338' not in targets['security_id'].tolist()
assert 'stock_sz_000338' not in targets['security_id'].tolist()
assert 'stock_sz_002572' not in targets['security_id'].tolist()
assert 'stock_sz_002572' not in targets['security_id'].tolist()

selector.move_on(timeout=0)

assert 'stock_sz_000338' in selector.get_targets('2019-06-17')['security_id'].tolist()
targets = selector.get_targets('2019-06-19')
if df_is_not_null(targets):
assert 'stock_sz_000338' in targets['security_id'].tolist()

assert 'stock_sz_002572' in selector.get_targets('2019-06-17')['security_id'].tolist()
assert 'stock_sz_002572' not in targets['security_id'].tolist()
10 changes: 5 additions & 5 deletions zvt/api/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from zvt.domain import get_db_session, CompanyType, TradingLevel, get_store_category
from zvt.domain.coin_meta import Coin
from zvt.domain.quote import *
from zvt.utils.pd_utils import index_df_with_time
from zvt.utils.pd_utils import index_df, df_is_not_null
from zvt.utils.time_utils import to_pd_timestamp, now_pd_timestamp
from zvt.utils.time_utils import to_time_str, TIME_FORMAT_DAY, TIME_FORMAT_ISO8601

Expand Down Expand Up @@ -127,7 +127,7 @@ def get_group(provider, data_schema, column, group_func=func.count, session=None

def get_data(data_schema, security_list=None, security_id=None, codes=None, level=None, provider='eastmoney',
columns=None, return_type='df', start_timestamp=None, end_timestamp=None,
filters=None, session=None, order=None, limit=None):
filters=None, session=None, order=None, limit=None, index='timestamp', index_is_time=True):
local_session = False
if not session:
store_category = get_store_category(data_schema)
Expand Down Expand Up @@ -165,12 +165,12 @@ def get_data(data_schema, security_list=None, security_id=None, codes=None, leve

if return_type == 'df':
df = pd.read_sql(query.statement, query.session.bind)
if not df.empty:
return index_df_with_time(df, drop=False)
if df_is_not_null(df):
return index_df(df, drop=False, index=index, index_is_time=index_is_time)
elif return_type == 'domain':
return query.all()
elif return_type == 'dict':
return [item.to_json() for item in query.all()]
return [item.__dict__ for item in query.all()]
except Exception:
raise
finally:
Expand Down
88 changes: 50 additions & 38 deletions zvt/api/technical.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# -*- coding: utf-8 -*-
from typing import List, Union

import pandas as pd
from sqlalchemy.orm import Session

from zvt.api.common import common_filter, get_data, decode_security_id
from zvt.api.common import get_data, decode_security_id
from zvt.api.common import get_security_schema, get_kdata_schema
from zvt.domain import get_db_engine, get_db_session, TradingLevel, Provider, get_store_category
from zvt.domain import get_db_engine, TradingLevel, Provider, get_store_category, SecurityType, get_db_session, \
StoreCategory, Index
from zvt.utils.pd_utils import df_is_not_null


Expand Down Expand Up @@ -33,49 +36,58 @@ def df_to_db(df, data_schema, provider):
df.to_sql(data_schema.__tablename__, db_engine, index=False, if_exists='append')


def get_securities(security_type='stock', exchanges=None, codes=None, columns=None,
return_type='df', session=None, start_timestamp=None, end_timestamp=None,
filters=None, order=None, limit=None, provider='eastmoney'):
local_session = False
def get_securities_in_blocks(block_names=['HS300_'], block_category='concept', provider='eastmoney'):
session = get_db_session(provider=provider, store_category=StoreCategory.meta)

filters = [Index.category == block_category]
name_filters = None
for block_name in block_names:
if name_filters:
name_filters |= (Index.name == block_name)
else:
name_filters = (Index.name == block_name)
filters.append(name_filters)
blocks = get_securities(security_type='index', provider='eastmoney',
filters=filters,
return_type='domain', session=session)
securities = []
for block in blocks:
securities += [item.stock_id for item in block.stocks]

return securities


def get_securities(security_list: List[str] = None,
security_type: Union[SecurityType, str] = 'stock',
exchanges: List[str] = None,
codes: List[str] = None,
columns: List = None,
return_type: str = 'df',
session: Session = None,
start_timestamp: Union[str, pd.Timestamp] = None,
end_timestamp: Union[str, pd.Timestamp] = None,
filters: List = None,
order: object = None,
limit: int = None,
provider: Union[str, Provider] = 'eastmoney',
index: str = 'code',
index_is_time: bool = False) -> object:
data_schema = get_security_schema(security_type)
store_category = get_store_category(data_schema=data_schema)

if not session:
session = get_db_session(provider=provider, store_category=store_category)
local_session = True

if not order:
order = data_schema.code.asc()

try:
if columns:
query = session.query(*columns)
if exchanges:
if filters:
filters.append(data_schema.exchange.in_(exchanges))
else:
query = session.query(data_schema)

# filters
if exchanges:
query = query.filter(data_schema.exchange.in_(exchanges))
if codes:
query = query.filter(data_schema.code.in_(codes))

query = common_filter(query, data_schema=data_schema, start_timestamp=start_timestamp,
end_timestamp=end_timestamp, filters=filters, order=order, limit=limit)

if return_type == 'df':
# TODO:add indices info
return pd.read_sql(query.statement, query.session.bind)
elif return_type == 'domain':
return query.all()
elif return_type == 'dict':
return [item.to_json() for item in query.all()]
except Exception as e:

raise
finally:
if local_session:
session.close()
filters = [data_schema.exchange.in_(exchanges)]

return get_data(data_schema=data_schema, security_list=security_list, security_id=None, codes=codes, level=None,
provider=provider,
columns=columns, return_type=return_type, start_timestamp=start_timestamp,
end_timestamp=end_timestamp, filters=filters,
session=session, order=order, limit=limit, index=index, index_is_time=index_is_time)


def get_kdata(security_id, level=TradingLevel.LEVEL_1DAY.value, provider='eastmoney', columns=None,
Expand Down
30 changes: 19 additions & 11 deletions zvt/charts/dcc_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import dash_core_components as dcc
import plotly.graph_objs as go
import simplejson

from zvt.api.common import decode_security_id
from zvt.domain import Provider, business
Expand Down Expand Up @@ -75,35 +76,42 @@ def get_trader_detail_figures(trader_domain: business.Trader,
'layout': account_layout
}))

df_orders = order_reader.get_data_df()
order_reader.move_on(timeout=0)
df_orders = order_reader.get_data_df().copy()

if df_is_not_null(df_orders):
grouped = df_orders.groupby('security_id')

for security_id, order_df in grouped:
security_type, _, _ = decode_security_id(security_id)
# TODO:just show the indicators used by the trader

indicators = []
indicators_param = []
indicator_cols = []
if trader_domain.technical_factors:
tech_factors = simplejson.loads(trader_domain.technical_factors)
for factor in tech_factors:
indicators += factor['indicators']
indicators_param += factor['indicators_param']
indicator_cols += factor['indicator_cols']

security_factor = TechnicalFactor(security_type=security_type, security_list=[security_id],
start_timestamp=trader_domain.start_timestamp,
end_timestamp=trader_domain.end_timestamp,
level=trader_domain.level, provider=trader_domain.provider,
indicators=['ma', 'ma'],
indicators_param=[{'window': 5}, {'window': 10}]
)

# if df_is_not_null(security_factor.get_data_df()):
# print(security_factor.get_data_df().tail())
indicators=indicators,
indicators_param=indicators_param)

# generate the annotation df
order_reader.move_on(timeout=0)
df = order_reader.get_data_df().copy()
df = order_df.copy()
if df_is_not_null(df):
df['value'] = df['order_price']
df['flag'] = df['order_type'].apply(lambda x: order_type_flag(x))
df['color'] = df['order_type'].apply(lambda x: order_type_color(x))
print(df.tail())

data, layout = security_factor.draw_with_indicators(render=None, annotation_df=df)
data, layout = security_factor.draw_with_indicators(render=None, annotation_df=df,
indicators=indicator_cols)

graph_list.append(
dcc.Graph(
Expand Down
2 changes: 2 additions & 0 deletions zvt/domain/business.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class Trader(BusinessBase):
kdata_use_begin_time = Column(Boolean)
# TODO:inspect selector/factors
selectors = Column(String(length=1024))
factors = Column(String(length=1024))
technical_factors = Column(String(length=1024))


# 一天只有一条记录
Expand Down
9 changes: 4 additions & 5 deletions zvt/factors/factor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
import enum
import logging
from typing import List, Union

import pandas as pd
Expand Down Expand Up @@ -82,17 +81,17 @@ def get_result_df(self):
def get_depth_df(self):
return self.depth_df

def draw_depth(self, figure=go.Scatter, mode='lines', value_field='close', render='html', file_name=None,
def draw_depth(self, figures=[go.Scatter], modes=['lines'], value_fields=['close'], render='html', file_name=None,
width=None, height=None, title=None, keep_ui_state=True):
chart = Chart(category_field=self.category_field, figures=figure, modes=mode, value_fields=value_field,
chart = Chart(category_field=self.category_field, figures=figures, modes=modes, value_fields=value_fields,
render=render, file_name=file_name,
width=width, height=height, title=title, keep_ui_state=keep_ui_state)
chart.set_data_df(self.depth_df)
chart.draw()

def draw_result(self, figure=go.Scatter, mode='lines', value_field='close', render='html', file_name=None,
def draw_result(self, figures=[go.Scatter], modes=['lines'], value_fields=['score'], render='html', file_name=None,
width=None, height=None, title=None, keep_ui_state=True):
chart = Chart(category_field=self.category_field, figures=figure, modes=mode, value_fields=value_field,
chart = Chart(category_field=self.category_field, figures=figures, modes=modes, value_fields=value_fields,
render=render, file_name=file_name,
width=width, height=height, title=title, keep_ui_state=keep_ui_state)
chart.set_data_df(self.result_df)
Expand Down
2 changes: 1 addition & 1 deletion zvt/factors/finance_factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ def __init__(self,
end_timestamp='2018-12-31',
codes=['000338', '000778', '601318'])

factor.draw_result(value_field='op_income_growth_yoy')
factor.draw_result(value_fields=['op_income_growth_yoy', 'rota'])
44 changes: 31 additions & 13 deletions zvt/factors/technical_factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@


class TechnicalFactor(FilterFactor):

def __json__(self):
return {
'indicators': self.indicators,
'indicators_param': self.indicators_param,
'indicator_cols': list(self.indicator_cols)
}

for_json = __json__ # supported by simplejson

def __init__(self,
security_list: List[str] = None,
security_type: Union[str, SecurityType] = SecurityType.stock,
Expand Down Expand Up @@ -170,17 +180,25 @@ def on_category_data_added(self, category, added_data: pd.DataFrame):
self.compute()


class BullFactor(TechnicalFactor):
def __init__(self, security_list: List[str] = None, security_type: Union[str, SecurityType] = SecurityType.stock,
exchanges: List[str] = ['sh', 'sz'], codes: List[str] = None,
the_timestamp: Union[str, pd.Timestamp] = None, start_timestamp: Union[str, pd.Timestamp] = None,
end_timestamp: Union[str, pd.Timestamp] = None, columns: List = None, filters: List = None,
provider: Union[str, Provider] = 'joinquant', level: TradingLevel = TradingLevel.LEVEL_1DAY,
real_time: bool = False, refresh_interval: int = 10, category_field: str = 'security_id',
indicators=['macd'], indicators_param=[{'slow': 26, 'fast': 12, 'n': 9}],
valid_window=26) -> None:
super().__init__(security_list, security_type, exchanges, codes, the_timestamp, start_timestamp, end_timestamp,
columns, filters, provider, level, real_time, refresh_interval, category_field, indicators,
indicators_param, valid_window)

def compute(self):
super().compute()
s = (self.depth_df['diff'] > 0) & (self.depth_df['dea'] > 0)
self.result_df = s.to_frame(name='score')


if __name__ == '__main__':
factor = TechnicalFactor(codes=['000338'], start_timestamp='2018-01-01', end_timestamp='2019-02-01',
indicators=['ma', 'ma'],
indicators_param=[{'window': 5}, {'window': 10}])
factor.draw_with_indicators()

# factor1 = CrossMaFactor(security_list=['coin_binance_EOS/USDT'],
# security_type=SecurityType.coin,
# start_timestamp='2019-01-01',
# end_timestamp='2019-06-05', level=TradingLevel.LEVEL_5MIN, provider='ccxt')
# factor1.compute()
# factor1.draw()
# factor1.draw_depth(value_field='ma10')
# factor1.draw_result(value_field='score')
factor = BullFactor(codes=['000338'], start_timestamp='2018-01-01', end_timestamp='2019-02-01')
factor.draw_result()
2 changes: 1 addition & 1 deletion zvt/recorders/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(self,
self.codes = codes

self.meta_session = get_db_session(provider=self.meta_provider, store_category=self.meta_category)
# init the security listo
# init the security list
self.securities = get_securities(session=self.meta_session,
security_type=self.security_type,
exchanges=self.exchanges,
Expand Down
1 change: 1 addition & 0 deletions zvt/selectors/examples/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# -*- coding: utf-8 -*-
18 changes: 18 additions & 0 deletions zvt/selectors/examples/fundamental_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-
from zvt.factors.finance_factor import FinanceGrowthFactor
from zvt.selectors.selector import TargetSelector


class FundamentalSelector(TargetSelector):
def init_factors(self, security_list, security_type, exchanges, codes, the_timestamp, start_timestamp,
end_timestamp):
factor = FinanceGrowthFactor(security_list=security_list, security_type=security_type, exchanges=exchanges,
codes=codes, the_timestamp=the_timestamp, start_timestamp=start_timestamp,
end_timestamp=end_timestamp, keep_all_timestamp=True, provider=self.provider)
self.score_factors.append(factor)


if __name__ == '__main__':
selector: TargetSelector = FundamentalSelector(start_timestamp='2018-01-01', end_timestamp='2019-06-30')
selector.run()
print(selector.get_result_df())
Loading

0 comments on commit afd1eca

Please sign in to comment.