forked from scrtlabs/catalyst
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_adjustment.py
102 lines (91 loc) · 2.81 KB
/
test_adjustment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
Tests for catalyst.lib.adjustment
"""
from unittest import TestCase
from nose_parameterized import parameterized
from catalyst.lib import adjustment as adj
from catalyst.utils.numpy_utils import make_datetime64ns
class AdjustmentTestCase(TestCase):
@parameterized.expand([
('add', adj.ADD),
('multiply', adj.MULTIPLY),
('overwrite', adj.OVERWRITE),
])
def test_make_float_adjustment(self, name, adj_type):
expected_types = {
'add': adj.Float64Add,
'multiply': adj.Float64Multiply,
'overwrite': adj.Float64Overwrite,
}
result = adj.make_adjustment_from_indices(
1, 2, 3, 4,
adjustment_kind=adj_type,
value=0.5,
)
expected = expected_types[name](
first_row=1,
last_row=2,
first_col=3,
last_col=4,
value=0.5,
)
self.assertEqual(result, expected)
def test_make_int_adjustment(self):
result = adj.make_adjustment_from_indices(
1, 2, 3, 4,
adjustment_kind=adj.OVERWRITE,
value=1,
)
expected = adj.Int64Overwrite(
first_row=1,
last_row=2,
first_col=3,
last_col=4,
value=1,
)
self.assertEqual(result, expected)
def test_make_datetime_adjustment(self):
overwrite_dt = make_datetime64ns(0)
result = adj.make_adjustment_from_indices(
1, 2, 3, 4,
adjustment_kind=adj.OVERWRITE,
value=overwrite_dt,
)
expected = adj.Datetime64Overwrite(
first_row=1,
last_row=2,
first_col=3,
last_col=4,
value=overwrite_dt,
)
self.assertEqual(result, expected)
@parameterized.expand([("some text",), ("some text".encode(),), (None,)])
def test_make_object_adjustment(self, value):
result = adj.make_adjustment_from_indices(
1, 2, 3, 4,
adjustment_kind=adj.OVERWRITE,
value=value,
)
expected = adj.ObjectOverwrite(
first_row=1,
last_row=2,
first_col=3,
last_col=4,
value=value,
)
self.assertEqual(result, expected)
def test_unsupported_type(self):
class SomeClass(object):
pass
with self.assertRaises(TypeError) as e:
adj.make_adjustment_from_indices(
1, 2, 3, 4,
adjustment_kind=adj.OVERWRITE,
value=SomeClass(),
)
exc = e.exception
expected_msg = (
"Don't know how to make overwrite adjustments for values of type "
"%r." % SomeClass
)
self.assertEqual(str(exc), expected_msg)