Skip to content

Commit

Permalink
xxxxxxx
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenWizard2015 committed Apr 5, 2024
1 parent f9d52c5 commit c33735f
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 19 deletions.
5 changes: 3 additions & 2 deletions NN/RestorationModel/CRestorationModel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tensorflow as tf
from NN.utils import masked

'''
wrapper which combines the encoders, decoder and restorator together
Expand Down Expand Up @@ -56,8 +57,8 @@ def denoiser(x, t, mask=None):
args = dict(condition=latents, coords=EPos, timestep=t, V=x)
residuals = residual
if mask is not None:
args = {k: tf.boolean_mask(v, mask) for k, v in args.items()}
residuals = tf.boolean_mask(residual, mask)
args = {k: masked(v, mask) for k, v in args.items()}
residuals = masked(residual, mask)

res = self._decoder(**args, training=training)
return self._withResidual(res, residuals)
Expand Down
2 changes: 1 addition & 1 deletion NN/restorators/samplers/CARSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _currentValueHandler(threshold):
return lambda step: step.xt

def currentValueF(step, **kwargs):
return tf.boolean_mask(step.xt, step.mask, axis=0)
return NNU.masked(step.xt, step.mask)
return currentValueF

# crate a closure that will be used to postprocess the value
Expand Down
49 changes: 33 additions & 16 deletions NN/restorators/samplers/CSamplerWatcher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import tensorflow as tf
from NN.utils import is_namedtuple
import NN.utils as NNU
from .CSamplingInterceptor import CSamplingInterceptor
from .ISamplerWatcher import ISamplerWatcher

Expand Down Expand Up @@ -44,9 +45,11 @@ def _updateTracked(self, name, value, mask=None, index=None, iteration=None):
if tracked is None: return
src, dest = self._withIndices(
value, index, mask=mask,
masked=(mask is not None) and not ('value' == name)
masked=not ('value' == name)
)

tf.print('-'*80)
tf.print(name, src, dest, summarize=-1)
tf.print('N', tf.reduce_sum(tf.cast(mask, tf.int32)))
self._move(value, src, tracked[iteration], dest)
return

Expand Down Expand Up @@ -77,30 +80,44 @@ def _onStart(self, value, kwargs):
return

def _withIndices(self, value, index, mask=None, masked=False):
srcIndex = tf.range(tf.shape(value)[0])
if masked:
srcIndex = tf.boolean_mask(tf.range(tf.shape(mask)[0]), mask)

N = tf.shape(value)[0]
srcIndex = tf.range(N)
destIndex = index + srcIndex
if mask is not None:
srcIndex = tf.boolean_mask(srcIndex, mask)
destIndex = tf.boolean_mask(destIndex, mask)
if mask is not None: # use mask
N = tf.reduce_sum(tf.cast(mask, tf.int32))
destIndex = index + tf.cast(tf.where(mask), tf.int32)
destIndex = tf.reshape(destIndex, (N,))
if masked:
rng = tf.range(tf.shape(mask)[0])
srcIndex = NNU.masked(rng, mask)
pass
pass

if self._indices is not None:
mask = tf.reduce_any(self._indices[None] == destIndex[:, None], axis=0)
tf.assert_equal(tf.shape(mask), tf.shape(self._indices))
# collect only valid indices
srcIndex = tf.boolean_mask(self._indices, mask) - index
indices = tf.reshape(self._indices, (1, -1))
destIndex = tf.reshape(destIndex, (-1, 1))
correspondence = indices == destIndex
tf.assert_rank(correspondence, 2)
mask_ = tf.reduce_any(correspondence, axis=0)
tf.assert_equal(tf.shape(mask_), tf.shape(self._indices))
# collect destination indices
destIndex = tf.where(mask)
destIndex = tf.where(mask_)
destIndex = tf.cast(destIndex, tf.int32)
# find corresponding source indices
mask_ = tf.reduce_any(correspondence, axis=-1)
tf.assert_equal(tf.shape(mask), tf.shape(srcIndex))
srcIndex = tf.where(mask_)
srcIndex = tf.cast(srcIndex, tf.int32)
N = tf.reduce_sum(tf.cast(mask_, tf.int32))
pass

srcIndex = tf.reshape(srcIndex, (-1, 1))
destIndex = tf.reshape(destIndex, (-1, 1))
srcIndex = tf.reshape(srcIndex, (N, 1))
destIndex = tf.reshape(destIndex, (N, 1))
return srcIndex, destIndex

def _move(self, src, srcIndex, dest, destIndex):
tf.print(tf.shape(src), tf.shape(srcIndex), tf.shape(dest), tf.shape(destIndex))
tf.print(srcIndex, destIndex, summarize=-1)
src = tf.gather_nd(src, srcIndex) # collect only valid indices
res = tf.tensor_scatter_nd_update(dest, destIndex, src)
dest.assign(res)
Expand Down
8 changes: 8 additions & 0 deletions NN/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@
import tensorflow_probability as tfp
import tensorflow.keras.layers as L

def masked(x, mask):
'''
very weird hack to apply a mask to the tensor and ensure that the shape is preserved
'''
N = tf.reduce_sum(tf.cast(mask, tf.int32))
x = tf.boolean_mask(x, mask, axis=0)
return tf.reshape(x, tf.concat([[N], tf.shape(x)[1:]], axis=0))

def shuffleBatch(batch):
indices = tf.range(tf.shape(batch)[0])
indices = tf.random.shuffle(indices)
Expand Down

0 comments on commit c33735f

Please sign in to comment.