Skip to content

Commit

Permalink
MAINT: Update tests to conform to new reader/writer structure
Browse files Browse the repository at this point in the history
  • Loading branch information
StewartDouglas authored and jfkirk committed Sep 10, 2015
1 parent 8ccdae9 commit 1ef2274
Show file tree
Hide file tree
Showing 15 changed files with 281 additions and 177 deletions.
34 changes: 22 additions & 12 deletions tests/modelling/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
BcolzDailyBarReader,
USEquityPricingLoader,
)
from zipline.finance import trading
from zipline.finance.trading import TradingEnvironment
from zipline.modelling.engine import SimpleFFCEngine
from zipline.modelling.factor import TestingFactor
Expand Down Expand Up @@ -93,7 +94,9 @@ def setUp(self):
start_date=self.dates[0],
end_date=self.dates[-1],
)
self.asset_finder = AssetFinder(self.asset_info)
trading.environment = trading.TradingEnvironment()
trading.environment.write_data(equities_df=self.asset_info)
self.asset_finder = AssetFinder(trading.environment.engine)

def test_bad_dates(self):
loader = self.loader
Expand Down Expand Up @@ -222,24 +225,30 @@ def test_numeric_factor(self):

class FrameInputTestCase(TestCase):

def setUp(self):
env = TradingEnvironment.instance()
day = env.trading_day
@classmethod
def setUpClass(cls):
cls.env = trading.TradingEnvironment()
day = cls.env.trading_day

self.assets = Int64Index([1, 2, 3])
self.dates = date_range(
cls.assets = Int64Index([1, 2, 3])
cls.dates = date_range(
'2015-01-01',
'2015-01-31',
freq=day,
tz='UTC',
)

asset_info = make_simple_asset_info(
self.assets,
start_date=self.dates[0],
end_date=self.dates[-1],
cls.assets,
start_date=cls.dates[0],
end_date=cls.dates[-1],
)
self.asset_finder = AssetFinder(asset_info)
cls.env.write_data(equities_df=asset_info)

def setUp(self):
self.asset_finder = AssetFinder(FrameInputTestCase.env.engine)
self.dates = FrameInputTestCase.dates
self.assets = FrameInputTestCase.assets

@lazyval
def base_mask(self):
Expand Down Expand Up @@ -329,7 +338,7 @@ class SyntheticBcolzTestCase(TestCase):
@classmethod
def setUpClass(cls):
cls.first_asset_start = Timestamp('2015-04-01', tz='UTC')
cls.env = TradingEnvironment.instance()
cls.env = trading.TradingEnvironment()
cls.trading_day = cls.env.trading_day
cls.asset_info = make_rotating_asset_info(
num_assets=6,
Expand All @@ -345,7 +354,8 @@ def setUpClass(cls):
freq=cls.trading_day,
)

cls.finder = AssetFinder(cls.asset_info)
cls.env.write_data(equities_df=cls.asset_info)
cls.finder = AssetFinder(cls.env.engine)

cls.temp_dir = TempDirectory()
cls.temp_dir.create()
Expand Down
10 changes: 9 additions & 1 deletion tests/modelling/test_modelling_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
USEquityPricingLoader,
)
# from zipline.modelling.factor import CustomFactor
from zipline.finance import trading
from zipline.modelling.factor.technical import VWAP
from zipline.utils.test_utils import (
make_simple_asset_info,
Expand Down Expand Up @@ -84,7 +85,9 @@ def setUpClass(cls):
Timestamp('2015'),
['AAPL', 'MSFT', 'BRK_A'],
)
cls.asset_finder = AssetFinder(asset_info)
cls.env = trading.TradingEnvironment()
cls.env.write_data(equities_df=asset_info)
cls.asset_finder = AssetFinder(cls.env.engine)
cls.tempdir = tempdir = TempDirectory()
tempdir.create()
try:
Expand Down Expand Up @@ -200,6 +203,11 @@ def handle_data(context, data):
# Do the same checks in before_trading_start
before_trading_start = handle_data

# Create fresh trading environment as the algo.run()
# method will attempt to write data to disk, and could
# violate SQL constraints.
trading.environment = trading.TradingEnvironment()

algo = TradingAlgorithm(
initialize=initialize,
handle_data=handle_data,
Expand Down
1 change: 0 additions & 1 deletion tests/test_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,6 @@ def test_order_method_style_forwarding(self):
def test_order_instant(self):
algo = TestOrderInstantAlgorithm(sim_params=self.sim_params,
instant_fill=True)

algo.run(self.df)

def test_minute_data(self):
Expand Down
26 changes: 22 additions & 4 deletions tests/test_algorithm_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ def handle_data(self, data):


class AlgorithmGeneratorTestCase(TestCase):

@classmethod
def setUpClass(cls):
cls.env = trading.TradingEnvironment()
cls.env.write_data(equities_identifiers=[8229])

def setUp(self):
setup_logger(self)

Expand All @@ -106,6 +112,8 @@ def test_lse_algorithm(self):
end=datetime(2012, 6, 30, tzinfo=pytz.utc)
)
algo = TestAlgo(self, identifiers=[8229], sim_params=sim_params)
# This call appears inconsistent with
# the signature of create_daily_trade_source
trade_source = factory.create_daily_trade_source(
[8229],
200,
Expand All @@ -127,11 +135,15 @@ def test_generator_dates(self):
Ensure the pipeline of generators are in sync, at least as far as
their current dates.
"""
# Ensure we are pointing to the TradingEnvironment for this class
trading.environment = AlgorithmGeneratorTestCase.env

sim_params = factory.create_simulation_parameters(
start=datetime(2011, 7, 30, tzinfo=pytz.utc),
end=datetime(2012, 7, 30, tzinfo=pytz.utc)
)
algo = TestAlgo(self, identifiers=[8229], sim_params=sim_params)
algo = TestAlgo(self, sim_params=sim_params,
env=AlgorithmGeneratorTestCase.env)
trade_source = factory.create_daily_trade_source(
[8229],
sim_params
Expand All @@ -158,7 +170,8 @@ def test_handle_data_on_market(self):
period_end=datetime(2012, 7, 30, tzinfo=pytz.utc),
data_frequency='minute'
)
algo = TestAlgo(self, identifiers=[8229], sim_params=sim_params)
algo = TestAlgo(self, sim_params=sim_params,
env=AlgorithmGeneratorTestCase.env)

midnight_custom_source = [Event({
'custom_field': 42.0,
Expand Down Expand Up @@ -196,11 +209,15 @@ def test_progress(self):
Ensure the pipeline of generators are in sync, at least as far as
their current dates.
"""
# Ensure we are pointing to the TradingEnvironment for this class
trading.environment = AlgorithmGeneratorTestCase.env

sim_params = factory.create_simulation_parameters(
start=datetime(2008, 1, 1, tzinfo=pytz.utc),
end=datetime(2008, 1, 5, tzinfo=pytz.utc)
)
algo = TestAlgo(self, sim_params=sim_params)
algo = TestAlgo(self, sim_params=sim_params,
env=AlgorithmGeneratorTestCase.env)
trade_source = factory.create_daily_trade_source(
[8229],
sim_params
Expand All @@ -222,6 +239,7 @@ def test_benchmark_times_match_market_close_for_minutely_data(self):
"""
sim_params = create_simulation_parameters(num_days=1,
data_frequency='minute')
algo = TestAlgo(self, sim_params=sim_params, identifiers=[8229])
algo = TestAlgo(self, sim_params=sim_params,
env=AlgorithmGeneratorTestCase.env)
algo.run(source=[], overwrite_sim_params=False)
self.assertEqual(algo.datetime, sim_params.last_close)
Loading

0 comments on commit 1ef2274

Please sign in to comment.