Skip to content

Commit

Permalink
BUG: fixed issue scrtlabs#111 related to positions update after resto…
Browse files Browse the repository at this point in the history
…ring algo state
  • Loading branch information
fredfortier committed Jan 7, 2018
1 parent 13de3e6 commit 5d97089
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 36 deletions.
6 changes: 3 additions & 3 deletions catalyst/examples/mean_reversion_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def initialize(context):
context.base_price = None
context.current_day = None

context.RSI_OVERSOLD = 55
context.RSI_OVERBOUGHT = 65
context.RSI_OVERSOLD = 35
context.RSI_OVERBOUGHT = 50
context.CANDLE_SIZE = '5T'

context.start_time = time.time()
Expand Down Expand Up @@ -248,7 +248,7 @@ def analyze(context=None, perf=None):

if live:
run_algorithm(
capital_base=0.03,
capital_base=0.025,
initialize=initialize,
handle_data=handle_data,
analyze=analyze,
Expand Down
99 changes: 66 additions & 33 deletions catalyst/exchange/exchange_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,16 +391,20 @@ def interrupt_algorithm(self):
'before exiting the algorithm.')

algo_folder = get_algo_folder(self.algo_namespace)
folder = join(algo_folder, 'daily_perf')
folder = join(algo_folder, 'daily_performance')
files = [f for f in listdir(folder) if isfile(join(folder, f))]

daily_perf_list = []
for item in files:
filename = join(folder, item)

with open(filename, 'rb') as handle:
daily_perf_list.append(pickle.load(handle))
perf_period = pickle.load(handle)
perf_period_dict = perf_period.to_dict()
daily_perf_list.append(perf_period_dict)

stats = pd.DataFrame(daily_perf_list)
stats.set_index('period_close', drop=False, inplace=True)

self.analyze(stats)

Expand Down Expand Up @@ -460,44 +464,62 @@ def _create_clock(self):

return self._clock

def get_generator(self):
if self.trading_client is not None:
return self.trading_client.transform()
def _init_trading_client(self):
"""
This replaces Ziplines `_create_generator` method. The main difference
is that we are restoring performance tracker objects if available.
This allows us to stop/start algos without loosing their state.
perf = None
"""
if self.perf_tracker is None:
# Note from the Zipline dev:
# HACK: When running with the `run` method, we set perf_tracker to
# None so that it will be overwritten here.
tracker = self.perf_tracker = PerformanceTracker(
sim_params=self.sim_params,
trading_calendar=self.trading_calendar,
env=self.trading_environment,
)

# Set the dt initially to the period start by forcing it to change.
self.on_dt_changed(self.sim_params.start_session)

new_position_tracker = tracker.position_tracker
tracker.position_tracker = None

# Unpacking the perf_tracker and positions if available
perf = get_algo_object(
cum_perf = get_algo_object(
algo_name=self.algo_namespace,
key='cumulative_performance',
)
if cum_perf is not None:
tracker.cumulative_performance = cum_perf
# Ensure single common position tracker
tracker.position_tracker = cum_perf.position_tracker

today = pd.Timestamp.utcnow().floor('1D')
todays_perf = get_algo_object(
algo_name=self.algo_namespace,
key=today.strftime('%Y-%m-%d'),
rel_path='daily_performance',
)
if todays_perf is not None:
# Ensure single common position tracker
if tracker.position_tracker is not None:
todays_perf.position_tracker = tracker.position_tracker
else:
tracker.position_tracker = todays_perf.position_tracker

tracker.todays_performance = todays_perf

if tracker.position_tracker is None:
# Use a new position_tracker if not is found in the state
tracker.position_tracker = new_position_tracker

if not self.initialized:
# Calls the initialize function of the algorithm
self.initialize(*self.initialize_args, **self.initialize_kwargs)
self.initialized = True

# Call the simulation trading algorithm for side-effects:
# it creates the perf tracker
# TradingAlgorithm._create_generator(self, self.sim_params)
if perf is not None:
tracker.cumulative_performance = perf

period = self.perf_tracker.todays_performance
period.starting_cash = perf.ending_cash
period.starting_exposure = perf.ending_exposure
period.starting_value = perf.ending_value
# This does not seem to get updated correctly
period.position_tracker = perf.position_tracker

self.trading_client = ExchangeAlgorithmExecutor(
algo=self,
sim_params=self.sim_params,
Expand All @@ -507,6 +529,11 @@ def get_generator(self):
restrictions=self.restrictions,
universe_func=self._calculate_universe,
)

def get_generator(self):
if self.trading_client is None:
self._init_trading_client()

return self.trading_client.transform()

def updated_portfolio(self):
Expand Down Expand Up @@ -677,11 +704,12 @@ def handle_data(self, data):
self.frame_stats = list()

self.performance_needs_update = False
new_orders = self.perf_tracker.todays_performance.orders_by_id.keys()
if new_orders != self._last_orders:
orders = self.perf_tracker.todays_performance.orders_by_id.keys()
if orders != self._last_orders:
self.performance_needs_update = True

self._last_orders = copy.deepcopy(new_orders)
# Saving current orders to detect changes in the next frame
self._last_orders = copy.deepcopy(orders)

if self.performance_needs_update:
self.perf_tracker.update_performance()
Expand All @@ -698,7 +726,7 @@ def handle_data(self, data):
self.portfolio_needs_update = False

log.info(
'got totals from exchanges, cash: {} positions: {}'.format(
'portfolio balances, cash: {}, positions: {}'.format(
cash, positions_value
)
)
Expand All @@ -710,18 +738,29 @@ def handle_data(self, data):
# every bar no matter if the algorithm places an order or not.
self.validate_account_controls()

self._save_algo_state(data)
self.current_day = data.current_dt.floor('1D')

def _save_algo_state(self, data):
today = data.current_dt.floor('1D')
try:
self._save_stats_csv(self._process_stats(data))
except Exception as e:
log.warn('unable to calculate performance: {}'.format(e))

log.debug('saving cumulative performance object')
save_algo_object(
algo_name=self.algo_namespace,
key='cumulative_performance',
obj=self.perf_tracker.cumulative_performance,
)

self.current_day = data.current_dt.floor('1D')
log.debug('saving todays performance object')
save_algo_object(
algo_name=self.algo_namespace,
key=today.strftime('%Y-%m-%d'),
obj=self.perf_tracker.todays_performance,
rel_path='daily_performance'
)

def _process_stats(self, data):
today = data.current_dt.floor('1D')
Expand Down Expand Up @@ -764,12 +803,6 @@ def _process_stats(self, data):
start_dt=today,
end_dt=data.current_dt
)
save_algo_object(
algo_name=self.algo_namespace,
key=today.strftime('%Y-%m-%d'),
obj=daily_stats,
rel_path='daily_perf'
)

return recorded_cols

Expand Down

0 comments on commit 5d97089

Please sign in to comment.