Skip to content

Commit

Permalink
TST: Don't modify master security lists directory during tests
Browse files Browse the repository at this point in the history
Rather than drop files temporarily into the master security lists
directory during unit tests, create temporary directories for the
tests. This avoids issues when the tests are being run at the same
time as other code that uses the real security lists data.
  • Loading branch information
Jonathan Kamens committed Apr 30, 2015
1 parent 00ea7b0 commit ca0f906
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 27 deletions.
22 changes: 6 additions & 16 deletions tests/test_security_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from zipline.errors import TradingControlViolation
from zipline.sources import SpecificEquityTrades
from zipline.utils.test_utils import (
setup_logger, add_security_data, remove_security_data_directory)
setup_logger, security_list_copy, add_security_data)
from zipline.utils import factory
from zipline.utils.security_list import (
SecurityListSet, load_from_directory)
Expand Down Expand Up @@ -105,26 +105,22 @@ def get_datetime():
def test_security_add(self):
def get_datetime():
return datetime(2015, 1, 27, tzinfo=pytz.utc)
try:
with security_list_copy():
add_security_data(['AAPL', 'GOOG'], [])
rl = SecurityListSet(get_datetime)
self.assertIn("AAPL", rl.leveraged_etf_list)
self.assertIn("GOOG", rl.leveraged_etf_list)
self.assertIn("BZQ", rl.leveraged_etf_list)
self.assertIn("URTY", rl.leveraged_etf_list)
finally:
remove_security_data_directory()

def test_security_add_delete(self):
try:
with security_list_copy():
def get_datetime():
return datetime(2015, 1, 27, tzinfo=pytz.utc)
add_security_data([], ['BZQ', 'URTY'])
rl = SecurityListSet(get_datetime)
self.assertNotIn("BZQ", rl.leveraged_etf_list)
self.assertNotIn("URTY", rl.leveraged_etf_list)
finally:
remove_security_data_directory()

def test_algo_without_rl_violation_via_check(self):
sim_params = factory.create_simulation_parameters(
Expand Down Expand Up @@ -228,7 +224,7 @@ def test_algo_with_rl_violation_cumulative(self):
start=list(
LEVERAGED_ETFS.keys())[0] + timedelta(days=7), num_days=4)

try:
with security_list_copy():
add_security_data(['AAPL'], [])
trade_history = factory.create_trade_history(
'BZQ',
Expand All @@ -244,11 +240,9 @@ def test_algo_with_rl_violation_cumulative(self):
algo.run(self.source)

self.check_algo_exception(algo, ctx, 0)
finally:
remove_security_data_directory()

def test_algo_without_rl_violation_after_delete(self):
try:
with security_list_copy():
# add a delete statement removing bzq
# write a new delete statement file to disk
add_security_data([], ['BZQ'])
Expand All @@ -266,11 +260,9 @@ def test_algo_without_rl_violation_after_delete(self):
algo = RestrictedAlgoWithoutCheck(
sid='BZQ', sim_params=sim_params)
algo.run(self.source)
finally:
remove_security_data_directory()

def test_algo_with_rl_violation_after_add(self):
try:
with security_list_copy():
add_security_data(['AAPL'], [])
sim_params = factory.create_simulation_parameters(
start=self.trading_day_before_first_kd, num_days=4)
Expand All @@ -288,8 +280,6 @@ def test_algo_with_rl_violation_after_add(self):
algo.run(self.source)

self.check_algo_exception(algo, ctx, 2)
finally:
remove_security_data_directory()

def check_algo_exception(self, algo, ctx, expected_order_count):
self.assertEqual(algo.order_count, expected_order_count)
Expand Down
35 changes: 24 additions & 11 deletions zipline/utils/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from contextlib import contextmanager
from logbook import FileHandler
from mock import patch
from zipline.finance.blotter import ORDER_STATUS
from zipline.utils.security_list import SECURITY_LISTS_DIR
from zipline.utils import security_list

from six import itervalues

import os
import pandas as pd
import shutil
import tempfile


def to_utc(time_str):
Expand Down Expand Up @@ -115,15 +117,34 @@ def nullctx():
Null context manager. Useful for conditionally adding a contextmanager in
a single line, e.g.:
with SomeContextManager() if some_expr else nullcontext:
with SomeContextManager() if some_expr else nullctx():
do_stuff()
"""
yield


@contextmanager
def security_list_copy():
old_dir = security_list.SECURITY_LISTS_DIR
new_dir = tempfile.mkdtemp()
try:
for subdir in os.listdir(old_dir):
shutil.copytree(os.path.join(old_dir, subdir),
os.path.join(new_dir, subdir))
with patch.object(security_list, 'SECURITY_LISTS_DIR', new_dir), \
patch.object(security_list, 'using_copy', True,
create=True):
yield
finally:
shutil.rmtree(new_dir, True)


def add_security_data(adds, deletes):
if not hasattr(security_list, 'using_copy'):
raise Exception('add_security_data must be used within '
'security_list_copy context')
directory = os.path.join(
SECURITY_LISTS_DIR,
security_list.SECURITY_LISTS_DIR,
"leveraged_etf_list/20150127/20150125"
)
if not os.path.exists(directory):
Expand All @@ -138,11 +159,3 @@ def add_security_data(adds, deletes):
for sym in adds:
f.write(sym)
f.write('\n')


def remove_security_data_directory():
directory = os.path.join(
SECURITY_LISTS_DIR,
"leveraged_etf_list/20150127/"
)
shutil.rmtree(directory)

0 comments on commit ca0f906

Please sign in to comment.