Skip to content

Commit

Permalink
Finally its works!
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenWizard2015 committed Apr 5, 2024
1 parent c33735f commit 3809acc
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 49 deletions.
78 changes: 39 additions & 39 deletions NN/restorators/samplers/CSamplerWatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
class CSamplerWatcher(ISamplerWatcher):
def __init__(self, steps, tracked, indices=None):
super().__init__()
self._indices = tf.constant(indices, dtype=tf.int32) if not(indices is None) else None
self._tracked = {}
prefix = [steps]
if not(self._indices is None):
self._indices = None
if not(indices is None):
self._indices = tf.reshape(tf.constant(indices, dtype=tf.int32), (1, -1))
prefix = [steps, tf.size(self._indices)]

for name, shape in tracked.items():
Expand All @@ -20,6 +21,7 @@ def __init__(self, steps, tracked, indices=None):

if 'value' in self._tracked: # value has steps + 1 shape, so we need extra variable
shp = prefix + list(tracked['value'])
print('Initial value shape:', shp)
self._initialValue = tf.Variable(tf.zeros(shp[1:]), trainable=False)
pass

Expand All @@ -43,13 +45,13 @@ def tracked(self, name):
def _updateTracked(self, name, value, mask=None, index=None, iteration=None):
tracked = self._tracked.get(name, None)
if tracked is None: return
src, dest = self._withIndices(
src, dest, unchangedIdx = self._withIndices(
value, index, mask=mask,
masked=not ('value' == name)
)
tf.print('-'*80)
tf.print(name, src, dest, summarize=-1)
tf.print('N', tf.reduce_sum(tf.cast(mask, tf.int32)))
prev = self.tracked(name)[iteration - 1]
self._move(prev, unchangedIdx, tracked[iteration], unchangedIdx)
tf.print(unchangedIdx)
self._move(value, src, tracked[iteration], dest)
return

Expand All @@ -75,51 +77,49 @@ def _onStart(self, value, kwargs):
index = kwargs['index']
self._iteration.assign(0)
if 'value' in self._tracked: # save initial value
src, dest = self._withIndices(value, index)
src, dest, _ = self._withIndices(value, index)
self._move(value, src, self._initialValue, dest)
return

def _withIndices(self, value, index, mask=None, masked=False):
N = tf.shape(value)[0]
srcIndex = tf.range(N)
destIndex = index + srcIndex
unchanged = tf.constant([], dtype=tf.int32)
srcIndex = tf.range(tf.shape(value)[0])
if mask is not None: # use mask
N = tf.reduce_sum(tf.cast(mask, tf.int32))
destIndex = index + tf.cast(tf.where(mask), tf.int32)
destIndex = tf.reshape(destIndex, (N,))
if masked:
rng = tf.range(tf.shape(mask)[0])
srcIndex = NNU.masked(rng, mask)
pass
pass
unchanged = tf.logical_not(mask)
unchanged = tf.cast(tf.where(unchanged), tf.int32) + index

if not masked:
whereIdx = tf.where(mask)
srcIndex = tf.cast(whereIdx, tf.int32)
pass
destIndex = index + srcIndex

if self._indices is not None:
indices = tf.reshape(self._indices, (1, -1))
destIndex = tf.reshape(destIndex, (-1, 1))
correspondence = indices == destIndex
tf.assert_rank(correspondence, 2)
mask_ = tf.reduce_any(correspondence, axis=0)
tf.assert_equal(tf.shape(mask_), tf.shape(self._indices))
# collect destination indices
destIndex = tf.where(mask_)
destIndex = tf.cast(destIndex, tf.int32)
# find corresponding source indices
mask_ = tf.reduce_any(correspondence, axis=-1)
tf.assert_equal(tf.shape(mask), tf.shape(srcIndex))
srcIndex = tf.where(mask_)
srcIndex = tf.cast(srcIndex, tf.int32)
N = tf.reduce_sum(tf.cast(mask_, tf.int32))
unchanged = self._index2index(unchanged, axis=0)
srcIndex = self._index2index(destIndex, axis=1)
destIndex = self._index2index(destIndex, axis=0)
pass

srcIndex = tf.reshape(srcIndex, (N, 1))
destIndex = tf.reshape(destIndex, (N, 1))
return srcIndex, destIndex
return srcIndex, destIndex, unchanged

def _move(self, src, srcIndex, dest, destIndex):
tf.print(tf.shape(src), tf.shape(srcIndex), tf.shape(dest), tf.shape(destIndex))
tf.print(srcIndex, destIndex, summarize=-1)
# tensor_scatter_nd_update can't handle empty indices, so we need to check it
if tf.size(srcIndex) == 0: return dest

srcIndex = tf.reshape(srcIndex, (-1, 1))
destIndex = tf.reshape(destIndex, (-1, 1))
src = tf.gather_nd(src, srcIndex) # collect only valid indices
res = tf.tensor_scatter_nd_update(dest, destIndex, src)
dest.assign(res)
return res

def _index2index(self, indices, axis=0):
N = tf.size(indices)
NN = tf.size(self._indices)
indices = tf.reshape(indices, (-1, 1))
indices = tf.cast(indices, tf.int32)
correspondence = self._indices == indices
tf.assert_equal(tf.shape(correspondence), (N, NN))
mask = tf.reduce_any(correspondence, axis=axis)
res = tf.where(mask)
return tf.cast(res, tf.int32)
# End of CSamplerWatcher
44 changes: 34 additions & 10 deletions tests/test_CSamplerWatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from Utils.utils import CFakeObject
from NN.restorators.samplers import sampler_from_config
from NN.restorators.samplers.CSamplerWatcher import CSamplerWatcher
from NN.utils import masked

def _fake_sampler(stochasticity=1.0, timesteps=10):
interpolant = sampler_from_config({
Expand All @@ -24,7 +25,7 @@ def fakeModel(x, T, **kwargs):
x = tf.random.normal(shape)
return CFakeObject(x=x, model=fakeModel, interpolant=interpolant)

def _fake_AR(threshold, timesteps=10):
def _fake_AR(threshold, timesteps=10, scale=1.0):
interpolant = sampler_from_config({
"name": "autoregressive",
"noise provider": "zero",
Expand All @@ -41,8 +42,13 @@ def _fake_AR(threshold, timesteps=10):
shape = (32, 3)
fakeNoise = tf.random.normal(shape)
def fakeModel(x, t, mask, **kwargs):
s = tf.boolean_mask(fakeNoise, mask) if mask is not None else fakeNoise
return s + tf.cast(t, tf.float32) * x
s = fakeNoise
if mask is not None:
s = masked(fakeNoise, mask)
t = masked(t, mask)
x = masked(x, mask)

return s + tf.cast(t, tf.float32) * x * scale

x = tf.random.normal(shape)
return CFakeObject(x=x, model=fakeModel, interpolant=interpolant)
Expand Down Expand Up @@ -159,7 +165,7 @@ def test_trackSolutionWithMask_value():

# test multiple calls with index
def test_multipleCallsWithIndex():
fake = _fake_AR(threshold=0.1)
fake = _fake_sampler()
watcher = CSamplerWatcher(
steps=10,
tracked=dict(value=(32*3, 3))
Expand All @@ -180,24 +186,42 @@ def test_multipleCallsWithIndex():
tf.debugging.assert_equal(collectedSteps[:, 32:64], collectedSteps[:, 64:])
return

# TODO: find out why this test fails, but previous one passes
# test multiple calls with index and mask
def test_multipleCallsWithIndexAndMask():
fake = _fake_AR(threshold=0.1)
watcher = CSamplerWatcher(
steps=10,
tracked=dict(value=(3,)),
indices=[0, 32]
indices=[0, 32, 64, 65]
)
arg = dict(value=fake.x, model=fake.model, algorithmInterceptor=watcher.interceptor())
A = fake.interpolant.sample(**arg, index=0)
B = fake.interpolant.sample(**arg, index=32)
C = fake.interpolant.sample(**arg, index=64)

collectedSteps = watcher.tracked('value')
tf.assert_equal(tf.shape(collectedSteps), (11, 2, 3), 'Must be (11, 4, 3)')
tf.assert_equal(A, B)
tf.assert_equal(A, C)

collectedSteps = watcher.tracked('value')[:watcher.iteration]
tf.assert_equal(tf.shape(collectedSteps)[1:], (4, 3), 'Must be (4, 3)')
tf.debugging.assert_equal(collectedSteps[:, 0:1], collectedSteps[:, 1:2])
# tf.debugging.assert_equal(collectedSteps[:, 1:2], collectedSteps[:, 2:3])
tf.debugging.assert_equal(collectedSteps[:, 1:2], collectedSteps[:, 2:3])
return

# test that masked values aren't zeroed
def test_maskedValues():
fake = _fake_AR(threshold=1e+5, scale=0.0)
watcher = CSamplerWatcher(
steps=10,
tracked=dict(value=(32, 3)),
)
arg = dict(value=fake.x, model=fake.model, algorithmInterceptor=watcher.interceptor())
_ = fake.interpolant.sample(**arg, index=0)

collectedSteps = watcher.tracked('value')
afterMask = collectedSteps[3:]
beforeMask = collectedSteps[2]

tf.debugging.assert_greater(3, watcher.iteration, 'Must collect 3 steps')
for i in range(3, watcher.iteration + 1):
tf.debugging.assert_equal(afterMask[i], beforeMask, 'Must be equal')
return

0 comments on commit 3809acc

Please sign in to comment.