Skip to content

Commit

Permalink
Merge pull request #1368 from quantopian/lots-of-symbols
Browse files Browse the repository at this point in the history
Lots of symbols
  • Loading branch information
llllllllll authored Aug 2, 2016
2 parents 39e7476 + 1f10fff commit 164bd06
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 16 deletions.
26 changes: 25 additions & 1 deletion tests/test_assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from numpy import full, int32, int64
import pandas as pd
from pandas.util.testing import assert_frame_equal
from six import PY2
from six import PY2, viewkeys
import sqlalchemy as sa

from zipline.assets import (
Expand All @@ -57,6 +57,7 @@
check_version_info,
write_version_info,
_futures_defaults,
SQLITE_MAX_VARIABLE_NUMBER,
)
from zipline.assets.asset_db_schema import ASSET_DB_VERSION
from zipline.assets.asset_db_migrations import (
Expand All @@ -83,6 +84,7 @@
ZiplineTestCase,
WithTradingCalendar,
)
from zipline.utils.range import range


@contextmanager
Expand Down Expand Up @@ -407,6 +409,28 @@ def init_instance_fixtures(self):
self._asset_writer = AssetDBWriter(conn)
self.asset_finder = self.asset_finder_type(conn)

def test_blocked_lookup_symbol_query(self):
# we will try to query for more variables than sqlite supports
# to make sure we are properly chunking on the client side
as_of = pd.Timestamp('2013-01-01', tz='UTC')
# we need more sids than we can query from sqlite
nsids = SQLITE_MAX_VARIABLE_NUMBER + 10
sids = range(nsids)
frame = pd.DataFrame.from_records(
[
{
'sid': sid,
'symbol': 'TEST.%d' % sid,
'start_date': as_of.value,
'end_date': as_of.value,
}
for sid in sids
]
)
self.write_assets(equities=frame)
assets = self.asset_finder.retrieve_equities(sids)
assert_equal(viewkeys(assets), set(sids))

def test_lookup_symbol_delimited(self):
as_of = pd.Timestamp('2013-01-01', tz='UTC')
frame = pd.DataFrame.from_records(
Expand Down
43 changes: 29 additions & 14 deletions zipline/assets/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,16 @@
from pandas import isnull
from six import with_metaclass, string_types, viewkeys, iteritems
import sqlalchemy as sa
from toolz import merge, compose, valmap, sliding_window, concatv, curry
from toolz import (
compose,
concat,
concatv,
curry,
merge,
partition_all,
sliding_window,
valmap,
)
from toolz.curried import operator as op

from zipline.errors import (
Expand All @@ -43,6 +52,7 @@
split_delimited_symbol,
asset_db_table_names,
symbol_columns,
SQLITE_MAX_VARIABLE_NUMBER,
)
from .asset_db_schema import (
ASSET_DB_VERSION
Expand Down Expand Up @@ -432,21 +442,26 @@ def _select_asset_by_symbol(asset_tbl, symbol):

def _lookup_most_recent_symbols(self, sids):
symbol_cols = self.equity_symbol_mappings.c

symbols = {
row.sid: {c: row[c] for c in symbol_columns}
for row in self.engine.execute(
sa.select(
(symbol_cols.sid,) +
tuple(map(op.getitem(symbol_cols), symbol_columns)),
).where(
symbol_cols.sid.in_(map(int, sids)),
).order_by(
symbol_cols.end_date.desc(),
).group_by(
symbol_cols.sid,
)
).fetchall()
for row in concat(
self.engine.execute(
sa.select(
(symbol_cols.sid,) +
tuple(map(op.getitem(symbol_cols), symbol_columns)),
).where(
symbol_cols.sid.in_(map(int, sid_group)),
).order_by(
symbol_cols.end_date.desc(),
).group_by(
symbol_cols.sid,
)
).fetchall()
for sid_group in partition_all(
SQLITE_MAX_VARIABLE_NUMBER,
sids
),
)
}

if len(symbols) != len(sids):
Expand Down
58 changes: 57 additions & 1 deletion zipline/utils/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,31 @@ def __init__(self, stop, *args):
except IndexError:
self.step = 1

if self.step == 0:
raise ValueError('range step must not be zero')

def __iter__(self):
"""
Examples
--------
>>> list(range(1))
[0]
>>> list(range(5))
[0, 1, 2, 3, 4]
>>> list(range(1, 5))
[1, 2, 3, 4]
>>> list(range(0, 5, 2))
[0, 2, 4]
>>> list(range(5, 0, -1))
[5, 4, 3, 2, 1]
>>> list(range(5, 0, 1))
[]
"""
n = self.start
stop = self.stop
step = self.step
while n < stop:
cmp_ = op.lt if step > 0 else op.gt
while cmp_(n, stop):
yield n
n += step

Expand All @@ -46,6 +66,8 @@ def __iter__(self):
)

def __contains__(self, other, _ops=_ops):
# Algorithm taken from CPython
# Objects/rangeobject.c:range_contains_long
start = self.start
step = self.step
cmp_start, cmp_stop = _ops[step > 0]
Expand All @@ -57,6 +79,40 @@ def __contains__(self, other, _ops=_ops):

del _ops

def __len__(self):
"""
Examples
--------
>>> len(range(1))
1
>>> len(range(5))
5
>>> len(range(1, 5))
4
>>> len(range(0, 5, 2))
3
>>> len(range(5, 0, -1))
5
>>> len(range(5, 0, 1))
0
"""
# Algorithm taken from CPython
# rangeobject.c:compute_range_length
step = self.step

if step > 0:
low = self.start
high = self.stop
else:
low = self.stop
high = self.start
step = -step

if low >= high:
return 0

return (high - low - 1) // step + 1

def __repr__(self):
return '%s(%s, %s%s)' % (
type(self).__name__,
Expand Down

0 comments on commit 164bd06

Please sign in to comment.