-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_utils_steps.py
60 lines (49 loc) · 2.34 KB
/
test_utils_steps.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import tensorflow as tf
from NN.utils import make_steps_sequence
import pytest
def _test_common(config):
for i in range(3, 100):
steps, prevSteps = make_steps_sequence(startStep=i, endStep=0, config=config)
tf.assert_equal(tf.shape(steps), tf.shape(prevSteps))
# steps should be strictly decreasing
tf.assert_less(steps[1:], steps[:-1], message='Steps are not strictly decreasing')
tf.assert_less(prevSteps[1:], prevSteps[:-1], message='PrevSteps are not strictly decreasing')
tf.assert_equal(steps[0], [i - 1], message='First step is not i - 1')
tf.assert_equal(steps[-1], [1], message='Last step is not 1')
tf.assert_equal(prevSteps[-1], [0], message='Last prevStep is not 0')
tf.assert_equal(steps[-2], [2], message='Second to last step is not 2')
tf.assert_equal(prevSteps[-2], [1], message='Second to last prevStep is not 1')
tf.assert_equal(steps[1:], prevSteps[:-1], message='Steps and prevSteps are not equal except for the first element and the last element')
tf.assert_equal(tf.size(steps), tf.size(tf.unique(steps).y), message='Steps has duplicates')
tf.assert_equal(tf.size(prevSteps), tf.size(tf.unique(prevSteps).y), message='PrevSteps has duplicates')
continue
return
@pytest.mark.parametrize('K', list(range(1, 10)))
def test_uniform_K_common(K):
_test_common({ 'name': 'uniform', 'K': K })
return
def test_quadratic_common():
_test_common( 'quadratic' )
return
def test_uniform_steps():
config = { 'name': 'uniform', 'K': 3 }
steps, prevSteps = make_steps_sequence(10, 0, config=config)
tf.assert_equal(steps, [9, 2 + 3 + 3, 2 + 3, 2, 1])
tf.assert_equal(prevSteps, [2 + 3 + 3, 2 + 3, 2, 1, 0])
return
def test_quadratic_steps_no_duplicate():
steps, prevSteps = make_steps_sequence(17, 0, config='quadratic')
tf.assert_equal(steps, [16, 8, 4, 2, 1])
tf.assert_equal(prevSteps, [8, 4, 2, 1, 0])
return
def test_quadratic_steps_case1():
steps, prevSteps = make_steps_sequence(21, 3, config='quadratic')
tf.assert_equal(steps, [20, 19, 11, 7, 5, 4])
tf.assert_equal(prevSteps, [19, 11, 7, 5, 4, 3])
return
def test_steps_K1():
config = { 'name': 'uniform', 'K': 1 }
steps, prevSteps = make_steps_sequence(10 + 1, 0 - 1, config=config)
tf.assert_equal(steps, [10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0])
tf.assert_equal(prevSteps, [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, -1])
return