-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_CWatcherWithExtras.py
125 lines (105 loc) · 4.66 KB
/
test_CWatcherWithExtras.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import pytest
import tensorflow as tf
from Utils.utils import CFakeObject
from NN.restorators.samplers import sampler_from_config
from NN.restorators.samplers.CSamplerWatcher import CSamplerWatcher
from NN.restorators.samplers.CWatcherWithExtras import CWatcherWithExtras
from Utils import colors
def _fake_sampler(stochasticity=1.0, timesteps=10):
interpolant = sampler_from_config({
'name': 'DDIM',
'stochasticity': stochasticity,
'noise stddev': 'zero',
'schedule': {
'name': 'discrete',
'beta schedule': 'linear',
'timesteps': timesteps,
},
'steps skip type': { 'name': 'uniform', 'K': 1 },
})
shape = (32, 3)
fakeNoise = tf.random.normal(shape)
def fakeModel(V, T, **kwargs):
return fakeNoise + tf.cast(T, tf.float32) * V
x = tf.random.normal(shape)
return CFakeObject(x=x, model=fakeModel, interpolant=interpolant)
def test_sameResults():
fake = _fake_sampler()
watcher = CSamplerWatcher(
steps=10,
tracked=dict(value=(32, 3), x0=(32, 3), x1=(32, 3))
)
resultsA = fake.interpolant.sample(value=fake.x, model=fake.model, algorithmInterceptor=watcher.interceptor())
residuals = tf.random.normal((32, 3))
watcherB = CWatcherWithExtras(watcher=watcher, converter=None, residuals=residuals)
resultsB = fake.interpolant.sample(value=fake.x, model=fake.model, algorithmInterceptor=watcherB.interceptor())
tf.debugging.assert_near(resultsA, resultsB)
return
def test_sameResults_convert():
fake = _fake_sampler()
watcher = CSamplerWatcher(
steps=10,
tracked=dict(value=(32, 3), x0=(32, 3), x1=(32, 3))
)
resultsA = fake.interpolant.sample(value=fake.x, model=fake.model, algorithmInterceptor=watcher.interceptor())
# New watcher for converted values
watcher = CSamplerWatcher(
steps=10,
tracked=dict(value=(32, 1), x0=(32, 1), x1=(32, 1))
)
converter = CFakeObject(
convertBack=lambda x: tf.random.normal((32, 1))
)
watcherB = CWatcherWithExtras(watcher=watcher, converter=converter, residuals=None)
resultsB = fake.interpolant.sample(value=fake.x, model=fake.model, algorithmInterceptor=watcherB.interceptor())
tf.debugging.assert_near(resultsA, resultsB)
return
def test_sameResults_nested():
fake = _fake_sampler()
watcher = CSamplerWatcher(
steps=10,
tracked=dict(value=(32, 3), x0=(32, 3), x1=(32, 3))
)
resultsA = fake.interpolant.sample(value=fake.x, model=fake.model, algorithmInterceptor=watcher.interceptor())
# New watcher for converted values
watcher = CSamplerWatcher(
steps=10,
tracked=dict(value=(32, 1), x0=(32, 1), x1=(32, 1))
)
converter = CFakeObject(
convertBack=lambda x: tf.random.normal((32, 1))
)
watcherB = CWatcherWithExtras(watcher=watcher, converter=converter, residuals=None)
watcherC = CWatcherWithExtras(watcher=watcherB, converter=converter, residuals=None)
resultsB = fake.interpolant.sample(value=fake.x, model=fake.model, algorithmInterceptor=watcherC.interceptor())
tf.debugging.assert_near(resultsA, resultsB)
return
@pytest.mark.parametrize('field', ['value', 'x0', 'x1'])
def test_shiftedValues(field):
fake = _fake_sampler()
watcher = CSamplerWatcher(steps=10, tracked={field: (32, 3)})
fake.interpolant.sample(value=fake.x, model=fake.model, algorithmInterceptor=watcher.interceptor())
valuesA = watcher.tracked(field).numpy()
residuals = tf.random.normal((32, 3))
watcherB = CWatcherWithExtras(watcher=watcher, converter=None, residuals=residuals)
fake.interpolant.sample(value=fake.x, model=fake.model, algorithmInterceptor=watcherB.interceptor())
valuesB = watcherB.tracked(field).numpy()
diff = valuesA + residuals[None] - valuesB
for i, (value, res) in enumerate(zip(diff.numpy().flatten(), residuals.numpy().flatten())):
assert value < 1e-6, f'Error at index {i}: {value} (residual: {res})'
return
@pytest.mark.parametrize('field', ['value', 'x0', 'x1'])
def test_transformedValues(field):
fake = _fake_sampler()
watcher = CSamplerWatcher(steps=10, tracked={field: (32, 3)})
converter = colors.convertRGBtoLAB()
fake.interpolant.sample(value=fake.x, model=fake.model, algorithmInterceptor=watcher.interceptor())
valuesA = watcher.tracked(field).numpy()
residuals = tf.random.normal((32, 3))
watcherB = CWatcherWithExtras(watcher=watcher, converter=converter, residuals=residuals)
fake.interpolant.sample(value=fake.x, model=fake.model, algorithmInterceptor=watcherB.interceptor())
valuesB = watcherB.tracked(field).numpy()
diff = converter.convertBack(valuesA) + residuals[None] - valuesB
for i, (value, res) in enumerate(zip(diff.numpy().flatten(), residuals.numpy().flatten())):
assert value < 1e-5, f'Error at index {i}: {value} (residual: {res})'
return