Skip to content

Commit

Permalink
blured impl
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenWizard2015 committed Apr 10, 2024
1 parent 1eda483 commit f89a872
Show file tree
Hide file tree
Showing 13 changed files with 176 additions and 172 deletions.
44 changes: 35 additions & 9 deletions NN/RestorationModel/CRestorationModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class CRestorationModel(tf.keras.Model):
def __init__(self,
decoder, restorator,
posEncoder, timeEncoder,
blurRadiusEncoder=None,
residualCondition=False,
**kwargs
):
Expand All @@ -20,6 +21,9 @@ def __init__(self,
self._posEncoder = posEncoder
self._timeEncoder = timeEncoder
self._residualCondition = residualCondition
self._blurRadiusEncoder = blurRadiusEncoder
print('[CRestorationModel] Residuals: ', residualCondition)
print('[CRestorationModel] Blur: ', blurRadiusEncoder is not None)
return

def _encodeTime(self, t, training):
Expand All @@ -45,15 +49,39 @@ def _withResidual(self, value, residual):
res = tf.concat([mainRes + residual, rest], axis=-1)
tf.assert_equal(tf.shape(res), shp)
return res

def _addResiduals(self, latents, residuals):
if self._residualCondition:
latents = tf.concat([latents, residuals], axis=-1)

return latents

def _addRadius(self, latents, R=None, fakeR=1e-5, training=False):
B = tf.shape(latents)[0]
if self._blurRadiusEncoder is not None:
if R is None:
R = tf.constant([[fakeR]], dtype=tf.float32)
encodedR = self._blurRadiusEncoder(R, training=training)
encodedR = tf.tile(encodedR, [B, 1])
else:
encodedR = self._blurRadiusEncoder(R, training=training)

def call(self, latents, pos, T, V, residual, training=None):
latents = tf.concat([latents, encodedR], axis=-1)

return latents

def call(self, latents, pos, T, V, residual, R=None, training=False):
EPos = self._encodePos(pos, training=training, args={})
t = self._encodeTime(T, training=training)
latents = self._addResiduals(latents, residual)
latents = self._addRadius(latents, R=R, training=training)
res = self._decoder(condition=latents, coords=EPos, timestep=t, V=V, training=training)
return self._withResidual(res, residual)

def reverse(self, latents, pos, reverseArgs, training, value, residual, index):
EPos = self._encodePos(pos, training=training, args=reverseArgs.get('decoder', {}))
latents = self._addResiduals(latents, residual)
latents = self._addRadius(latents, R=None, training=training)

def denoiser(x, t, mask=None):
args = dict(condition=latents, coords=EPos, timestep=t, V=x)
Expand All @@ -62,13 +90,6 @@ def denoiser(x, t, mask=None):
args = {k: masked(v, mask) for k, v in args.items()}
residuals = masked(residual, mask)

if self._residualCondition: # add residuals to the condition
B = tf.shape(residuals)[0]
args['condition'] = tf.reshape( # ensure that we know the shape of the condition
tf.concat([args['condition'], residuals], axis=-1),
[B, -1]
)

res = self._decoder(**args, training=training)
return self._withResidual(res, residuals)

Expand All @@ -81,14 +102,19 @@ def denoiser(x, t, mask=None):

def train_step(self, x0, latents, positions, params, xT=None):
residual = params['residual']
params = {k: v for k, v in params.items() if k not in ['residual']}
R = None
if self._blurRadiusEncoder is not None:
x0 = params['blured'] # replace the input with the blured values
R = params['blur R']
params = {k: v for k, v in params.items() if k not in ['residual', 'blured', 'blur R']}
# defer training to the restorator
return self._restorator.train_step(
x0=x0,
xT=xT,
model=lambda T, V: self(
latents=latents, pos=positions,
T=T, V=V, residual=residual,
R=R,
training=True
),
**params
Expand Down
5 changes: 5 additions & 0 deletions NN/RestorationModel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,16 @@ def restorationModel_from_config(config):

if 'basic' == name:
restorator = restorator_from_config(config['restorator'])
blurRadiusEncoder = None
if 'blur radius encoding' in config:
blurRadiusEncoder = encoding_from_config(config['blur radius encoding'])

return CRestorationModel(
decoder=decoder_from_config(config['decoder'], channels=restorator.channels),
restorator=restorator,
posEncoder=encoding_from_config(config['position encoding']),
timeEncoder=encoding_from_config(config['time encoding']),
blurRadiusEncoder=blurRadiusEncoder,
residualCondition=config.get('residual condition', False),
)

Expand Down
93 changes: 18 additions & 75 deletions NN/utils_bluring.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,81 +30,6 @@ def gaussian_kernel(size, stdsPx, shifts=None):
tf.assert_equal(tf.shape(gauss), (size, size, 3, B))
return gauss

############################
# trying to implement more efficient bluring
def shiftsPixels(HW, points):
d = 1.0 / tf.cast(HW, tf.float32)
return points - (tf.floor(points / d) * d + (d / 2.0))

def visibleArea(points, HW, size):
HW = tf.cast(HW, tf.float32)
HW = tf.repeat(HW, repeats=2)
HW = tf.reshape(HW, (1, 2))
points = points * HW
points = tf.floor(points)
points = tf.cast(points, tf.int32)

HW = tf.cast(HW, tf.int32)
left = tf.maximum(0, points - size)
right = tf.minimum(HW, points + size)
return left, right

def area2indices(left, right, HW, maxN):
B = tf.shape(left)[0]
LR = tf.concat([left, right], axis=-1)
tf.assert_equal(tf.shape(LR), (B, 4))

def f(lr):
l, r = lr[:2], lr[2:]
wh = r - l
w, h = wh[0], wh[1]
# tf.debugging.assert_greater(0, w)
tf.debugging.assert_less_equal(w, maxN)
# tf.debugging.assert_greater(0, h)
tf.debugging.assert_less_equal(h, maxN)
indices = l[0] + tf.range(w) # [minX, maxX]
indices = tf.reshape(indices, [1, -1])
indices = tf.tile(indices, [h, 1])
shifts = l[1] + tf.range(h) # [minY, maxY]
indices = indices + shifts[:, None] * HW
tf.assert_equal(tf.shape(indices), (h, w))

pad = maxN**2 - tf.size(indices)
indices = tf.reshape(indices, [-1])
indices = tf.pad(indices, [[0, pad]], constant_values=-1)
return indices
return tf.map_fn(f, LR, dtype=tf.int32)

def extractBluredX(img, points, R, maxR):
img = img[None]
B = tf.shape(points)[0]
tf.assert_rank(img, 4)
tf.assert_equal(tf.shape(points), (B, 2))
tf.assert_equal(tf.shape(R), (B, 1))
H, W = [tf.shape(img)[i] for i in [1, 2]]
tf.assert_equal(H, W, 'Image should be square')
gaussians = gaussian_kernel(maxR, R, shifts=shiftsPixels(H, points))
gaussians = tf.transpose(gaussians, [3, 0, 1, 2]) # [size, size, 3, B] => [B, size, size, 3]
gaussians = tf.reshape(gaussians, [B, -1, 3])
sz = tf.shape(gaussians)[1]
# extract areas around the points
# first, find the visible area for each point
left, right = visibleArea(points, H, size=maxR)
# extract the indices of the visible area
indices = area2indices(left, right, H, sz)
tf.assert_equal(tf.shape(indices), (B, sz ** 2))
# extract the visible areas from the image
flatImg = tf.reshape(img, [1, H * W, 3])
extracted = tf.gather(flatImg, indices, axis=1)[0]
tf.assert_equal(tf.shape(extracted), (B, sz ** 2, 3))

indicesLow = indices[:, 0, None]
extractedWeights = tf.gather(gaussians, indices - indicesLow, batch_dims=1)
tf.assert_equal(tf.shape(extractedWeights), tf.shape(extracted))
extracted = tf.reduce_sum(extracted * extractedWeights, axis=1)
tf.assert_equal(tf.shape(extracted), (B, 3))
return extracted
############################
def applyBluring(img, kernel):
tf.assert_rank(img, 4)
tf.assert_rank(kernel, 4)
Expand Down Expand Up @@ -150,4 +75,22 @@ def f(img, points, ptR):
blured = tf.gather_nd(blured, idx)
tf.assert_equal(tf.shape(blured), (B, 3))
return blured
return f

def extractBluredOne():
def f(img, points, R):
img = img[None]
tf.assert_rank(img, 4)
B = tf.shape(points)[0]
tf.assert_equal(tf.shape(points), (B, 2))
R = tf.reshape(R, (1, ))
maxR = tf.cast(R, tf.int32) + 1
gaussians = gaussian_kernel(maxR[0], R)
blured = applyBluring(img, gaussians)
gaussiansN = tf.shape(gaussians)[-1]
tf.assert_equal(tf.shape(blured), (gaussiansN, tf.shape(img)[1], tf.shape(img)[2], 3))
# extract the blured values
blured = extractInterpolated(blured, points[None])[0]
tf.assert_equal(tf.shape(blured), (B, 3))
return blured
return f
15 changes: 8 additions & 7 deletions Utils/CImageProcessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,11 @@ def _dest(self, cropped):
assert 'sampled' in cropped, 'Invalid cropped: %s' % cropped
assert 'positions' in cropped, 'Invalid cropped: %s' % cropped
sampled = cropped['sampled']
res = dict(
sampled=sampled,
positions=cropped['positions']
)
res = dict(**cropped)
for extra in self._extras:
if 'grayscale' == extra:
res['grayscale'] = tf.image.rgb_to_grayscale(sampled)
continue
if 'sobel' == extra: # copy sobel edges
res['sobel'] = cropped['sobel']
continue
return res

Expand All @@ -68,12 +63,18 @@ def process(self, config_or_image):
'shared crops': True,
'crop size': None,
'subsample': False,
'blur': False,
}
extras = self._extras
if isConfig:
args = dict(args, **config_or_image)
blur = args.get('blur', False)
if blur:
blur['name'] = 'blured'
extras = extras + [blur]

cropper = CroppingAugm.configToCropper(
args, dest_size=self._imageSize, extras=self._extras,
args, dest_size=self._imageSize, extras=extras
)
def _process(img):
self._checkInput(img)
Expand Down
35 changes: 35 additions & 0 deletions Utils/CroppingAugm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from Utils.utils import CFakeObject
from Utils.PositionsSampler import PositionsSampler
from NN.utils import extractInterpolated
from NN.utils_bluring import extractBlured, extractBluredOne

def CropsProcessor(F, signature):
return CFakeObject(F=F, signature=signature)
Expand All @@ -27,6 +28,22 @@ def SubsampleProcessor(target_crop_size, N, extras=[], sampler='uniform'):
assert N > 0, 'Invalid N: %s' % N
sampler = PositionsSampler(sampler)
resizer = _resizeTo(target_crop_size)(None)

blurConfig = next(
(e for e in extras if isinstance(e, dict) and ('blured' == e['name'])),
None
)
withBlur = blurConfig is not None
if withBlur:
blurRange = blurConfig['min'] + tf.linspace(0.0, blurConfig['max'], blurConfig['N'])
blurN = tf.size(blurRange)
blurShared = blurConfig.get('shared', False)
if blurShared:
blur = extractBluredOne()
else:
blur = extractBlured(blurRange)
pass

def _F(dest_size=None): # dest_size is ignored
def _FF(img):
img = tf.cast(img, tf.float32)
Expand All @@ -43,12 +60,30 @@ def _FF(img):
sobel = extractInterpolated(sobel, positions)
res['sobel'] = tf.reshape(sobel, [N, 6])

if withBlur and blurShared:
idx = tf.random.uniform((1,), minval=0, maxval=blurN, dtype=tf.int32)
R = tf.gather(blurRange, idx)
R = tf.reshape(R, (1,))
res['blured'] = blur(src, positions[0], R)
res['blur R'] = tf.fill([N, 1], R[0])
pass

if withBlur and not blurShared:
idx = tf.random.uniform((N,), minval=0, maxval=blurN, dtype=tf.int32)
R = tf.gather(blurRange, idx)
R = tf.reshape(R, (N, 1))
res['blured'] = blur(src, positions[0], R)
res['blur R'] = R
pass
return res
return _FF

signature = dict(src=tf.float32, sampled=tf.float32, positions=tf.float32)
if 'sobel' in extras:
signature['sobel'] = tf.float32
if withBlur:
signature['blured'] = tf.float32
signature['blur R'] = tf.float32
return CropsProcessor(F=_F, signature=signature)
#############
# Cropping methods
Expand Down
12 changes: 12 additions & 0 deletions configs/experiments/blur/basic-shared.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"dataset": {
"train": {
"blur": {
"min": 1e-5,
"max": 8,
"N": 100,
"shared": true
}
}
}
}
12 changes: 12 additions & 0 deletions configs/experiments/blur/basic.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"dataset": {
"train": {
"blur": {
"min": 1e-5,
"max": 8,
"N": 100,
"shared": false
}
}
}
}
7 changes: 7 additions & 0 deletions configs/experiments/encoder/with-positions.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"model": {
"encoder": {
"positions": true
}
}
}
7 changes: 7 additions & 0 deletions configs/experiments/restorator/use-blur.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"model": {
"restorator": {
"blur radius encoding": "from:configs/models/encodings.json"
}
}
}
7 changes: 7 additions & 0 deletions configs/experiments/restorator/use-residuals.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"model": {
"restorator": {
"residual condition": true
}
}
}
8 changes: 8 additions & 0 deletions configs/models/celeba.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
"image size": 64,
"channels": 1
},
"decoder": {
"inherit": "models/decoder/mlp.json",
"channels": 3
},
"restorator": {
"inherit": "models/restorator/single-pass.json",
"channels": 3
},
"renderer": "from:renderer.json",
"nerf": {
"name": "basic",
Expand Down
Loading

0 comments on commit f89a872

Please sign in to comment.