Skip to content

Commit

Permalink
Add test stubs for transformations
Browse files Browse the repository at this point in the history
  • Loading branch information
terrorfisch committed Dec 5, 2022
1 parent 3629669 commit bda7bad
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions tests/_program/transformation_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def get_output_channels(self, input_channels):
def get_input_channels(self, output_channels):
raise NotImplementedError()

def get_constant_output_channels(self, input_channels):
raise NotImplementedError()

@property
def compare_key(self):
return id(self)
Expand Down Expand Up @@ -169,6 +172,9 @@ def test_constant_propagation(self):
trafo = LinearTransformation(matrix, in_chs, out_chs)
self.assertTrue(trafo.is_constant_invariant())

def test_time_dependence(self):
raise NotImplementedError()


class IdentityTransformationTests(unittest.TestCase):
def test_compare_key(self):
Expand Down Expand Up @@ -305,6 +311,8 @@ def test_init(self):
self.assertEqual(trafo.get_output_channels({'X'}), {'X', 'Y', 'Z'})
self.assertEqual(trafo.get_output_channels({'X', 'Z', 'K'}), {'X', 'Y', 'Z', 'K'})

self.assertEqual(trafo.get_constant_output_channels({'X', 'Y', 'Z', 'K'}), {'X', 'Y', 'K'})

def test_trafo(self):
channels = {'X': 2, 'Y': 4.4, 'Z': ExpressionScalar('t')}
trafo = ParallelChannelTransformation(channels)
Expand Down Expand Up @@ -345,6 +353,9 @@ def test_constant_propagation(self):
trafo = ParallelChannelTransformation(channels)
self.assertTrue(trafo.is_constant_invariant())

def test_time_dependence(self):
raise NotImplementedError()


class TestChaining(unittest.TestCase):
def test_identity_result(self):
Expand Down Expand Up @@ -377,6 +388,9 @@ def test_chaining(self):

self.assertEqual(result, expected)

def test_constant_propagation(self):
raise NotImplementedError()


class TestOffsetTransformation(unittest.TestCase):
def setUp(self) -> None:
Expand Down Expand Up @@ -430,6 +444,9 @@ def test_constant_propagation(self):
constant_trafo = OffsetTransformation({'a': 7, 'b': 8.})
self.assertTrue(constant_trafo.is_constant_invariant())

def test_time_dependence(self):
raise NotImplementedError()


class TestScalingTransformation(unittest.TestCase):
def setUp(self) -> None:
Expand Down Expand Up @@ -482,3 +499,6 @@ def test_constant_propagation(self):
const_trafo = ScalingTransformation(self.constant_scales)
self.assertFalse(trafo.is_constant_invariant())
self.assertTrue(const_trafo.is_constant_invariant())

def test_time_dependence(self):
raise NotImplementedError()

0 comments on commit bda7bad

Please sign in to comment.