diff --git a/setup.cfg b/setup.cfg index a1ce13a4e..68171ee44 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,7 +3,7 @@ verbosity=2 detailed-errors=1 with-ignore-docstrings=1 with-timer=1 -timer-top-n=15 +timer-filter=warning [metadata] description-file = README.rst diff --git a/zipline/assets/assets.py b/zipline/assets/assets.py index a386b5f65..282db0527 100644 --- a/zipline/assets/assets.py +++ b/zipline/assets/assets.py @@ -20,15 +20,16 @@ from logbook import Logger import numpy as np import pandas as pd +from pandas import isnull from pandas.tseries.tools import normalize_date from six import with_metaclass, string_types, viewkeys +from six.moves import map as imap import sqlalchemy as sa -from toolz import compose from zipline.errors import ( MultipleSymbolsFound, RootSymbolNotFound, - SidNotFound, + SidsNotFound, SymbolNotFound, MapAssetIdentifierIndexError, ) @@ -41,6 +42,7 @@ ASSET_DB_VERSION, asset_db_table_names, ) +from zipline.utils.control_flow import invert log = Logger('assets.py') @@ -63,13 +65,14 @@ }) -def _convert_asset_timestamp_fields(dict): +def _convert_asset_timestamp_fields(dict_): """ Takes in a dict of Asset init args and converts dates to pd.Timestamps """ - for key in (_asset_timestamp_fields & viewkeys(dict)): - value = pd.Timestamp(dict[key], tz='UTC') - dict[key] = None if pd.isnull(value) else value + for key in (_asset_timestamp_fields & viewkeys(dict_)): + value = pd.Timestamp(dict_[key], tz='UTC') + dict_[key] = None if isnull(value) else value + return dict_ class AssetFinder(object): @@ -96,105 +99,218 @@ def __init__(self, engine): # routing. # # The caches are read through, i.e. accessing an asset through - # retrieve_asset, _retrieve_equity etc. will populate the cache on - # first retrieval. + # retrieve_asset will populate the cache on first retrieval. self._asset_cache = {} - self._equity_cache = {} - self._future_cache = {} - self._asset_type_cache = {} # Populated on first call to `lifetimes`. self._asset_lifetimes = None - def asset_type_by_sid(self, sid): + def lookup_asset_types(self, sids): """ - Retrieve the asset type of a given sid. + Retrieve asset types for a list of sids. + + Parameters + ---------- + sids : list[int] + + Returns + ------- + types : dict[sid -> str or None] + Asset types for the provided sids. """ - try: - return self._asset_type_cache[sid] - except KeyError: - pass + found, missing = {}, set() + for sid in sids: + try: + found[sid] = self._asset_type_cache[sid] + except KeyError: + missing.add(sid) + + if not missing: + return found + + router_cols = self.asset_router.c + query = sa.select((router_cols.sid, router_cols.asset_type)).where( + self.asset_router.c.sid.in_(map(int, missing)) + ) + for sid, type_ in query.execute().fetchall(): + missing.remove(sid) + found[sid] = self._asset_type_cache[sid] = type_ + + for sid in missing: + found[sid] = self._asset_type_cache[sid] = None + + return found + + def lookup_single_asset_type(self, sid): + """Retrieve the asset type for a single asset.""" + return self.lookup_asset_types([sid])[sid] + + def group_by_type(self, sids): + """ + Group a list of sids by asset type. - asset_type = sa.select((self.asset_router.c.asset_type,)).where( - self.asset_router.c.sid == int(sid), - ).scalar() + Parameters + ---------- + sids : list[int] - if asset_type is not None: - self._asset_type_cache[sid] = asset_type - return asset_type + Returns + ------- + types : defaultdict[str or None -> list[int]] + A dict mapping unique asset types to lists of sids drawn from sids. + If we fail to look up an asset, we assign it a key of None. + """ + return invert(self.lookup_asset_types(sids)) def retrieve_asset(self, sid, default_none=False): """ - Retrieve the Asset object of a given sid. + Retrieve the Asset for a given sid. """ - if isinstance(sid, Asset): - return sid + return self.retrieve_all((sid,), default_none=default_none)[0] - try: - asset = self._asset_cache[sid] - except KeyError: - asset_type = self.asset_type_by_sid(sid) - if asset_type == 'equity': - asset = self._retrieve_equity(sid) - elif asset_type == 'future': - asset = self._retrieve_futures_contract(sid) - else: - asset = None + def retrieve_all(self, sids, default_none=False): + """ + Retrieve all assets in `sids`. - # Cache the asset if it has been retrieved - if asset is not None: - self._asset_cache[sid] = asset + Parameters + ---------- + sids : interable of int + Assets to retrieve. + default_none : bool + If True, return None for failed lookups. + If False, raise `SidsNotFound`. - if (asset is not None) or default_none: - return asset - raise SidNotFound(sid=sid) + Returns + ------- + assets : list[int or None] + A list of the same length as `sids` containing Assets (or Nones) + corresponding to the requested sids. - def retrieve_all(self, sids, default_none=False): - return [self.retrieve_asset(sid, default_none) for sid in sids] + Raises + ------ + SidsNotFound + When a requested sid is not found and default_none=False. + """ + hits, missing, failures = {}, set(), [] + for sid in sids: + try: + asset = self._asset_cache[sid] + if not default_none and asset is None: + # Bail early if we've already cached that we don't know + # about an asset. + raise SidsNotFound(sids=[sid]) + hits[sid] = asset + except KeyError: + missing.add(sid) + + # All requests were cache hits. Return requested sids in order. + if not missing: + return [hits[sid] for sid in sids] + + update_hits = hits.update + + # Look up cache misses by type. + type_to_assets = self.group_by_type(missing) + + # Handle failures + failures = {failure: None for failure in type_to_assets.pop(None, ())} + update_hits(failures) + self._asset_cache.update(failures) + + if failures and not default_none: + raise SidsNotFound(sids=list(failures)) + + # We don't update the asset cache here because it should already be + # updated by `self._retrieve_equities`. + update_hits(self._retrieve_equities(type_to_assets.pop('equity', ()))) + update_hits( + self._retrieve_futures_contracts(type_to_assets.pop('future', ())) + ) - def _retrieve_equity(self, sid): + # We shouldn't know about any other asset types. + if type_to_assets: + raise AssertionError( + "Found asset types: %s" % list(type_to_assets.keys()) + ) + + return [hits[sid] for sid in sids] + + def _retrieve_equities(self, sids): """ Retrieve the Equity object of a given sid. """ - return self._retrieve_asset( - sid, self._equity_cache, self.equities, Equity, - ) + return self._retrieve_assets(sids, self.equities, Equity) - def _retrieve_futures_contract(self, sid): + def _retrieve_equity(self, sid): + return self._retrieve_equities((sid,))[sid] + + def _retrieve_futures_contracts(self, sids): """ Retrieve the Future object of a given sid. """ - return self._retrieve_asset( - sid, self._future_cache, self.futures_contracts, Future, - ) + return self._retrieve_assets(sids, self.futures_contracts, Future) + + def _retrieve_futures_contract(self, sid): + return self._retrieve_futures_contracts((sid,))[sid] @staticmethod - def _select_asset_by_sid(asset_tbl, sid): - return sa.select([asset_tbl]).where(asset_tbl.c.sid == int(sid)) + def _select_assets_by_sid(asset_tbl, sids): + return sa.select([asset_tbl]).where( + asset_tbl.c.sid.in_(map(int, sids)) + ) @staticmethod def _select_asset_by_symbol(asset_tbl, symbol): return sa.select([asset_tbl]).where(asset_tbl.c.symbol == symbol) - def _retrieve_asset(self, sid, cache, asset_tbl, asset_type): - try: - return cache[sid] - except KeyError: - pass + def _retrieve_assets(self, sids, asset_tbl, asset_type): + """ + Internal function for loading assets from a table. - data = self._select_asset_by_sid(asset_tbl, sid).execute().fetchone() - # Convert 'data' from a RowProxy object to a dict, to allow assignment - data = dict(data.items()) - if data: - _convert_asset_timestamp_fields(data) + This function does not do any caching. It is assumed that this will be + called at most once with any given sid. - asset = asset_type(**data) - else: - asset = None + Parameters + --------- + sids : iterable of int + Asset ids to look up. + asset_tbl : sqlalchemy.Table + Table from which to query assets. + asset_type : type + Type of asset to be constructed. - cache[sid] = asset - return asset + Returns + ------- + assets : dict[int -> Asset] + Dict mapping requested sids to the retrieved assets. + """ + # Fastpath for empty request. + if not sids: + return {} + + cache = self._asset_cache + + hits = {} + # Load misses from the db. + query = self._select_assets_by_sid(asset_tbl, sids) + for row in imap(dict, query.execute().fetchall()): + asset = asset_type(**_convert_asset_timestamp_fields(row)) + sid = asset.sid + hits[sid] = cache[sid] = asset + + # If we get here, it means something in our code thought that a + # particular sid was an equity/future and called this function with a + # concrete type, but we couldn't actually resolve the asset. This is + # an error in our code, not a user-input error. + misses = tuple(set(sids) - viewkeys(hits)) + if misses: + raise AssertionError( + "Couldn't resolve sids {sids} as instances of {type}.".format( + sids=misses, + type=asset_type, + ) + ) + return hits def _get_fuzzy_candidates(self, fuzzy_symbol): candidates = sa.select( @@ -272,10 +388,9 @@ def _get_best_candidate(self, candidates): return self._retrieve_equity(candidates[0]['sid']) def _get_equities_from_candidates(self, candidates): - return list(map( - compose(self._retrieve_equity, itemgetter('sid')), - candidates, - )) + sids = map(itemgetter('sid'), candidates) + results = self.retrieve_equities(sids) + return [results[sid] for sid in sids] def lookup_symbol(self, symbol, as_of_date, fuzzy=False): """ @@ -286,7 +401,6 @@ def lookup_symbol(self, symbol, as_of_date, fuzzy=False): If no Equity was active at as_of_date raises SymbolNotFound. """ - company_symbol, share_class_symbol, fuzzy_symbol = \ split_delimited_symbol(symbol) if as_of_date: @@ -376,22 +490,7 @@ def lookup_future_symbol(self, symbol): # If no data found, raise an exception if not data: raise SymbolNotFound(symbol=symbol) - - # If we find a contract, check whether it's been cached - try: - return self._future_cache[data['sid']] - except KeyError: - pass - - # Build the Future object from its parameters - data = dict(data.items()) - _convert_asset_timestamp_fields(data) - future = Future(**data) - - # Cache the Future object. - self._future_cache[data['sid']] = future - - return future + return self.retrieve_asset(data['sid']) def lookup_future_chain(self, root_symbol, as_of_date): """ Return the futures chain for a given root symbol. @@ -487,7 +586,8 @@ def lookup_future_chain(self, root_symbol, as_of_date): if count == 0: raise RootSymbolNotFound(root_symbol=root_symbol) - return list(map(self._retrieve_futures_contract, sids)) + contracts = self._retrieve_futures_contracts(sids) + return [contracts[sid] for sid in sids] @property def sids(self): @@ -513,7 +613,7 @@ def _lookup_generic_scalar(self, elif isinstance(asset_convertible, Integral): try: result = self.retrieve_asset(int(asset_convertible)) - except SidNotFound: + except SidsNotFound: missing.append(asset_convertible) return None matches.append(result) @@ -563,7 +663,7 @@ def lookup_generic(self, return matches[0], missing except IndexError: if hasattr(asset_convertible_or_iterable, '__int__'): - raise SidNotFound(sid=asset_convertible_or_iterable) + raise SidsNotFound(sids=[asset_convertible_or_iterable]) else: raise SymbolNotFound(symbol=asset_convertible_or_iterable) diff --git a/zipline/errors.py b/zipline/errors.py index aab1f4143..91e680754 100644 --- a/zipline/errors.py +++ b/zipline/errors.py @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from zipline.utils.memoize import lazyval + class ZiplineError(Exception): msg = None - def __init__(self, *args, **kwargs): - self.args = args + def __init__(self, **kwargs): self.kwargs = kwargs self.message = str(self) @@ -231,13 +232,17 @@ class RootSymbolNotFound(ZiplineError): """.strip() -class SidNotFound(ZiplineError): +class SidsNotFound(ZiplineError): """ - Raised when a retrieve_asset() call contains a non-existent sid. + Raised when a retrieve_asset() or retrieve_all() call contains a + non-existent sid. """ - msg = """ -Asset with sid '{sid}' was not found. -""".strip() + @lazyval + def msg(self): + sids = self.kwargs['sids'] + if len(sids) == 1: + return "No asset found for sid: {sids[0]}." + return "No assets found for sids: {sids}." class ConsumeAssetMetaDataError(ZiplineError): diff --git a/zipline/utils/control_flow.py b/zipline/utils/control_flow.py index f45f3c0e8..891c5ff3f 100644 --- a/zipline/utils/control_flow.py +++ b/zipline/utils/control_flow.py @@ -1,6 +1,7 @@ """ Control flow utilities. """ +from six import iteritems from warnings import ( catch_warnings, filterwarnings, @@ -54,3 +55,19 @@ def ignore_nanwarnings(): {'category': RuntimeWarning, 'module': 'numpy.lib.nanfunctions'}, ) ) + + +def invert(d): + """ + Invert a dictionary into a dictionary of lists. + + >>> invert({'a': 1, 'b': 2, 'c': 1}) + {1: ['a', 'c'], 2: ['b']} + """ + out = {} + for k, v in iteritems(d): + try: + out[v].append(k) + except KeyError: + out[v] = [k] + return out