-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils_bluring.py
100 lines (87 loc) · 3.37 KB
/
utils_bluring.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import tensorflow as tf
from NN.utils import extractInterpolated
def create1DGaussian(size, stds, shifts):
B = tf.shape(stds)[0]
tf.assert_equal(tf.shape(shifts), (B, ))
x = tf.linspace(-size // 2 + 1, size // 2 + 1, size)
x = tf.cast(x, tf.float32)
x = tf.tile(x[None], [B, 1]) + shifts[..., None]
x = tf.nn.softmax(-(x ** 2) / (2.0 * (stds ** 2)), axis=-1)
x = tf.reshape(x, [B, size])
return x
def gaussian_kernel(size, stdsPx, shifts=None):
if shifts is None:
shifts = tf.zeros((tf.shape(stdsPx)[0], 2))
stds = tf.cast(stdsPx, tf.float32)
B = tf.shape(stds)[0]
stds = tf.reshape(stds, [B, 1])
gX = create1DGaussian(size, stds, shifts[:, 0])[..., None]
gY = create1DGaussian(size, stds, shifts[:, 1])[..., None, :]
gauss = tf.matmul(gX, gY)
gauss = tf.reshape(gauss, [B, size, size, 1])
gauss = tf.tile(gauss, [1, 1, 1, 3])
tf.assert_equal(tf.shape(gauss), (B, size, size, 3))
gauss = tf.transpose(gauss, [1, 2, 3, 0]) # [B, size, size, 1] => [size, size, 1, B]
tf.assert_equal(tf.shape(gauss), (size, size, 3, B))
return gauss
def applyBluring(img, kernel):
tf.assert_rank(img, 4)
tf.assert_rank(kernel, 4)
tf.assert_equal(tf.shape(img)[0], 1)
B = tf.shape(kernel)[-1]
imgG = tf.nn.depthwise_conv2d(img, kernel, strides=[1, 1, 1, 1], padding='SAME')[0]
H, W = [tf.shape(imgG)[i] for i in range(2)]
imgG = tf.reshape(imgG, [H, W, 3, -1])
imgG = tf.transpose(imgG, (3, 0, 1, 2))
imgG = tf.reshape(imgG, (B, H, W, 3))
return imgG
def extractBlured(R):
'''
R: list of bluring radiuses, (B, 1)
'''
R = tf.reshape(R, (tf.size(R), 1))
maxR = tf.reduce_max(R)
maxR = tf.cast(maxR, tf.int32) + 1
gaussians = gaussian_kernel(maxR, R, shifts=tf.zeros((tf.shape(R)[0], 2)))
gaussiansN = tf.shape(gaussians)[-1]
def f(img, points, ptR):
img = img[None]
tf.assert_rank(img, 4)
B = tf.shape(points)[0]
tf.assert_equal(tf.shape(points), (B, 2))
tf.assert_equal(tf.shape(ptR), (B, 1))
blured = applyBluring(img, gaussians)
tf.assert_equal(tf.shape(blured), (gaussiansN, tf.shape(img)[1], tf.shape(img)[2], 3))
# extract the blured values
blured = extractInterpolated(blured, points[None])
tf.assert_equal(tf.shape(blured), (gaussiansN, B, 3))
# blured contains the blured values for each point in each gaussian
# we need to select the blured value for each point based on its radius
correspondingG = tf.transpose(ptR == R[..., 0])
tf.assert_equal(tf.shape(correspondingG), (gaussiansN, B))
idx = tf.where(correspondingG)
tf.assert_equal(tf.shape(idx), (B, 2))
blured = tf.gather_nd(blured, idx)
tf.assert_equal(tf.shape(blured), (B, 3))
return blured
return f
def extractBluredOne(maxR=None):
if maxR is None:
getMaxR = lambda R: tf.cast(tf.reduce_max(R), tf.int32) + 1
else:
getMaxR = lambda R: maxR
def f(img, points, R):
img = img[None]
tf.assert_rank(img, 4)
B = tf.shape(points)[0]
tf.assert_equal(tf.shape(points), (B, 2))
R = tf.reshape(R, (1, ))
gaussians = gaussian_kernel(getMaxR(R), R)
blured = applyBluring(img, gaussians)
gaussiansN = tf.shape(gaussians)[-1]
tf.assert_equal(tf.shape(blured), (gaussiansN, tf.shape(img)[1], tf.shape(img)[2], 3))
# extract the blured values
blured = extractInterpolated(blured, points[None])[0]
tf.assert_equal(tf.shape(blured), (B, 3))
return blured
return f