diff --git a/tests/calendars/test_trading_calendar.py b/tests/calendars/test_trading_calendar.py index 0357a645f2..94af91da42 100644 --- a/tests/calendars/test_trading_calendar.py +++ b/tests/calendars/test_trading_calendar.py @@ -441,6 +441,22 @@ def test_minute_to_session_label(self): direction="none" ) + @parameterized.expand([ + (1, 0), + (2, 0), + (2, 1), + ]) + def test_minute_index_to_session_labels(self, interval, offset): + minutes = self.calendar.minutes_for_sessions_in_range('2011-01-04', + '2011-04-04') + minutes = minutes[range(offset, len(minutes), interval)] + + np.testing.assert_array_equal( + np.array(minutes.map(self.calendar.minute_to_session_label), + dtype='datetime64[ns]'), + self.calendar.minute_index_to_session_labels(minutes) + ) + def test_next_prev_session(self): session_labels = self.answers.index[1:-2] max_idx = len(session_labels) - 1 diff --git a/zipline/utils/calendars/_calendar_helpers.pyx b/zipline/utils/calendars/_calendar_helpers.pyx index 0f7e0520dd..55316a477f 100644 --- a/zipline/utils/calendars/_calendar_helpers.pyx +++ b/zipline/utils/calendars/_calendar_helpers.pyx @@ -1,13 +1,13 @@ -from numpy cimport ndarray, long_t -from numpy import searchsorted +from numpy cimport ndarray, int64_t +from numpy import empty, searchsorted, int64 cimport cython @cython.boundscheck(False) @cython.wraparound(False) -def next_divider_idx(ndarray[long_t, ndim=1] dividers, long_t minute_val): +cpdef int next_divider_idx(ndarray[int64_t, ndim=1] dividers, int64_t minute_val): cdef int divider_idx - cdef long target + cdef int64_t target divider_idx = searchsorted(dividers, minute_val, side="right") target = dividers[divider_idx] @@ -20,8 +20,8 @@ def next_divider_idx(ndarray[long_t, ndim=1] dividers, long_t minute_val): @cython.boundscheck(False) @cython.wraparound(False) -def previous_divider_idx(ndarray[long_t, ndim=1] dividers, - long_t minute_val): +def previous_divider_idx(ndarray[int64_t, ndim=1] dividers, + int64_t minute_val): cdef int divider_idx divider_idx = searchsorted(dividers, minute_val) @@ -31,9 +31,9 @@ def previous_divider_idx(ndarray[long_t, ndim=1] dividers, return divider_idx - 1 -def is_open(ndarray[long_t, ndim=1] opens, - ndarray[long_t, ndim=1] closes, - long_t minute_val): +def is_open(ndarray[int64_t, ndim=1] opens, + ndarray[int64_t, ndim=1] closes, + int64_t minute_val): cdef open_idx, close_idx open_idx = searchsorted(opens, minute_val) @@ -51,3 +51,24 @@ def is_open(ndarray[long_t, ndim=1] opens, # this can happen if we're outside the schedule's range (like # after the last close) return False + +@cython.boundscheck(False) +@cython.wraparound(False) +def minutes_to_session_labels(ndarray[int64_t, ndim=1] minutes, + minute_to_session_label, + ndarray[int64_t, ndim=1] closes): + cdef int current_idx, next_idx, close_idx + current_idx = next_idx = close_idx = 0 + + cdef ndarray[int64_t, ndim=1] results = empty(len(minutes), dtype=int64) + + while current_idx < len(minutes): + close_idx += searchsorted(closes[close_idx:], + minutes[current_idx], side="right") + next_idx += next_divider_idx(minutes[next_idx:], closes[close_idx]) + results[current_idx:next_idx] = minute_to_session_label( + minutes[current_idx] + ) + current_idx = next_idx + + return results diff --git a/zipline/utils/calendars/trading_calendar.py b/zipline/utils/calendars/trading_calendar.py index 37c7f35307..77966c6a5d 100644 --- a/zipline/utils/calendars/trading_calendar.py +++ b/zipline/utils/calendars/trading_calendar.py @@ -29,7 +29,13 @@ from zipline.utils.calendars._calendar_helpers import ( next_divider_idx, previous_divider_idx, - is_open + is_open, + minutes_to_session_labels, +) +from zipline.utils.input_validation import ( + attrgetter, + coerce, + preprocess, ) from zipline.utils.memoize import remember_last, lazyval @@ -659,13 +665,14 @@ def all_minutes(self): return DatetimeIndex(all_minutes).tz_localize("UTC") + @preprocess(dt=coerce(pd.Timestamp, attrgetter('value'))) def minute_to_session_label(self, dt, direction="next"): """ Given a minute, get the label of its containing session. Parameters ---------- - dt : pd.Timestamp + dt : pd.Timestamp or nanosecond offset The dt for which to get the containing session. direction: str @@ -684,17 +691,17 @@ def minute_to_session_label(self, dt, direction="next"): The label of the containing session. """ - idx = searchsorted(self.market_closes_nanos, dt.value) + idx = searchsorted(self.market_closes_nanos, dt) current_or_next_session = self.schedule.index[idx] if direction == "previous": if not is_open(self.market_opens_nanos, self.market_closes_nanos, - dt.value): + dt): # if the exchange is closed, use the previous session return self.schedule.index[idx - 1] elif direction == "none": if not is_open(self.market_opens_nanos, self.market_closes_nanos, - dt.value): + dt): # if the exchange is closed, blow up raise ValueError("The given dt is not an exchange minute!") elif direction != "next": @@ -704,6 +711,30 @@ def minute_to_session_label(self, dt, direction="next"): return current_or_next_session + def minute_index_to_session_labels(self, index): + """ + Given a sorted DatetimeIndex of market minutes, return a + DatetimeIndex of the corresponding session labels. + + Parameters + ---------- + index: pd.DatetimeIndex or pd.Series + The ordered list of market minutes we want session labels for. + + Returns + ------- + pd.DatetimeIndex (UTC) + The list of session labels corresponding to the given minutes. + """ + def minute_to_session_label_nanos(dt_nanos): + return self.minute_to_session_label(dt_nanos).value + + return DatetimeIndex(minutes_to_session_labels( + index.values.astype(np.int64), + minute_to_session_label_nanos, + self.market_closes_nanos, + ).astype('datetime64[ns]'), tz='UTC') + def _special_dates(self, calendars, ad_hoc_dates, start_date, end_date): """ Union an iterable of pairs of the form (time, calendar)