forked from scrtlabs/catalyst
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_column.py
79 lines (66 loc) · 2.45 KB
/
test_column.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""
Tests BoundColumn attributes and methods.
"""
from contextlib2 import ExitStack
from unittest import TestCase
from pandas import date_range, DataFrame
from pandas.util.testing import assert_frame_equal
from catalyst.lib.labelarray import LabelArray
from catalyst.pipeline import Pipeline
from catalyst.pipeline.data.testing import TestingDataSet as TDS
from catalyst.testing import chrange, temp_pipeline_engine
from catalyst.utils.pandas_utils import ignore_pandas_nan_categorical_warning
class LatestTestCase(TestCase):
@classmethod
def setUpClass(cls):
cls._stack = stack = ExitStack()
cls.calendar = cal = date_range('2014', '2015', freq='D', tz='UTC')
cls.sids = list(range(5))
cls.engine = stack.enter_context(
temp_pipeline_engine(
cal,
cls.sids,
random_seed=100,
symbols=chrange('A', 'E'),
),
)
cls.assets = cls.engine._finder.retrieve_all(cls.sids)
@classmethod
def tearDownClass(cls):
cls._stack.close()
def expected_latest(self, column, slice_):
loader = self.engine.get_loader(column)
index = self.calendar[slice_]
columns = self.assets
values = loader.values(column.dtype, self.calendar, self.sids)[slice_]
if column.dtype.kind in ('O', 'S', 'U'):
# For string columns, we expect a categorical in the output.
return LabelArray(
values,
missing_value=column.missing_value,
).as_categorical_frame(
index=index,
columns=columns,
)
return DataFrame(
loader.values(column.dtype, self.calendar, self.sids)[slice_],
index=self.calendar[slice_],
columns=self.assets,
)
def _test_latest(self):
columns = TDS.columns
pipe = Pipeline(
columns={c.name: c.latest for c in columns},
)
cal_slice = slice(20, 40)
dates_to_test = self.calendar[cal_slice]
result = self.engine.run_pipeline(
pipe,
dates_to_test[0],
dates_to_test[-1],
)
for column in columns:
with ignore_pandas_nan_categorical_warning():
col_result = result[column.name].unstack()
expected_col_result = self.expected_latest(column, cal_slice)
assert_frame_equal(col_result, expected_col_result)