Skip to content

Commit

Permalink
save adjust_type for trader_info
Browse files Browse the repository at this point in the history
  • Loading branch information
foolcage committed Dec 26, 2020
1 parent 45529cb commit 7a76cb1
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 23 deletions.
1 change: 1 addition & 0 deletions zvt/domain/trader_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class TraderInfo(TraderBase, Mixin):
level = Column(String(length=32))
real_time = Column(Boolean)
kdata_use_begin_time = Column(Boolean)
kdata_adjust_type = Column(String(length=32))


# account stats of every day
Expand Down
13 changes: 7 additions & 6 deletions zvt/factors/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,11 @@ def __init__(self, slow=26, fast=12, n=9, normal=False) -> None:
self.indicators.append('dea')
self.indicators.append('macd')

# def transform(self, input_df) -> pd.DataFrame:
# macd_df = input_df.groupby(level=0)['close'].apply(
# lambda x: macd(x, slow=self.slow, fast=self.fast, n=self.n, return_type='df', normal=self.normal))
# input_df = pd.concat([input_df, macd_df], axis=1, sort=False)
# return input_df
def transform(self, input_df) -> pd.DataFrame:
macd_df = input_df.groupby(level=0)['close'].apply(
lambda x: macd(x, slow=self.slow, fast=self.fast, n=self.n, return_type='df', normal=self.normal))
input_df = pd.concat([input_df, macd_df], axis=1, sort=False)
return input_df

def transform_one(self, entity_id, df: pd.DataFrame) -> pd.DataFrame:
print(f'transform_one {entity_id} {df}')
Expand Down Expand Up @@ -301,4 +301,5 @@ def calculate_score(df, factor_name, quantile):


# the __all__ is generated
__all__ = ['ma', 'ema', 'macd', 'point_in_range', 'intersect_ranges', 'intersect', 'RankScorer', 'consecutive_count', 'MaTransformer', 'IntersectTransformer', 'MaAndVolumeTransformer', 'MacdTransformer', 'QuantileScorer']
__all__ = ['ma', 'ema', 'macd', 'point_in_range', 'intersect_ranges', 'intersect', 'RankScorer', 'consecutive_count',
'MaTransformer', 'IntersectTransformer', 'MaAndVolumeTransformer', 'MacdTransformer', 'QuantileScorer']
24 changes: 11 additions & 13 deletions zvt/trader/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def on_trading_open(self, timestamp):
if is_same_date(timestamp, self.start_timestamp):
return
self.account = self.load_account()
self.logger.info('on_trading_open:{},current_account:{}'.format(timestamp, self.account))
self.logger.info('on_trading_open:{},current_account:{}'.format(timestamp, self.account.__dict__))

def on_trading_error(self, timestamp, error):
pass
Expand Down Expand Up @@ -464,11 +464,9 @@ def order(self, entity_id, current_price, current_timestamp, order_amount=0, ord
# 买的数量
order_amount = order_money // cost

if order_amount < 100:
if self.rich_mode:
self.input_money()
else:
raise NotEnoughMoneyError()
if order_amount < 1:
self.logger.error(f'invalid order_money:{order_money}, cost:{cost}, order_amount:{order_amount}')
return

self.update_position(current_position, order_amount, current_price, order_type,
current_timestamp)
Expand All @@ -487,11 +485,9 @@ def order(self, entity_id, current_price, current_timestamp, order_amount=0, ord

order_amount = order_money // cost

if order_amount < 100:
if self.rich_mode:
self.input_money()
else:
raise NotEnoughMoneyError()
if order_amount < 1:
self.logger.error(f'invalid order_money:{order_money}, cost:{cost}, order_amount:{order_amount}')
return
self.update_position(current_position, order_amount, current_price, order_type,
current_timestamp)
else:
Expand Down Expand Up @@ -538,9 +534,10 @@ def order(self, entity_id, current_price, current_timestamp, order_amount=0, ord
# 买的数量
order_amount = want_pay // cost

if order_amount < 100:
if order_amount < 1:
if self.rich_mode:
self.input_money()
order_amount = (self.account.cash * order_pct) // cost
else:
raise NotEnoughMoneyError()
self.update_position(current_position, order_amount, current_price, order_type,
Expand All @@ -555,9 +552,10 @@ def order(self, entity_id, current_price, current_timestamp, order_amount=0, ord

order_amount = want_pay // cost

if order_amount < 100:
if order_amount < 1:
if self.rich_mode:
self.input_money()
order_amount = (self.account.cash * order_pct) // cost
else:
raise NotEnoughMoneyError()

Expand Down
6 changes: 5 additions & 1 deletion zvt/trader/trader.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def __init__(self,
self.kdata_use_begin_time = kdata_use_begin_time
self.draw_result = draw_result
self.rich_mode = rich_mode

if type(adjust_type) is str:
adjust_type = AdjustType(adjust_type)
self.adjust_type = adjust_type

self.account_service = SimAccountService(entity_schema=self.entity_schema,
Expand Down Expand Up @@ -145,7 +148,8 @@ def on_start(self):
provider=self.provider,
level=self.level.value,
real_time=self.real_time,
kdata_use_begin_time=self.kdata_use_begin_time)
kdata_use_begin_time=self.kdata_use_begin_time,
kdata_adjust_type=self.adjust_type.value)
self.session.add(sim_account)
self.session.commit()

Expand Down
3 changes: 2 additions & 1 deletion zvt/ui/apps/trader_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,5 +141,6 @@ def update_target_signals(entity_id, start_date, end_date, trader_index):
return dcc.Graph(
id=f'{entity_id}-signals',
figure=get_trading_signals_figure(order_reader=order_readers[trader_index], entity_id=entity_id,
start_timestamp=start_date, end_timestamp=end_date))
start_timestamp=start_date, end_timestamp=end_date,
adjust_type=traders[trader_index].kdata_adjust_type))
raise dash.PreventUpdate()
5 changes: 3 additions & 2 deletions zvt/ui/components/dcc_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ def order_type_flag(order_type):
def get_trading_signals_figure(order_reader: OrderReader,
entity_id: str,
start_timestamp=None,
end_timestamp=None):
end_timestamp=None,
adjust_type=None):
entity_type, _, _ = decode_entity_id(entity_id)

data_schema = get_kdata_schema(entity_type=entity_type, level=order_reader.level)
data_schema = get_kdata_schema(entity_type=entity_type, level=order_reader.level, adjust_type=adjust_type)
if not start_timestamp:
start_timestamp = order_reader.start_timestamp
if not end_timestamp:
Expand Down

0 comments on commit 7a76cb1

Please sign in to comment.