Skip to content

Commit

Permalink
implemented a way to extract blured pixels from an image
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenWizard2015 committed Apr 9, 2024
1 parent 5861df0 commit cabb839
Show file tree
Hide file tree
Showing 3 changed files with 249 additions and 34 deletions.
35 changes: 1 addition & 34 deletions NN/utils.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,9 @@
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_probability as tfp
import tensorflow.keras.layers as L
import tensorflow as tf
import tensorflow_probability as tfp

def gaussian_kernel(size, stdsPx):
stds = tf.cast(stdsPx, tf.float32) / size
B = tf.shape(stds)[0]
stds = tf.reshape(stds, [B, 1])
x = tf.linspace(-size // 2 + 1, size // 2 + 1, size)
x = tf.cast(x ** 2, tf.float32)
x = tf.tile(x[None], [B, 1])
x = tf.nn.softmax(-x / (2.0 * (stds**2)))
x = tf.matmul(x[:, :, None], x[:, None, :])
gauss = tf.reshape(x, [B, size, size, 1])
gauss = tf.repeat(gauss, 3, axis=-1)
gauss = tf.repeat(gauss, 3, axis=0)
gauss = tf.transpose(gauss, [1, 2, 3, 0]) # [B, size, size, 1] => [size, size, 1, B]
return gauss

def masked(x, mask):
'''
very weird hack to apply a mask to the tensor and ensure that the shape is preserved
Expand Down Expand Up @@ -156,21 +140,4 @@ def is_namedtuple(obj) -> bool:
isinstance(obj, tuple) and
hasattr(obj, '_asdict') and
hasattr(obj, '_fields')
)

if '__main__' == __name__:
print('Utils test')
print('All tests passed successfully!')
import cv2
import tensorflow_addons as tfa

img = cv2.imread('d:\photo_2024-03-26_15-48-33.jpg')
img = img.astype('float32') / 255.0
gaussians = gaussian_kernel(48, tf.constant([10., 20., 30.]))
imgG = tf.nn.conv2d(img[None], gaussians, strides=[1, 1, 1, 1], padding='SAME')[0]

for i in [3, 6, 9]:
g = imgG[..., i-3:i].numpy()
g = (g * 255.0).astype('uint8')
cv2.imshow('Gaussian %i' % i, g)
cv2.waitKey(0)
)
153 changes: 153 additions & 0 deletions NN/utils_bluring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import tensorflow as tf
from NN.utils import extractInterpolated

def create1DGaussian(size, stds, shifts):
B = tf.shape(stds)[0]
tf.assert_equal(tf.shape(shifts), (B, ))
x = tf.linspace(-size // 2 + 1, size // 2 + 1, size)
x = tf.cast(x, tf.float32)
x = tf.tile(x[None], [B, 1]) + shifts[..., None]
x = tf.nn.softmax(-(x ** 2) / (2.0 * (stds ** 2)), axis=-1)
x = tf.reshape(x, [B, size])
return x

def gaussian_kernel(size, stdsPx, shifts=None):
if shifts is None:
shifts = tf.zeros((tf.shape(stdsPx)[0], 2))

stds = tf.cast(stdsPx, tf.float32)
B = tf.shape(stds)[0]
stds = tf.reshape(stds, [B, 1])

gX = create1DGaussian(size, stds, shifts[:, 0])[..., None]
gY = create1DGaussian(size, stds, shifts[:, 1])[..., None, :]
gauss = tf.matmul(gX, gY)

gauss = tf.reshape(gauss, [B, size, size, 1])
gauss = tf.tile(gauss, [1, 1, 1, 3])
tf.assert_equal(tf.shape(gauss), (B, size, size, 3))
gauss = tf.transpose(gauss, [1, 2, 3, 0]) # [B, size, size, 1] => [size, size, 1, B]
tf.assert_equal(tf.shape(gauss), (size, size, 3, B))
return gauss

############################
# trying to implement more efficient bluring
def shiftsPixels(HW, points):
d = 1.0 / tf.cast(HW, tf.float32)
return points - (tf.floor(points / d) * d + (d / 2.0))

def visibleArea(points, HW, size):
HW = tf.cast(HW, tf.float32)
HW = tf.repeat(HW, repeats=2)
HW = tf.reshape(HW, (1, 2))
points = points * HW
points = tf.floor(points)
points = tf.cast(points, tf.int32)

HW = tf.cast(HW, tf.int32)
left = tf.maximum(0, points - size)
right = tf.minimum(HW, points + size)
return left, right

def area2indices(left, right, HW, maxN):
B = tf.shape(left)[0]
LR = tf.concat([left, right], axis=-1)
tf.assert_equal(tf.shape(LR), (B, 4))

def f(lr):
l, r = lr[:2], lr[2:]
wh = r - l
w, h = wh[0], wh[1]
# tf.debugging.assert_greater(0, w)
tf.debugging.assert_less_equal(w, maxN)
# tf.debugging.assert_greater(0, h)
tf.debugging.assert_less_equal(h, maxN)
indices = l[0] + tf.range(w) # [minX, maxX]
indices = tf.reshape(indices, [1, -1])
indices = tf.tile(indices, [h, 1])
shifts = l[1] + tf.range(h) # [minY, maxY]
indices = indices + shifts[:, None] * HW
tf.assert_equal(tf.shape(indices), (h, w))

pad = maxN**2 - tf.size(indices)
indices = tf.reshape(indices, [-1])
indices = tf.pad(indices, [[0, pad]], constant_values=-1)
return indices
return tf.map_fn(f, LR, dtype=tf.int32)

def extractBluredX(img, points, R, maxR):
img = img[None]
B = tf.shape(points)[0]
tf.assert_rank(img, 4)
tf.assert_equal(tf.shape(points), (B, 2))
tf.assert_equal(tf.shape(R), (B, 1))
H, W = [tf.shape(img)[i] for i in [1, 2]]
tf.assert_equal(H, W, 'Image should be square')
gaussians = gaussian_kernel(maxR, R, shifts=shiftsPixels(H, points))
gaussians = tf.transpose(gaussians, [3, 0, 1, 2]) # [size, size, 3, B] => [B, size, size, 3]
gaussians = tf.reshape(gaussians, [B, -1, 3])
sz = tf.shape(gaussians)[1]
# extract areas around the points
# first, find the visible area for each point
left, right = visibleArea(points, H, size=maxR)
# extract the indices of the visible area
indices = area2indices(left, right, H, sz)
tf.assert_equal(tf.shape(indices), (B, sz ** 2))
# extract the visible areas from the image
flatImg = tf.reshape(img, [1, H * W, 3])
extracted = tf.gather(flatImg, indices, axis=1)[0]
tf.assert_equal(tf.shape(extracted), (B, sz ** 2, 3))

indicesLow = indices[:, 0, None]
extractedWeights = tf.gather(gaussians, indices - indicesLow, batch_dims=1)
tf.assert_equal(tf.shape(extractedWeights), tf.shape(extracted))
extracted = tf.reduce_sum(extracted * extractedWeights, axis=1)
tf.assert_equal(tf.shape(extracted), (B, 3))
return extracted
############################
def applyBluring(img, kernel):
tf.assert_rank(img, 4)
tf.assert_rank(kernel, 4)
tf.assert_equal(tf.shape(img)[0], 1)
B = tf.shape(kernel)[-1]

imgG = tf.nn.depthwise_conv2d(img, kernel, strides=[1, 1, 1, 1], padding='SAME')[0]
H, W = [tf.shape(imgG)[i] for i in range(2)]
imgG = tf.reshape(imgG, [H, W, 3, -1])
imgG = tf.transpose(imgG, (3, 0, 1, 2))
imgG = tf.reshape(imgG, (B, H, W, 3))
return imgG

def extractBlured(R):
'''
R: list of bluring radiuses, (B, 1)
'''
R = tf.reshape(R, (tf.size(R), 1))
maxR = tf.reduce_max(R)
maxR = tf.cast(maxR, tf.int32) + 1
gaussians = gaussian_kernel(maxR, R, shifts=tf.zeros((tf.shape(R)[0], 2)))
gaussiansN = tf.shape(gaussians)[-1]

def f(img, points, ptR):
img = img[None]
tf.assert_rank(img, 4)
B = tf.shape(points)[0]
tf.assert_equal(tf.shape(points), (B, 2))
tf.assert_equal(tf.shape(ptR), (B, 1))
blured = applyBluring(img, gaussians)
tf.assert_equal(tf.shape(blured), (gaussiansN, tf.shape(img)[1], tf.shape(img)[2], 3))
# extract the blured values
blured = extractInterpolated(blured, points[None])
tf.assert_equal(tf.shape(blured), (gaussiansN, B, 3))

# blured contains the blured values for each point in each gaussian
# we need to select the blured value for each point based on its radius
correspondingG = tf.transpose(ptR == R[..., 0])
tf.assert_equal(tf.shape(correspondingG), (gaussiansN, B))

idx = tf.where(correspondingG)
tf.assert_equal(tf.shape(idx), (B, 2))
blured = tf.gather_nd(blured, idx)
tf.assert_equal(tf.shape(blured), (B, 3))
return blured
return f
95 changes: 95 additions & 0 deletions tests/test_utils_bluring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import pytest
import numpy as np
import tensorflow as tf
from NN.utils_bluring import shiftsPixels, visibleArea, area2indices, \
gaussian_kernel, applyBluring, extractBlured
from NN.utils import extractInterpolated

# test shiftsPixels
def test_shiftsPixels():
H = 3
d = (1.0 / H)
points = np.array([
[d * 0.9, d * 0.5], [d * 1.1, d * 0.6], [d * 2.3, d * 0.7], [d * 2.8, d * 10.],
]).astype(np.float32)
correct = np.array([
[0.4, 0.0], [-0.4, 0.1], [-0.2, 0.2], [0.3, 0.5],
]).astype(np.float32)

shifts = shiftsPixels(H, points).numpy() / d
diff = np.abs(shifts - correct)
assert diff.max() < 1e-5, '%s != %s' % (shifts, correct)
return

# test visibleArea
def test_visibleArea():
HW = 10
size = 7
points = np.array([
[0.5, 0.5], [0.1, 0.1], [0.9, 0.9], [0.1, 0.9], [0.9, 0.1], [0.7, 0.2]
]).astype(np.float32)
left, right = visibleArea(points, HW, size)
left = left.numpy()
right = right.numpy()
correctLeft = np.array([
[2, 2], [0, 0], [6, 6], [0, 6], [6, 0], [4, 0],
])
correctRight = np.array([
[ 8, 8], [ 4, 4], [10, 10], [ 4, 10], [10, 4], [10, 5]
])
assert np.allclose(left, correctLeft), '%s != %s' % (left, correctLeft)
assert np.allclose(right, correctRight), '%s != %s' % (right, correctRight)
return

# test area2indices
def test_area2indices():
HW = 10
indices = area2indices([[0, 0], [2, 1]], [[2, 3], [4, 5]], HW)
correct = np.array([
[ 0, 1,
0 + HW * 1, 1 + HW * 1,
0 + HW * 2, 1 + HW * 2,
] + (HW ** 2 - 6) * [-1],
[ 2 + HW * 1, 3 + HW * 1,
2 + HW * 2, 3 + HW * 2,
2 + HW * 3, 3 + HW * 3,
2 + HW * 4, 3 + HW * 4,
] + (HW ** 2 - (4 - 2) * (5 - 1)) * [-1],
])
for i, a, b in zip(range(correct.shape[0]), indices.numpy(), correct):
assert np.allclose(a, b), '%d | %s != %s' % (i, a, b)
continue
return

# test area2indices full
def test_area2indices_full():
hwOld = 10
indices = area2indices([[0, 0]], [[hwOld, hwOld]], hwOld)
correct = np.array([range(hwOld ** 2)])
for i, a, b in zip(range(correct.shape[0]), indices.numpy(), correct):
assert np.allclose(a, b), '%d | %s != %s' % (i, a, b)
continue
return
######################################################################
# extractBlured same as extractInterpolated after applying gaussians
# TODO: find out why the results are so numerically instable
def test_extractBlured_same():
H = 10
img = np.random.rand(H, H, 3).astype(np.float32)
points = np.array([
[0.5, 0.5], [0.1, 0.1], [0.9, 0.9], [0.1, 0.9], [0.9, 0.1], [0.7, 0.2]
]).astype(np.float32)
R = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).astype(np.float32)
imgBlured = applyBluring(img[None], gaussian_kernel(H, R))
assert imgBlured.shape == (6, H, H, 3)

blur = extractBlured(R)
blured = blur(img, points, ptR=np.full((6, 1), 1.0))

extracted = extractInterpolated(imgBlured[0, None], points[None])[0]
assert extracted.shape == blured.shape
for i, a, b in zip(range(extracted.shape[0]), extracted, blured):
diff = np.abs(a - b).max()
assert diff < 5e-2, '%d | %s != %s' % (i, a, b)
continue
return

0 comments on commit cabb839

Please sign in to comment.