From 7e99094cb1cad829186bb32a0e072faad16515e9 Mon Sep 17 00:00:00 2001 From: Joe Jevnik Date: Tue, 2 Aug 2016 18:33:28 -0400 Subject: [PATCH 1/2] ENH: add __len__ and fix iteration for negative step --- zipline/utils/range.py | 58 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/zipline/utils/range.py b/zipline/utils/range.py index 138b09220..9e966bab5 100644 --- a/zipline/utils/range.py +++ b/zipline/utils/range.py @@ -32,11 +32,31 @@ def __init__(self, stop, *args): except IndexError: self.step = 1 + if self.step == 0: + raise ValueError('range step must not be zero') + def __iter__(self): + """ + Examples + -------- + >>> list(range(1)) + [0] + >>> list(range(5)) + [0, 1, 2, 3, 4] + >>> list(range(1, 5)) + [1, 2, 3, 4] + >>> list(range(0, 5, 2)) + [0, 2, 4] + >>> list(range(5, 0, -1)) + [5, 4, 3, 2, 1] + >>> list(range(5, 0, 1)) + [] + """ n = self.start stop = self.stop step = self.step - while n < stop: + cmp_ = op.lt if step > 0 else op.gt + while cmp_(n, stop): yield n n += step @@ -46,6 +66,8 @@ def __iter__(self): ) def __contains__(self, other, _ops=_ops): + # Algorithm taken from CPython + # Objects/rangeobject.c:range_contains_long start = self.start step = self.step cmp_start, cmp_stop = _ops[step > 0] @@ -57,6 +79,40 @@ def __contains__(self, other, _ops=_ops): del _ops + def __len__(self): + """ + Examples + -------- + >>> len(range(1)) + 1 + >>> len(range(5)) + 5 + >>> len(range(1, 5)) + 4 + >>> len(range(0, 5, 2)) + 3 + >>> len(range(5, 0, -1)) + 5 + >>> len(range(5, 0, 1)) + 0 + """ + # Algorithm taken from CPython + # rangeobject.c:compute_range_length + step = self.step + + if step > 0: + low = self.start + high = self.stop + else: + low = self.stop + high = self.start + step = -step + + if low >= high: + return 0 + + return (high - low - 1) // step + 1 + def __repr__(self): return '%s(%s, %s%s)' % ( type(self).__name__, From 1f10fff1c4a2d94863c8e9ffac54f5001785c9ad Mon Sep 17 00:00:00 2001 From: Joe Jevnik Date: Tue, 2 Aug 2016 18:34:00 -0400 Subject: [PATCH 2/2] BUG: support querying more than 999 assets at a time --- tests/test_assets.py | 26 +++++++++++++++++++++++- zipline/assets/assets.py | 43 +++++++++++++++++++++++++++------------- 2 files changed, 54 insertions(+), 15 deletions(-) diff --git a/tests/test_assets.py b/tests/test_assets.py index b81fac570..13f19e75a 100644 --- a/tests/test_assets.py +++ b/tests/test_assets.py @@ -30,7 +30,7 @@ from numpy import full, int32, int64 import pandas as pd from pandas.util.testing import assert_frame_equal -from six import PY2 +from six import PY2, viewkeys import sqlalchemy as sa from zipline.assets import ( @@ -57,6 +57,7 @@ check_version_info, write_version_info, _futures_defaults, + SQLITE_MAX_VARIABLE_NUMBER, ) from zipline.assets.asset_db_schema import ASSET_DB_VERSION from zipline.assets.asset_db_migrations import ( @@ -83,6 +84,7 @@ ZiplineTestCase, WithTradingCalendar, ) +from zipline.utils.range import range @contextmanager @@ -407,6 +409,28 @@ def init_instance_fixtures(self): self._asset_writer = AssetDBWriter(conn) self.asset_finder = self.asset_finder_type(conn) + def test_blocked_lookup_symbol_query(self): + # we will try to query for more variables than sqlite supports + # to make sure we are properly chunking on the client side + as_of = pd.Timestamp('2013-01-01', tz='UTC') + # we need more sids than we can query from sqlite + nsids = SQLITE_MAX_VARIABLE_NUMBER + 10 + sids = range(nsids) + frame = pd.DataFrame.from_records( + [ + { + 'sid': sid, + 'symbol': 'TEST.%d' % sid, + 'start_date': as_of.value, + 'end_date': as_of.value, + } + for sid in sids + ] + ) + self.write_assets(equities=frame) + assets = self.asset_finder.retrieve_equities(sids) + assert_equal(viewkeys(assets), set(sids)) + def test_lookup_symbol_delimited(self): as_of = pd.Timestamp('2013-01-01', tz='UTC') frame = pd.DataFrame.from_records( diff --git a/zipline/assets/assets.py b/zipline/assets/assets.py index 4796ac8d4..b770014aa 100644 --- a/zipline/assets/assets.py +++ b/zipline/assets/assets.py @@ -23,7 +23,16 @@ from pandas import isnull from six import with_metaclass, string_types, viewkeys, iteritems import sqlalchemy as sa -from toolz import merge, compose, valmap, sliding_window, concatv, curry +from toolz import ( + compose, + concat, + concatv, + curry, + merge, + partition_all, + sliding_window, + valmap, +) from toolz.curried import operator as op from zipline.errors import ( @@ -43,6 +52,7 @@ split_delimited_symbol, asset_db_table_names, symbol_columns, + SQLITE_MAX_VARIABLE_NUMBER, ) from .asset_db_schema import ( ASSET_DB_VERSION @@ -432,21 +442,26 @@ def _select_asset_by_symbol(asset_tbl, symbol): def _lookup_most_recent_symbols(self, sids): symbol_cols = self.equity_symbol_mappings.c - symbols = { row.sid: {c: row[c] for c in symbol_columns} - for row in self.engine.execute( - sa.select( - (symbol_cols.sid,) + - tuple(map(op.getitem(symbol_cols), symbol_columns)), - ).where( - symbol_cols.sid.in_(map(int, sids)), - ).order_by( - symbol_cols.end_date.desc(), - ).group_by( - symbol_cols.sid, - ) - ).fetchall() + for row in concat( + self.engine.execute( + sa.select( + (symbol_cols.sid,) + + tuple(map(op.getitem(symbol_cols), symbol_columns)), + ).where( + symbol_cols.sid.in_(map(int, sid_group)), + ).order_by( + symbol_cols.end_date.desc(), + ).group_by( + symbol_cols.sid, + ) + ).fetchall() + for sid_group in partition_all( + SQLITE_MAX_VARIABLE_NUMBER, + sids + ), + ) } if len(symbols) != len(sids):