From 483397e5540eb4a4fe9f925d36b2e3f195959c2d Mon Sep 17 00:00:00 2001 From: Gil Wassermann Date: Tue, 2 Aug 2016 16:34:32 -0400 Subject: [PATCH 1/2] ENH: Added AtLeastN filter --- tests/pipeline/test_filter.py | 77 ++++++++++++++++++++++++++- zipline/pipeline/filters/__init__.py | 3 +- zipline/pipeline/filters/smoothing.py | 16 ++++++ 3 files changed, 94 insertions(+), 2 deletions(-) diff --git a/tests/pipeline/test_filter.py b/tests/pipeline/test_filter.py index ecd232a2e..1861dee8f 100644 --- a/tests/pipeline/test_filter.py +++ b/tests/pipeline/test_filter.py @@ -30,7 +30,7 @@ from zipline.pipeline import Filter, Factor, TermGraph from zipline.pipeline.classifiers import Classifier from zipline.pipeline.factors import CustomFactor -from zipline.pipeline.filters import All, Any +from zipline.pipeline.filters import All, Any, AtLeastN from zipline.testing import check_arrays, parameter_space, permute_rows from zipline.utils.numpy_utils import float64_dtype, int64_dtype from .base import BasePipelineTestCase, with_default_shape @@ -498,6 +498,81 @@ class Input(Filter): check_arrays(results['3'], expected_3) check_arrays(results['4'], expected_4) + def test_at_least_N(self): + + # With a window_length of K, AtLeastN should return 1 + # if N or more 1's exist in the lookback window + + # This smoothing filter gives customizable "stickiness" + + data = array([[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0]], dtype=bool) + + expected_1 = array([[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 0, 0]], dtype=bool) + + expected_2 = array([[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 0, 0], + [1, 1, 1, 0, 0, 0]], dtype=bool) + + expected_3 = array([[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 0, 0, 0, 0]], dtype=bool) + + expected_4 = array([[1, 1, 1, 1, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0]], dtype=bool) + + class Input(Filter): + inputs = () + window_length = 0 + + all_but_one = AtLeastN(inputs=[Input()], + window_length=4, + N=3) + + all_but_two = AtLeastN(inputs=[Input()], + window_length=4, + N=2) + + any_equiv = AtLeastN(inputs=[Input()], + window_length=4, + N=1) + + all_equiv = AtLeastN(inputs=[Input()], + window_length=4, + N=4) + + results = self.run_graph( + TermGraph({ + 'AllButOne': all_but_one, + 'AllButTwo': all_but_two, + 'AnyEquiv': any_equiv, + 'AllEquiv': all_equiv, + 'Any': Any(inputs=[Input()], window_length=4), + 'All': All(inputs=[Input()], window_length=4) + }), + initial_workspace={Input(): data}, + mask=self.build_mask(ones(shape=data.shape)), + ) + + check_arrays(results['Any'], expected_1) + check_arrays(results['AnyEquiv'], expected_1) + check_arrays(results['AllButTwo'], expected_2) + check_arrays(results['AllButOne'], expected_3) + check_arrays(results['All'], expected_4) + check_arrays(results['AllEquiv'], expected_4) + @parameter_space(factor_len=[2, 3, 4]) def test_window_safe(self, factor_len): # all true data set of (days, securities) diff --git a/zipline/pipeline/filters/__init__.py b/zipline/pipeline/filters/__init__.py index 86ad7476c..c88a2732f 100644 --- a/zipline/pipeline/filters/__init__.py +++ b/zipline/pipeline/filters/__init__.py @@ -9,12 +9,13 @@ PercentileFilter, SingleAsset, ) -from .smoothing import All, Any +from .smoothing import All, Any, AtLeastN __all__ = [ 'All', 'Any', 'ArrayPredicate', + 'AtLeastN', 'CustomFilter', 'Filter', 'Latest', diff --git a/zipline/pipeline/filters/smoothing.py b/zipline/pipeline/filters/smoothing.py index f48e29767..17005d3f1 100644 --- a/zipline/pipeline/filters/smoothing.py +++ b/zipline/pipeline/filters/smoothing.py @@ -33,3 +33,19 @@ class Any(CustomFilter): def compute(self, today, assets, out, arg): out[:] = (arg.sum(axis=0) > 0) + + +class AtLeastN(CustomFilter): + """ + A Filter requiring that assets produce True for at least N days in the + last ``window_length`` days. + + **Default Inputs:** None + + **Default Window Length:** None + """ + + params = ('N',) + + def compute(self, today, assets, out, arg, N): + out[:] = (arg.sum(axis=0) >= N) From e09fadb7e7d68ba1ca12bd542c84482039bbef5d Mon Sep 17 00:00:00 2001 From: Gil Wassermann Date: Tue, 2 Aug 2016 16:39:24 -0400 Subject: [PATCH 2/2] DOC: added to whatsnew --- docs/source/whatsnew/1.0.2.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/whatsnew/1.0.2.txt b/docs/source/whatsnew/1.0.2.txt index 83a54071d..87fc37bc5 100644 --- a/docs/source/whatsnew/1.0.2.txt +++ b/docs/source/whatsnew/1.0.2.txt @@ -31,6 +31,11 @@ Enhancements returns True if an asset produced a True for any/all days in the previous ``window_length`` days (:issue:`1358`). +- Added new pipeline filter :class:`~zipline.pipeline.filters.AtLeastN`, + which takes another filter and an int N and returns True if an asset + produced a True on N or more days in the previous ``window_length`` + days (:issue:`1367`). + Bug Fixes ~~~~~~~~~