Skip to content

Commit

Permalink
make R starting from 0.0
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenWizard2015 committed Apr 10, 2024
1 parent 5ca411a commit 24d88d0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
2 changes: 1 addition & 1 deletion NN/RestorationModel/CRestorationModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _addResiduals(self, latents, residuals):

return latents

def _addRadius(self, latents, R=None, fakeR=1e-5, training=False):
def _addRadius(self, latents, R=None, fakeR=0.0, training=False):
B = tf.shape(latents)[0]
if self._blurRadiusEncoder is not None:
if R is None:
Expand Down
28 changes: 15 additions & 13 deletions Utils/CroppingAugm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def SubsampleProcessor(target_crop_size, N, extras=[], sampler='uniform'):
withBlur = blurConfig is not None
if withBlur:
blurRange = blurConfig['min'] + tf.linspace(0.0, blurConfig['max'], blurConfig['N'])
minR = tf.reduce_min(blurRange)
blurN = tf.size(blurRange)
blurShared = blurConfig.get('shared', False)
if blurShared:
Expand All @@ -60,20 +61,21 @@ def _FF(img):
sobel = extractInterpolated(sobel, positions)
res['sobel'] = tf.reshape(sobel, [N, 6])

if withBlur and blurShared:
idx = tf.random.uniform((1,), minval=0, maxval=blurN, dtype=tf.int32)
R = tf.gather(blurRange, idx)
R = tf.reshape(R, (1,))
if withBlur:
if blurShared:
idx = tf.random.uniform((1,), minval=0, maxval=blurN, dtype=tf.int32)
R = tf.gather(blurRange, idx)
R = tf.reshape(R, (1,))
R = tf.fill([N, 1], R[0])
else:
idx = tf.random.uniform((N,), minval=0, maxval=blurN, dtype=tf.int32)
R = tf.gather(blurRange, idx)
R = tf.reshape(R, (N, 1))
pass

tf.assert_equal(tf.shape(R), (N, 1))
res['blured'] = blur(src, positions[0], R)
res['blur R'] = tf.fill([N, 1], R[0])
pass

if withBlur and not blurShared:
idx = tf.random.uniform((N,), minval=0, maxval=blurN, dtype=tf.int32)
R = tf.gather(blurRange, idx)
R = tf.reshape(R, (N, 1))
res['blured'] = blur(src, positions[0], R)
res['blur R'] = R
res['blur R'] = R - minR # ensure that R is starting from 0.0
pass
return res
return _FF
Expand Down

0 comments on commit 24d88d0

Please sign in to comment.