diff --git a/utils/generator.py b/utils/generator.py index e5326c1..87a8936 100644 --- a/utils/generator.py +++ b/utils/generator.py @@ -131,6 +131,7 @@ def regenerate_cache(self): (self.max_size, self.max_size), method=tf.image.ResizeMethod.BICUBIC) + pattern = pattern / 255 pattern = tf.cast(tf.math.less(pattern, self.density), tf.uint8) self.pattern = pattern @@ -148,7 +149,7 @@ def __call__(self, inputs, density_std=0.05): y = self.rng.randint(0, self.max_size - height + 1, size=len(idx)) for i, lx, ly in zip(idx, x, y): res[i] = self.pattern[lx: lx + width, ly: ly + height][None] - coverage = float(res[i, 0].mean()) + coverage = float(res[i, :, :, 0].mean()) if not (self.density - density_std < coverage < self.density + density_std): nw_idx.append(i)