Skip to content

Commit

Permalink
TST: Use testing market data with run_algorithm
Browse files Browse the repository at this point in the history
so env doesn't need to download it
  • Loading branch information
richafrank committed May 18, 2017
1 parent 3ca5a15 commit 8734224
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 26 deletions.
27 changes: 17 additions & 10 deletions tests/test_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
)
from zipline.testing import (
FakeDataPortal,
copy_market_data,
create_daily_df_for_asset,
create_data_portal,
create_data_portal_from_trade_history,
Expand All @@ -99,6 +100,7 @@
tmp_trading_env,
to_utc,
trades_by_sid_to_dfs,
tmp_dir,
)
from zipline.testing import RecordBatchBlotter
from zipline.testing.fixtures import (
Expand Down Expand Up @@ -4760,13 +4762,18 @@ def check_panels():
check_panels()
price_record.loc[:] = np.nan

run_algorithm(
start=start_dt,
end=end_dt,
capital_base=1,
initialize=initialize,
handle_data=handle_data,
data_frequency=data_frequency,
data=panel
)
check_panels()
with tmp_dir() as tmpdir:
root = tmpdir.getpath('example_data/root')
copy_market_data(self.MARKET_DATA_DIR, root)

run_algorithm(
start=start_dt,
end=end_dt,
capital_base=1,
initialize=initialize,
handle_data=handle_data,
data_frequency=data_frequency,
data=panel,
environ={'ZIPLINE_ROOT': root},
)
check_panels()
8 changes: 6 additions & 2 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@

from zipline import examples
from zipline.data.bundles import register, unregister
from zipline.testing import test_resource_path
from zipline.testing.fixtures import WithTmpDir, ZiplineTestCase
from zipline.testing import test_resource_path, copy_market_data
from zipline.testing.fixtures import WithTmpDir, ZiplineTestCase, \
WithTradingEnvironment
from zipline.testing.predicates import assert_equal
from zipline.utils.cache import dataframe_cache

Expand Down Expand Up @@ -53,6 +54,9 @@ def init_class_fixtures(cls):
serialization='pickle',
)

copy_market_data(WithTradingEnvironment.MARKET_DATA_DIR,
cls.tmpdir.getpath('example_data/root'))

@parameterized.expand(examples.EXAMPLE_MODULES)
def test_example(self, example_name):
actual_perf = examples.run_example(
Expand Down
29 changes: 18 additions & 11 deletions zipline/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ def last_modified_time(path):
return pd.Timestamp(os.path.getmtime(path), unit='s', tz='UTC')


def get_data_filepath(name):
def get_data_filepath(name, environ=None):
"""
Returns a handle to data file.
Creates containing directory, if needed.
"""
dr = data_root()
dr = data_root(environ)

if not os.path.exists(dr):
os.makedirs(dr)
Expand Down Expand Up @@ -91,7 +91,8 @@ def has_data_for_dates(series_or_df, first_date, last_date):
return (first <= first_date) and (last >= last_date)


def load_market_data(trading_day=None, trading_days=None, bm_symbol='^GSPC'):
def load_market_data(trading_day=None, trading_days=None, bm_symbol='^GSPC',
environ=None):
"""
Load benchmark returns and treasury yield curves for the given calendar and
benchmark symbol.
Expand Down Expand Up @@ -162,19 +163,22 @@ def load_market_data(trading_day=None, trading_days=None, bm_symbol='^GSPC'):
# We need the trading_day to figure out the close prior to the first
# date so that we can compute returns for the first date.
trading_day,
environ,
)
tc = ensure_treasury_data(
bm_symbol,
first_date,
last_date,
now,
environ,
)
benchmark_returns = br[br.index.slice_indexer(first_date, last_date)]
treasury_curves = tc[tc.index.slice_indexer(first_date, last_date)]
return benchmark_returns, treasury_curves


def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day):
def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day,
environ=None):
"""
Ensure we have benchmark data for `symbol` from `first_date` to `last_date`
Expand Down Expand Up @@ -204,7 +208,8 @@ def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day):
path.
"""
filename = get_benchmark_filename(symbol)
data = _load_cached_data(filename, first_date, last_date, now, 'benchmark')
data = _load_cached_data(filename, first_date, last_date, now, 'benchmark',
environ)
if data is not None:
return data

Expand All @@ -218,7 +223,7 @@ def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day):
first_date - trading_day,
last_date,
)
data.to_csv(get_data_filepath(filename))
data.to_csv(get_data_filepath(filename, environ))
except (OSError, IOError, HTTPError):
logger.exception('failed to cache the new benchmark returns')
raise
Expand All @@ -227,7 +232,7 @@ def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day):
return data


def ensure_treasury_data(symbol, first_date, last_date, now):
def ensure_treasury_data(symbol, first_date, last_date, now, environ=None):
"""
Ensure we have treasury data from treasury module associated with
`symbol`.
Expand Down Expand Up @@ -259,7 +264,8 @@ def ensure_treasury_data(symbol, first_date, last_date, now):
)
first_date = max(first_date, loader_module.earliest_possible_date())

data = _load_cached_data(filename, first_date, last_date, now, 'treasury')
data = _load_cached_data(filename, first_date, last_date, now, 'treasury',
environ)
if data is not None:
return data

Expand All @@ -269,22 +275,23 @@ def ensure_treasury_data(symbol, first_date, last_date, now):

try:
data = loader_module.get_treasury_data(first_date, last_date)
data.to_csv(get_data_filepath(filename))
data.to_csv(get_data_filepath(filename, environ))
except (OSError, IOError, HTTPError):
logger.exception('failed to cache treasury data')
if not has_data_for_dates(data, first_date, last_date):
logger.warn("Still don't have expected data after redownload!")
return data


def _load_cached_data(filename, first_date, last_date, now, resource_name):
def _load_cached_data(filename, first_date, last_date, now, resource_name,
environ=None):
if resource_name == 'benchmark':
from_csv = pd.Series.from_csv
else:
from_csv = pd.DataFrame.from_csv

# Path for the cache.
path = get_data_filepath(filename)
path = get_data_filepath(filename, environ)

# If the path does not exist, it means the first download has not happened
# yet, so don't try to read from 'path'.
Expand Down
4 changes: 3 additions & 1 deletion zipline/finance/trading.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial

import logbook
import pandas as pd
Expand Down Expand Up @@ -86,11 +87,12 @@ def __init__(
trading_calendar=None,
asset_db_path=':memory:',
future_chain_predicates=CHAIN_PREDICATES,
environ=None,
):

self.bm_symbol = bm_symbol
if not load:
load = load_market_data
load = partial(load_market_data, environ=environ)

if not trading_calendar:
trading_calendar = get_calendar("NYSE")
Expand Down
1 change: 1 addition & 0 deletions zipline/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
check_allclose,
check_arrays,
chrange,
copy_market_data,
create_daily_df_for_asset,
create_data_portal,
create_data_portal_from_trade_history,
Expand Down
15 changes: 15 additions & 0 deletions zipline/testing/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from zipline.assets import AssetFinder, AssetDBWriter
from zipline.assets.synthetic import make_simple_equity_info
from zipline.data.data_portal import DataPortal
from zipline.data.loader import get_benchmark_filename, INDEX_MAPPING
from zipline.data.minute_bars import (
BcolzMinuteBarReader,
BcolzMinuteBarWriter,
Expand All @@ -52,6 +53,7 @@
from zipline.utils.input_validation import expect_dimensions
from zipline.utils.numpy_utils import as_column, isnat
from zipline.utils.pandas_utils import timedelta_to_integral_seconds
from zipline.utils.paths import ensure_directory
from zipline.utils.sentinel import sentinel

import numpy as np
Expand Down Expand Up @@ -1490,6 +1492,19 @@ def patched_read_csv(filepath_or_buffer, *args, **kwargs):
yield


def copy_market_data(src_market_data_dir, dest_root_dir):
symbol = '^GSPC'
filenames = (get_benchmark_filename(symbol), INDEX_MAPPING[symbol][1])

ensure_directory(os.path.join(dest_root_dir, 'data'))

for filename in filenames:
shutil.copyfile(
os.path.join(src_market_data_dir, filename),
os.path.join(dest_root_dir, 'data', filename)
)


@curry
def ensure_doctest(f, name=None):
"""Ensure that an object gets doctested. This is useful for instances
Expand Down
4 changes: 2 additions & 2 deletions zipline/utils/run_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _run(handle_data,
"invalid url %r, must begin with 'sqlite:///'" %
str(bundle_data.asset_finder.engine.url),
)
env = TradingEnvironment(asset_db_path=connstr)
env = TradingEnvironment(asset_db_path=connstr, environ=environ)
first_trading_day =\
bundle_data.equity_minute_bar_reader.first_trading_day
data = DataPortal(
Expand All @@ -152,7 +152,7 @@ def choose_loader(column):
"No PipelineLoader registered for column %s." % column
)
else:
env = None
env = TradingEnvironment(environ=environ)
choose_loader = None

perf = TradingAlgorithm(
Expand Down

0 comments on commit 8734224

Please sign in to comment.