Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenWizard2015 committed Apr 4, 2024
1 parent 6ea65fb commit f9d52c5
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 48 deletions.
72 changes: 38 additions & 34 deletions NN/restorators/samplers/CSamplerWatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,15 @@ def tracked(self, name):
res = tf.concat([self._initialValue[None], res], axis=0)
return res

def _updateTracked(self, name, value, mask=None, index=None):
def _updateTracked(self, name, value, mask=None, index=None, iteration=None):
tracked = self._tracked.get(name, None)
if tracked is None: return
value, idx = self._withIndices(value, index, mask=mask)
src, dest = self._withIndices(
value, index, mask=mask,
masked=(mask is not None) and not ('value' == name)
)

iteration = self._iteration
if (mask is not None) and not ('value' == name): # 'value' is always unmasked
mask, _ = self._withIndices(mask, index, mask=mask)
prev = self._tracked[name][iteration]
# expand mask to match the value shape by copying values from the previous iteration
indices = tf.where(mask)
sz = tf.shape(value)[0]
value = tf.tensor_scatter_nd_update(prev[idx:idx+sz], indices, value)
pass

sz = tf.shape(value)[0]
tracked[iteration, idx:idx+sz].assign(value)
self._move(value, src, tracked[iteration], dest)
return

def _onNextStep(self, iteration, kwargs):
Expand All @@ -69,36 +61,48 @@ def _onNextStep(self, iteration, kwargs):
mask = step.mask if hasattr(step, 'mask') else None
# iterate over all fields
for name in solution._fields:
self._updateTracked(name, getattr(solution, name), mask=mask, index=index)
self._updateTracked(
name, getattr(solution, name),
mask=mask, index=index, iteration=iteration
)
continue
return

def _onStart(self, value, kwargs):
index = kwargs['index']
self._iteration.assign(0)
if 'value' in self._tracked: # save initial value
value, idx = self._withIndices(value, index)
# update slice [index:index+sz] with the value
sz = tf.shape(value)[0]
self._initialValue[idx:idx+sz].assign(value)
src, dest = self._withIndices(value, index)
self._move(value, src, self._initialValue, dest)
return

def _withIndices(self, value, index, mask=None):
if self._indices is None: return value, index
# find subset of indices
sz = tf.shape(value)[0]
def _withIndices(self, value, index, mask=None, masked=False):
srcIndex = tf.range(tf.shape(value)[0])
if masked:
srcIndex = tf.boolean_mask(tf.range(tf.shape(mask)[0]), mask)

validMask = tf.logical_and(index <= self._indices, self._indices < index + sz)
destIndex = index + srcIndex
if mask is not None:
maskedIndices = tf.range(sz)
maskedIndices = tf.boolean_mask(maskedIndices, mask) + index
# exclude masked indices
maskedIndices = tf.reduce_any(maskedIndices[:, None] == self._indices[None], axis=0)
validMask = tf.logical_and(validMask, maskedIndices)
srcIndex = tf.boolean_mask(srcIndex, mask)
destIndex = tf.boolean_mask(destIndex, mask)
pass

startIndex = tf.reduce_min(tf.where(validMask))
startIndex = tf.cast(startIndex, tf.int32)
indices = tf.boolean_mask(self._indices, validMask) - index
return tf.gather(value, indices, axis=0), startIndex

if self._indices is not None:
mask = tf.reduce_any(self._indices[None] == destIndex[:, None], axis=0)
tf.assert_equal(tf.shape(mask), tf.shape(self._indices))
# collect only valid indices
srcIndex = tf.boolean_mask(self._indices, mask) - index
# collect destination indices
destIndex = tf.where(mask)
pass

srcIndex = tf.reshape(srcIndex, (-1, 1))
destIndex = tf.reshape(destIndex, (-1, 1))
return srcIndex, destIndex

def _move(self, src, srcIndex, dest, destIndex):
src = tf.gather_nd(src, srcIndex) # collect only valid indices
res = tf.tensor_scatter_nd_update(dest, destIndex, src)
dest.assign(res)
return res
# End of CSamplerWatcher
34 changes: 20 additions & 14 deletions tests/test_CSamplerWatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def fakeModel(x, T, **kwargs):
def _fake_AR(threshold, timesteps=10):
interpolant = sampler_from_config({
"name": "autoregressive",
"noise provider": "normal",
"noise provider": "zero",
"threshold": threshold,
"steps": {
"start": 1.0,
Expand Down Expand Up @@ -146,26 +146,31 @@ def test_trackSolutionWithMask():
return

def test_trackSolutionWithMask_value():
fake = _fake_AR(threshold=0.1)
fake = _fake_AR(threshold=0.5)
watcher = CSamplerWatcher(
steps=10,
tracked=dict(value=(32, 3))
tracked=dict(value=(32, 3), x0=(32, 3), x1=(32, 3))
)
_ = fake.interpolant.sample(value=fake.x, model=fake.model, algorithmInterceptor=watcher.interceptor())
_checkTracked(watcher.tracked('value'), N=11)
_checkTracked(watcher.tracked('x0'), N=10)
_checkTracked(watcher.tracked('x1'), N=10)
return

# test multiple calls with index
def test_multipleCallsWithIndex():
fake = _fake_sampler()
fake = _fake_AR(threshold=0.1)
watcher = CSamplerWatcher(
steps=10,
tracked=dict(value=(32*3, 3))
)
arg = dict(value=fake.x, model=fake.model, algorithmInterceptor=watcher.interceptor())
_ = fake.interpolant.sample(**arg, index=0)
_ = fake.interpolant.sample(**arg, index=32)
_ = fake.interpolant.sample(**arg, index=64)
A = fake.interpolant.sample(**arg, index=0)
B = fake.interpolant.sample(**arg, index=32)
C = fake.interpolant.sample(**arg, index=64)

tf.debugging.assert_equal(A, B)
tf.debugging.assert_equal(A, C)

collectedSteps = watcher.tracked('value')
tf.debugging.assert_equal(tf.shape(collectedSteps)[1], 96, 'Must collect 96 values')
Expand All @@ -182,16 +187,17 @@ def test_multipleCallsWithIndexAndMask():
watcher = CSamplerWatcher(
steps=10,
tracked=dict(value=(3,)),
indices=[0, 32, 64, 65]
indices=[0, 32]
)
arg = dict(value=fake.x, model=fake.model, algorithmInterceptor=watcher.interceptor())
_ = fake.interpolant.sample(**arg, index=0)
_ = fake.interpolant.sample(**arg, index=32)
_ = fake.interpolant.sample(**arg, index=64)
A = fake.interpolant.sample(**arg, index=0)
B = fake.interpolant.sample(**arg, index=32)
C = fake.interpolant.sample(**arg, index=64)

collectedSteps = watcher.tracked('value')
tf.assert_equal(tf.shape(collectedSteps), (11, 4, 3), 'Must be (11, 4, 3)')
tf.debugging.assert_equal(watcher.iteration, 10, 'Must collect 10 steps')
tf.assert_equal(tf.shape(collectedSteps), (11, 2, 3), 'Must be (11, 4, 3)')
tf.assert_equal(A, B)
tf.assert_equal(A, C)
tf.debugging.assert_equal(collectedSteps[:, 0:1], collectedSteps[:, 1:2])
tf.debugging.assert_equal(collectedSteps[:, 1:2], collectedSteps[:, 2:3])
# tf.debugging.assert_equal(collectedSteps[:, 1:2], collectedSteps[:, 2:3])
return

0 comments on commit f9d52c5

Please sign in to comment.