Skip to content

Commit

Permalink
TST: Use fixture's trading env for FakeDataPortal or TradingAlgo
Browse files Browse the repository at this point in the history
to avoid a new trading env needing to download data unnecessarily
  • Loading branch information
richafrank committed May 18, 2017
1 parent ca26208 commit 955862b
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 22 deletions.
6 changes: 3 additions & 3 deletions tests/pipeline/test_pipeline_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def handle_data(context, data):
)

algo.run(
FakeDataPortal(),
FakeDataPortal(self.env),
# Yes, I really do want to use the start and end dates I passed to
# TradingAlgorithm.
overwrite_sim_params=False,
Expand Down Expand Up @@ -606,7 +606,7 @@ def before_trading_start(context, data):
)

algo.run(
FakeDataPortal(),
FakeDataPortal(self.env),
overwrite_sim_params=False,
)

Expand Down Expand Up @@ -654,7 +654,7 @@ def before_trading_start(context, data):
)

algo.run(
FakeDataPortal(),
FakeDataPortal(self.env),
overwrite_sim_params=False,
)

Expand Down
22 changes: 15 additions & 7 deletions tests/test_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ def test_zipline_api_resolves_dynamically(self):
initialize=lambda context: None,
handle_data=lambda context, data: None,
sim_params=self.sim_params,
env=self.env,
)

# Verify that api methods get resolved dynamically by patching them out
Expand Down Expand Up @@ -892,7 +893,8 @@ def before_trading_start(context, data):
def test_run_twice(self):
algo1 = TestRegisterTransformAlgorithm(
sim_params=self.sim_params,
sids=[0, 1]
sids=[0, 1],
env=self.env,
)

res1 = algo1.run(self.data_portal)
Expand All @@ -901,7 +903,8 @@ def test_run_twice(self):
# use the newly instantiated environment.
algo2 = TestRegisterTransformAlgorithm(
sim_params=self.sim_params,
sids=[0, 1]
sids=[0, 1],
env=self.env,
)

res2 = algo2.run(self.data_portal)
Expand Down Expand Up @@ -1569,15 +1572,16 @@ def make_equity_daily_bar_data(cls):

def test_noop(self):
algo = TradingAlgorithm(initialize=initialize_noop,
handle_data=handle_data_noop)
handle_data=handle_data_noop,
env=self.env)
algo.run(self.data_portal)

def test_noop_string(self):
algo = TradingAlgorithm(script=noop_algo)
algo = TradingAlgorithm(script=noop_algo, env=self.env)
algo.run(self.data_portal)

def test_no_handle_data(self):
algo = TradingAlgorithm(script=no_handle_data)
algo = TradingAlgorithm(script=no_handle_data, env=self.env)
algo.run(self.data_portal)

def test_api_calls(self):
Expand All @@ -1593,7 +1597,8 @@ def test_api_calls_string(self):
def test_api_get_environment(self):
platform = 'zipline'
algo = TradingAlgorithm(script=api_get_environment_algo,
platform=platform)
platform=platform,
env=self.env)
algo.run(self.data_portal)
self.assertEqual(algo.environment, platform)

Expand Down Expand Up @@ -1779,6 +1784,7 @@ def test_algo_record_allow_mock(self):
test_algo = TradingAlgorithm(
script=record_variables,
sim_params=self.sim_params,
env=self.env,
)
set_algo_instance(test_algo)

Expand Down Expand Up @@ -3785,6 +3791,7 @@ def analyze(context, perf):
initialize=initialize,
handle_data=handle_data,
analyze=analyze,
env=self.env,
)

with empty_trading_env() as env:
Expand Down Expand Up @@ -4642,7 +4649,7 @@ def handle_data(context, data):
self.assertEqual(expected_message, w.message)


class AlgoInputValidationTestCase(ZiplineTestCase):
class AlgoInputValidationTestCase(WithTradingEnvironment, ZiplineTestCase):

def test_reject_passing_both_api_methods_and_script(self):
script = dedent(
Expand All @@ -4668,6 +4675,7 @@ def analyze(context, results):
with self.assertRaises(ValueError):
TradingAlgorithm(
script=script,
env=self.env,
**{method: lambda *args, **kwargs: None}
)

Expand Down
19 changes: 11 additions & 8 deletions tests/test_tradesimulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from nose_parameterized import parameterized
from six.moves import range
from unittest import TestCase
from zipline import TradingAlgorithm
from zipline.gens.sim_engine import BEFORE_TRADING_START_BAR

Expand All @@ -28,8 +27,12 @@
from zipline.gens.tradesimulation import AlgorithmSimulator
from zipline.sources.benchmark_source import BenchmarkSource
from zipline.test_algorithms import NoopAlgorithm
from zipline.testing.fixtures import WithSimParams, ZiplineTestCase, \
WithDataPortal
from zipline.testing.fixtures import (
WithDataPortal,
WithSimParams,
WithTradingEnvironment,
ZiplineTestCase,
)
from zipline.utils import factory
from zipline.testing.core import FakeDataPortal
from zipline.utils.calendars.trading_calendar import days_at_time
Expand All @@ -50,7 +53,7 @@ def handle_data(self, data):
FREQUENCIES = {'daily': 0, 'minute': 1} # daily is less frequent than minute


class TestTradeSimulation(TestCase):
class TestTradeSimulation(WithTradingEnvironment, ZiplineTestCase):

def fake_minutely_benchmark(self, dt):
return 0.01
Expand All @@ -61,8 +64,8 @@ def test_minutely_emissions_generate_performance_stats_for_last_day(self):
emission_rate='minute')
with patch.object(BenchmarkSource, "get_value",
self.fake_minutely_benchmark):
algo = NoopAlgorithm(sim_params=params)
algo.run(FakeDataPortal())
algo = NoopAlgorithm(sim_params=params, env=self.env)
algo.run(FakeDataPortal(self.env))
self.assertEqual(len(algo.perf_tracker.sim_params.sessions), 1)

@parameterized.expand([('%s_%s_%s' % (num_sessions, freq, emission_rate),
Expand All @@ -82,8 +85,8 @@ def fake_benchmark(self, dt):

with patch.object(BenchmarkSource, "get_value",
self.fake_minutely_benchmark):
algo = BeforeTradingAlgorithm(sim_params=params)
algo.run(FakeDataPortal())
algo = BeforeTradingAlgorithm(sim_params=params, env=self.env)
algo.run(FakeDataPortal(self.env))

self.assertEqual(
len(algo.perf_tracker.sim_params.sessions),
Expand Down
5 changes: 1 addition & 4 deletions zipline/testing/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,11 +695,8 @@ def create_data_portal_from_trade_history(asset_finder, trading_calendar,


class FakeDataPortal(DataPortal):
def __init__(self, env=None, trading_calendar=None,
def __init__(self, env, trading_calendar=None,
first_trading_day=None):
if env is None:
env = TradingEnvironment()

if trading_calendar is None:
trading_calendar = get_calendar("NYSE")

Expand Down

0 comments on commit 955862b

Please sign in to comment.