Skip to content

Commit

Permalink
some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenWizard2015 committed Apr 3, 2024
1 parent 0b9f38b commit 0bf7637
Show file tree
Hide file tree
Showing 9 changed files with 95 additions and 32 deletions.
4 changes: 2 additions & 2 deletions NN/Renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ def batched(self, ittr, B, N, batchSize=None, training=False):
for i in tf.range(NBatches):
index = i * stepBy
data = ittr(index, stepBy)
V = self._restorator.reverse(**data, training=training)
V = self._restorator.reverse(**data, training=training, index=index)
C = tf.shape(V)[-1]
res = res.write(i, tf.reshape(V, (B, stepBy, C)))
continue
#################
index = NBatches * stepBy

data = ittr(index, N - index)
V = self._restorator.reverse(**data, training=training)
V = self._restorator.reverse(**data, training=training, index=index)
C = tf.shape(V)[-1]

w = N - index
Expand Down
4 changes: 2 additions & 2 deletions NN/RestorationModel/CRepeatedRestorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ def call(self, latents, pos, T, V, residual, training=None):
continue
return V

def reverse(self, latents, pos, reverseArgs, training, value, residual):
def reverse(self, latents, pos, reverseArgs, training, value, residual, index):
for i in range(self._N):
if tf.is_tensor(value): value = tf.stop_gradient(value)
value = self._restorator.reverse(
latents=self._withID(latents, i, training),
pos=pos, reverseArgs=reverseArgs,
residual=residual,
training=training, value=value
training=training, value=value, index=index
)
continue
return value
Expand Down
5 changes: 3 additions & 2 deletions NN/RestorationModel/CRestorationModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def call(self, latents, pos, T, V, residual, training=None):
res = self._decoder(condition=latents, coords=EPos, timestep=t, V=V, training=training)
return self._withResidual(res, residual)

def reverse(self, latents, pos, reverseArgs, training, value, residual):
def reverse(self, latents, pos, reverseArgs, training, value, residual, index):
EPos = self._encodePos(pos, training=training, args=reverseArgs.get('decoder', {}))

def denoiser(x, t, mask=None):
Expand All @@ -65,7 +65,8 @@ def denoiser(x, t, mask=None):
return self._restorator.reverse(
value=value, denoiser=denoiser,
modelT=lambda t: self._encodeTime(t, training=training),
**reverseArgs
**reverseArgs,
index=index
)

def train_step(self, x0, latents, positions, params, xT=None):
Expand Down
4 changes: 2 additions & 2 deletions NN/RestorationModel/CSequentialRestorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ def call(self, latents, pos, T, V, residual, training=None):
continue
return V

def reverse(self, latents, pos, reverseArgs, training, value, residual):
def reverse(self, latents, pos, reverseArgs, training, value, residual, index):
for restorator in self._restorators:
if tf.is_tensor(value): value = tf.stop_gradient(value)
value = restorator.reverse(
latents=latents, pos=pos, reverseArgs=reverseArgs,
residual=residual,
training=training, value=value
training=training, value=value, index=index
)
continue
return value
Expand Down
4 changes: 2 additions & 2 deletions NN/restorators/CARProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ def denoiser(x, t=None, **kwargs):
return model(x=x, t=T, mask=kwargs.get('mask', None))[:, :self.predictions]
return denoiser

def reverse(self, value, denoiser, modelT=None, **kwargs):
def reverse(self, value, denoiser, modelT=None, index=0, **kwargs):
if isinstance(value, tuple):
value = self._sourceDistribution.initialValueFor(value + (self.predictions, ))

denoiser = self._makeDenoiser(denoiser, modelT)
res = self._sampler.sample(value=value, model=denoiser, **kwargs)
res = self._sampler.sample(value=value, model=denoiser, index=index, **kwargs)
tf.assert_equal(tf.shape(res), tf.shape(value))
return res

Expand Down
5 changes: 3 additions & 2 deletions NN/restorators/samplers/CBasicInterpolantSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ def __init__(self, interpolant, algorithm):
def interpolant(self): return self._interpolant

@tf.function
def sample(self, value, model, **kwargs):
kwargs = dict(**kwargs, interpolant=self._interpolant) # add interpolant to kwargs
def sample(self, value, model, index=0, **kwargs):
# add interpolant to kwargs and index
kwargs = dict(**kwargs, interpolant=self._interpolant, index=index)
# wrap algorithm with hook, if provided
algorithm = kwargs.get('algorithmInterceptor', lambda x: x)( self._algorithm )
assert isinstance(algorithm, ISamplingAlgorithm), f'Algorithm must be an instance of ISamplingAlgorithm, but got {type(algorithm)}'
Expand Down
56 changes: 38 additions & 18 deletions NN/restorators/samplers/CSamplerWatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,26 +39,27 @@ def tracked(self, name):
res = tf.concat([self._initialValue[None], res], axis=0)
return res

def _updateTracked(self, name, value, mask=None):
def _updateTracked(self, name, value, mask=None, index=None):
tracked = self._tracked.get(name, None)
if tracked is None: return
value = self._withIndices(value)
value, idx = self._withIndices(value, index, mask=mask)

iteration = self._iteration
if (mask is None) or ('value' == name): # 'value' is always unmasked
tracked[iteration].assign(value)
return

mask = self._withIndices(mask)
prev = tracked[iteration - 1]
# expand mask to match the value shape by copying values from the previous iteration
indices = tf.where(mask)
value = tf.tensor_scatter_nd_update(prev, indices, value)
tf.assert_equal(tf.shape(prev), tf.shape(value), 'Must be the same shape')
tracked[iteration].assign(value)
if (mask is not None) and not ('value' == name): # 'value' is always unmasked
mask, _ = self._withIndices(mask, index, mask=mask)
prev = self._tracked[name][iteration]
# expand mask to match the value shape by copying values from the previous iteration
indices = tf.where(mask)
sz = tf.shape(value)[0]
value = tf.tensor_scatter_nd_update(prev[idx:idx+sz], indices, value)
pass

sz = tf.shape(value)[0]
tracked[iteration, idx:idx+sz].assign(value)
return

def _onNextStep(self, iteration, kwargs):
index = kwargs['index']
self._iteration.assign(iteration)
# track also solution
solution = kwargs['solution']
Expand All @@ -68,17 +69,36 @@ def _onNextStep(self, iteration, kwargs):
mask = step.mask if hasattr(step, 'mask') else None
# iterate over all fields
for name in solution._fields:
self._updateTracked(name, getattr(solution, name), mask=mask)
self._updateTracked(name, getattr(solution, name), mask=mask, index=index)
continue
return

def _onStart(self, value, kwargs):
index = kwargs['index']
self._iteration.assign(0)
if 'value' in self._tracked: # save initial value
self._initialValue.assign( self._withIndices(value) )
value, idx = self._withIndices(value, index)
# update slice [index:index+sz] with the value
sz = tf.shape(value)[0]
self._initialValue[idx:idx+sz].assign(value)
return

def _withIndices(self, value):
if self._indices is None: return value
return tf.gather(value, self._indices, axis=0)
def _withIndices(self, value, index, mask=None):
if self._indices is None: return value, index
# find subset of indices
sz = tf.shape(value)[0]

validMask = tf.logical_and(index <= self._indices, self._indices < index + sz)
if mask is not None:
maskedIndices = tf.range(sz)
maskedIndices = tf.boolean_mask(maskedIndices, mask) + index
# exclude masked indices
maskedIndices = tf.reduce_any(maskedIndices[:, None] == self._indices[None], axis=0)
validMask = tf.logical_and(validMask, maskedIndices)
pass

startIndex = tf.reduce_min(tf.where(validMask))
startIndex = tf.cast(startIndex, tf.int32)
indices = tf.boolean_mask(self._indices, validMask) - index
return tf.gather(value, indices, axis=0), startIndex
# End of CSamplerWatcher
2 changes: 1 addition & 1 deletion huggingface/HF/NN/CInterpolantVisualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _collectSteps(self, points, initialValues, kwargs):
initialValues=initialValues,
reverseArgs=dict(
**reverseArgs,
algorithmInterceptor=watcher.interceptor(),
algorithmInterceptor=watcher,
),
)
N = watcher.iteration + 1
Expand Down
43 changes: 42 additions & 1 deletion tests/test_CSamplerWatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _fake_AR(threshold, timesteps=10):
shape = (32, 3)
fakeNoise = tf.random.normal(shape)
def fakeModel(x, t, mask, **kwargs):
s = tf.boolean_mask(fakeNoise, mask)
s = tf.boolean_mask(fakeNoise, mask) if mask is not None else fakeNoise
return s + tf.cast(t, tf.float32) * x

x = tf.random.normal(shape)
Expand Down Expand Up @@ -153,4 +153,45 @@ def test_trackSolutionWithMask_value():
)
_ = fake.interpolant.sample(value=fake.x, model=fake.model, algorithmInterceptor=watcher.interceptor())
_checkTracked(watcher.tracked('value'), N=11)
return

# test multiple calls with index
def test_multipleCallsWithIndex():
fake = _fake_sampler()
watcher = CSamplerWatcher(
steps=10,
tracked=dict(value=(32*3, 3))
)
arg = dict(value=fake.x, model=fake.model, algorithmInterceptor=watcher.interceptor())
_ = fake.interpolant.sample(**arg, index=0)
_ = fake.interpolant.sample(**arg, index=32)
_ = fake.interpolant.sample(**arg, index=64)

collectedSteps = watcher.tracked('value')
tf.debugging.assert_equal(tf.shape(collectedSteps)[1], 96, 'Must collect 96 values')
tf.debugging.assert_equal(watcher.iteration, 10, 'Must collect 10 steps')
# values must be same across (0..32), (32..64), (64..96)
tf.debugging.assert_equal(collectedSteps[:, :32], collectedSteps[:, 32:64])
tf.debugging.assert_equal(collectedSteps[:, 32:64], collectedSteps[:, 64:])
return

# TODO: find out why this test fails, but previous one passes
# test multiple calls with index and mask
def test_multipleCallsWithIndexAndMask():
fake = _fake_AR(threshold=0.1)
watcher = CSamplerWatcher(
steps=10,
tracked=dict(value=(3,)),
indices=[0, 32, 64, 65]
)
arg = dict(value=fake.x, model=fake.model, algorithmInterceptor=watcher.interceptor())
_ = fake.interpolant.sample(**arg, index=0)
_ = fake.interpolant.sample(**arg, index=32)
_ = fake.interpolant.sample(**arg, index=64)

collectedSteps = watcher.tracked('value')
tf.assert_equal(tf.shape(collectedSteps), (11, 4, 3), 'Must be (11, 4, 3)')
tf.debugging.assert_equal(watcher.iteration, 10, 'Must collect 10 steps')
tf.debugging.assert_equal(collectedSteps[:, 0:1], collectedSteps[:, 1:2])
tf.debugging.assert_equal(collectedSteps[:, 1:2], collectedSteps[:, 2:3])
return

0 comments on commit 0bf7637

Please sign in to comment.