-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
implemented a way to extract blured pixels from an image
- Loading branch information
1 parent
5861df0
commit cabb839
Showing
3 changed files
with
249 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import tensorflow as tf | ||
from NN.utils import extractInterpolated | ||
|
||
def create1DGaussian(size, stds, shifts): | ||
B = tf.shape(stds)[0] | ||
tf.assert_equal(tf.shape(shifts), (B, )) | ||
x = tf.linspace(-size // 2 + 1, size // 2 + 1, size) | ||
x = tf.cast(x, tf.float32) | ||
x = tf.tile(x[None], [B, 1]) + shifts[..., None] | ||
x = tf.nn.softmax(-(x ** 2) / (2.0 * (stds ** 2)), axis=-1) | ||
x = tf.reshape(x, [B, size]) | ||
return x | ||
|
||
def gaussian_kernel(size, stdsPx, shifts=None): | ||
if shifts is None: | ||
shifts = tf.zeros((tf.shape(stdsPx)[0], 2)) | ||
|
||
stds = tf.cast(stdsPx, tf.float32) | ||
B = tf.shape(stds)[0] | ||
stds = tf.reshape(stds, [B, 1]) | ||
|
||
gX = create1DGaussian(size, stds, shifts[:, 0])[..., None] | ||
gY = create1DGaussian(size, stds, shifts[:, 1])[..., None, :] | ||
gauss = tf.matmul(gX, gY) | ||
|
||
gauss = tf.reshape(gauss, [B, size, size, 1]) | ||
gauss = tf.tile(gauss, [1, 1, 1, 3]) | ||
tf.assert_equal(tf.shape(gauss), (B, size, size, 3)) | ||
gauss = tf.transpose(gauss, [1, 2, 3, 0]) # [B, size, size, 1] => [size, size, 1, B] | ||
tf.assert_equal(tf.shape(gauss), (size, size, 3, B)) | ||
return gauss | ||
|
||
############################ | ||
# trying to implement more efficient bluring | ||
def shiftsPixels(HW, points): | ||
d = 1.0 / tf.cast(HW, tf.float32) | ||
return points - (tf.floor(points / d) * d + (d / 2.0)) | ||
|
||
def visibleArea(points, HW, size): | ||
HW = tf.cast(HW, tf.float32) | ||
HW = tf.repeat(HW, repeats=2) | ||
HW = tf.reshape(HW, (1, 2)) | ||
points = points * HW | ||
points = tf.floor(points) | ||
points = tf.cast(points, tf.int32) | ||
|
||
HW = tf.cast(HW, tf.int32) | ||
left = tf.maximum(0, points - size) | ||
right = tf.minimum(HW, points + size) | ||
return left, right | ||
|
||
def area2indices(left, right, HW, maxN): | ||
B = tf.shape(left)[0] | ||
LR = tf.concat([left, right], axis=-1) | ||
tf.assert_equal(tf.shape(LR), (B, 4)) | ||
|
||
def f(lr): | ||
l, r = lr[:2], lr[2:] | ||
wh = r - l | ||
w, h = wh[0], wh[1] | ||
# tf.debugging.assert_greater(0, w) | ||
tf.debugging.assert_less_equal(w, maxN) | ||
# tf.debugging.assert_greater(0, h) | ||
tf.debugging.assert_less_equal(h, maxN) | ||
indices = l[0] + tf.range(w) # [minX, maxX] | ||
indices = tf.reshape(indices, [1, -1]) | ||
indices = tf.tile(indices, [h, 1]) | ||
shifts = l[1] + tf.range(h) # [minY, maxY] | ||
indices = indices + shifts[:, None] * HW | ||
tf.assert_equal(tf.shape(indices), (h, w)) | ||
|
||
pad = maxN**2 - tf.size(indices) | ||
indices = tf.reshape(indices, [-1]) | ||
indices = tf.pad(indices, [[0, pad]], constant_values=-1) | ||
return indices | ||
return tf.map_fn(f, LR, dtype=tf.int32) | ||
|
||
def extractBluredX(img, points, R, maxR): | ||
img = img[None] | ||
B = tf.shape(points)[0] | ||
tf.assert_rank(img, 4) | ||
tf.assert_equal(tf.shape(points), (B, 2)) | ||
tf.assert_equal(tf.shape(R), (B, 1)) | ||
H, W = [tf.shape(img)[i] for i in [1, 2]] | ||
tf.assert_equal(H, W, 'Image should be square') | ||
gaussians = gaussian_kernel(maxR, R, shifts=shiftsPixels(H, points)) | ||
gaussians = tf.transpose(gaussians, [3, 0, 1, 2]) # [size, size, 3, B] => [B, size, size, 3] | ||
gaussians = tf.reshape(gaussians, [B, -1, 3]) | ||
sz = tf.shape(gaussians)[1] | ||
# extract areas around the points | ||
# first, find the visible area for each point | ||
left, right = visibleArea(points, H, size=maxR) | ||
# extract the indices of the visible area | ||
indices = area2indices(left, right, H, sz) | ||
tf.assert_equal(tf.shape(indices), (B, sz ** 2)) | ||
# extract the visible areas from the image | ||
flatImg = tf.reshape(img, [1, H * W, 3]) | ||
extracted = tf.gather(flatImg, indices, axis=1)[0] | ||
tf.assert_equal(tf.shape(extracted), (B, sz ** 2, 3)) | ||
|
||
indicesLow = indices[:, 0, None] | ||
extractedWeights = tf.gather(gaussians, indices - indicesLow, batch_dims=1) | ||
tf.assert_equal(tf.shape(extractedWeights), tf.shape(extracted)) | ||
extracted = tf.reduce_sum(extracted * extractedWeights, axis=1) | ||
tf.assert_equal(tf.shape(extracted), (B, 3)) | ||
return extracted | ||
############################ | ||
def applyBluring(img, kernel): | ||
tf.assert_rank(img, 4) | ||
tf.assert_rank(kernel, 4) | ||
tf.assert_equal(tf.shape(img)[0], 1) | ||
B = tf.shape(kernel)[-1] | ||
|
||
imgG = tf.nn.depthwise_conv2d(img, kernel, strides=[1, 1, 1, 1], padding='SAME')[0] | ||
H, W = [tf.shape(imgG)[i] for i in range(2)] | ||
imgG = tf.reshape(imgG, [H, W, 3, -1]) | ||
imgG = tf.transpose(imgG, (3, 0, 1, 2)) | ||
imgG = tf.reshape(imgG, (B, H, W, 3)) | ||
return imgG | ||
|
||
def extractBlured(R): | ||
''' | ||
R: list of bluring radiuses, (B, 1) | ||
''' | ||
R = tf.reshape(R, (tf.size(R), 1)) | ||
maxR = tf.reduce_max(R) | ||
maxR = tf.cast(maxR, tf.int32) + 1 | ||
gaussians = gaussian_kernel(maxR, R, shifts=tf.zeros((tf.shape(R)[0], 2))) | ||
gaussiansN = tf.shape(gaussians)[-1] | ||
|
||
def f(img, points, ptR): | ||
img = img[None] | ||
tf.assert_rank(img, 4) | ||
B = tf.shape(points)[0] | ||
tf.assert_equal(tf.shape(points), (B, 2)) | ||
tf.assert_equal(tf.shape(ptR), (B, 1)) | ||
blured = applyBluring(img, gaussians) | ||
tf.assert_equal(tf.shape(blured), (gaussiansN, tf.shape(img)[1], tf.shape(img)[2], 3)) | ||
# extract the blured values | ||
blured = extractInterpolated(blured, points[None]) | ||
tf.assert_equal(tf.shape(blured), (gaussiansN, B, 3)) | ||
|
||
# blured contains the blured values for each point in each gaussian | ||
# we need to select the blured value for each point based on its radius | ||
correspondingG = tf.transpose(ptR == R[..., 0]) | ||
tf.assert_equal(tf.shape(correspondingG), (gaussiansN, B)) | ||
|
||
idx = tf.where(correspondingG) | ||
tf.assert_equal(tf.shape(idx), (B, 2)) | ||
blured = tf.gather_nd(blured, idx) | ||
tf.assert_equal(tf.shape(blured), (B, 3)) | ||
return blured | ||
return f |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import pytest | ||
import numpy as np | ||
import tensorflow as tf | ||
from NN.utils_bluring import shiftsPixels, visibleArea, area2indices, \ | ||
gaussian_kernel, applyBluring, extractBlured | ||
from NN.utils import extractInterpolated | ||
|
||
# test shiftsPixels | ||
def test_shiftsPixels(): | ||
H = 3 | ||
d = (1.0 / H) | ||
points = np.array([ | ||
[d * 0.9, d * 0.5], [d * 1.1, d * 0.6], [d * 2.3, d * 0.7], [d * 2.8, d * 10.], | ||
]).astype(np.float32) | ||
correct = np.array([ | ||
[0.4, 0.0], [-0.4, 0.1], [-0.2, 0.2], [0.3, 0.5], | ||
]).astype(np.float32) | ||
|
||
shifts = shiftsPixels(H, points).numpy() / d | ||
diff = np.abs(shifts - correct) | ||
assert diff.max() < 1e-5, '%s != %s' % (shifts, correct) | ||
return | ||
|
||
# test visibleArea | ||
def test_visibleArea(): | ||
HW = 10 | ||
size = 7 | ||
points = np.array([ | ||
[0.5, 0.5], [0.1, 0.1], [0.9, 0.9], [0.1, 0.9], [0.9, 0.1], [0.7, 0.2] | ||
]).astype(np.float32) | ||
left, right = visibleArea(points, HW, size) | ||
left = left.numpy() | ||
right = right.numpy() | ||
correctLeft = np.array([ | ||
[2, 2], [0, 0], [6, 6], [0, 6], [6, 0], [4, 0], | ||
]) | ||
correctRight = np.array([ | ||
[ 8, 8], [ 4, 4], [10, 10], [ 4, 10], [10, 4], [10, 5] | ||
]) | ||
assert np.allclose(left, correctLeft), '%s != %s' % (left, correctLeft) | ||
assert np.allclose(right, correctRight), '%s != %s' % (right, correctRight) | ||
return | ||
|
||
# test area2indices | ||
def test_area2indices(): | ||
HW = 10 | ||
indices = area2indices([[0, 0], [2, 1]], [[2, 3], [4, 5]], HW) | ||
correct = np.array([ | ||
[ 0, 1, | ||
0 + HW * 1, 1 + HW * 1, | ||
0 + HW * 2, 1 + HW * 2, | ||
] + (HW ** 2 - 6) * [-1], | ||
[ 2 + HW * 1, 3 + HW * 1, | ||
2 + HW * 2, 3 + HW * 2, | ||
2 + HW * 3, 3 + HW * 3, | ||
2 + HW * 4, 3 + HW * 4, | ||
] + (HW ** 2 - (4 - 2) * (5 - 1)) * [-1], | ||
]) | ||
for i, a, b in zip(range(correct.shape[0]), indices.numpy(), correct): | ||
assert np.allclose(a, b), '%d | %s != %s' % (i, a, b) | ||
continue | ||
return | ||
|
||
# test area2indices full | ||
def test_area2indices_full(): | ||
hwOld = 10 | ||
indices = area2indices([[0, 0]], [[hwOld, hwOld]], hwOld) | ||
correct = np.array([range(hwOld ** 2)]) | ||
for i, a, b in zip(range(correct.shape[0]), indices.numpy(), correct): | ||
assert np.allclose(a, b), '%d | %s != %s' % (i, a, b) | ||
continue | ||
return | ||
###################################################################### | ||
# extractBlured same as extractInterpolated after applying gaussians | ||
# TODO: find out why the results are so numerically instable | ||
def test_extractBlured_same(): | ||
H = 10 | ||
img = np.random.rand(H, H, 3).astype(np.float32) | ||
points = np.array([ | ||
[0.5, 0.5], [0.1, 0.1], [0.9, 0.9], [0.1, 0.9], [0.9, 0.1], [0.7, 0.2] | ||
]).astype(np.float32) | ||
R = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).astype(np.float32) | ||
imgBlured = applyBluring(img[None], gaussian_kernel(H, R)) | ||
assert imgBlured.shape == (6, H, H, 3) | ||
|
||
blur = extractBlured(R) | ||
blured = blur(img, points, ptR=np.full((6, 1), 1.0)) | ||
|
||
extracted = extractInterpolated(imgBlured[0, None], points[None])[0] | ||
assert extracted.shape == blured.shape | ||
for i, a, b in zip(range(extracted.shape[0]), extracted, blured): | ||
diff = np.abs(a - b).max() | ||
assert diff < 5e-2, '%d | %s != %s' % (i, a, b) | ||
continue | ||
return |