From 24d88d04b6a56e2293afdc59ae4623681fd7aaa3 Mon Sep 17 00:00:00 2001 From: GreenWizard Date: Wed, 10 Apr 2024 12:29:15 +0200 Subject: [PATCH] make R starting from 0.0 --- NN/RestorationModel/CRestorationModel.py | 2 +- Utils/CroppingAugm.py | 28 +++++++++++++----------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/NN/RestorationModel/CRestorationModel.py b/NN/RestorationModel/CRestorationModel.py index 4ed4df6..6aec5c8 100644 --- a/NN/RestorationModel/CRestorationModel.py +++ b/NN/RestorationModel/CRestorationModel.py @@ -56,7 +56,7 @@ def _addResiduals(self, latents, residuals): return latents - def _addRadius(self, latents, R=None, fakeR=1e-5, training=False): + def _addRadius(self, latents, R=None, fakeR=0.0, training=False): B = tf.shape(latents)[0] if self._blurRadiusEncoder is not None: if R is None: diff --git a/Utils/CroppingAugm.py b/Utils/CroppingAugm.py index 395994e..d5cddf6 100644 --- a/Utils/CroppingAugm.py +++ b/Utils/CroppingAugm.py @@ -36,6 +36,7 @@ def SubsampleProcessor(target_crop_size, N, extras=[], sampler='uniform'): withBlur = blurConfig is not None if withBlur: blurRange = blurConfig['min'] + tf.linspace(0.0, blurConfig['max'], blurConfig['N']) + minR = tf.reduce_min(blurRange) blurN = tf.size(blurRange) blurShared = blurConfig.get('shared', False) if blurShared: @@ -60,20 +61,21 @@ def _FF(img): sobel = extractInterpolated(sobel, positions) res['sobel'] = tf.reshape(sobel, [N, 6]) - if withBlur and blurShared: - idx = tf.random.uniform((1,), minval=0, maxval=blurN, dtype=tf.int32) - R = tf.gather(blurRange, idx) - R = tf.reshape(R, (1,)) + if withBlur: + if blurShared: + idx = tf.random.uniform((1,), minval=0, maxval=blurN, dtype=tf.int32) + R = tf.gather(blurRange, idx) + R = tf.reshape(R, (1,)) + R = tf.fill([N, 1], R[0]) + else: + idx = tf.random.uniform((N,), minval=0, maxval=blurN, dtype=tf.int32) + R = tf.gather(blurRange, idx) + R = tf.reshape(R, (N, 1)) + pass + + tf.assert_equal(tf.shape(R), (N, 1)) res['blured'] = blur(src, positions[0], R) - res['blur R'] = tf.fill([N, 1], R[0]) - pass - - if withBlur and not blurShared: - idx = tf.random.uniform((N,), minval=0, maxval=blurN, dtype=tf.int32) - R = tf.gather(blurRange, idx) - R = tf.reshape(R, (N, 1)) - res['blured'] = blur(src, positions[0], R) - res['blur R'] = R + res['blur R'] = R - minR # ensure that R is starting from 0.0 pass return res return _FF