Skip to content

Commit

Permalink
fast base/detail decomposition with bilateral filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
teboli committed Oct 4, 2022
1 parent c0baa0b commit 85263c0
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 66 deletions.
4 changes: 2 additions & 2 deletions polyblur/blur_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def compute_gaussian_parameters(magnitudes_normal, magnitudes_ortho, c, b):
sigma = torch.sqrt(sigma)
## Compute rho
rho = cc / (magnitudes_ortho * magnitudes_ortho + 1e-8) - bb
rho = torch.clamp(sigma, min=0.09, max=16.0)
rho = torch.clamp(rho, min=0.09, max=16.0)
rho = torch.sqrt(rho)
return sigma, rho

Expand Down Expand Up @@ -228,7 +228,7 @@ def create_gaussian_filter(thetas, sigmas, rhos, ksize):
INV_SIGMA = INV_SIGMA.view(B, C, 1, 1, 2, 2) # (B,C,1,1,2,2)

# Create meshgrid for Gaussian
t = torch.arange(ksize, device=sigmas.device)
t = torch.arange(ksize, device=sigmas.device) - ((ksize-1) // 2)
X, Y = torch.meshgrid(t, t, indexing='xy')
Z = torch.stack([X, Y], dim=-1).unsqueeze(-1).float() # (k,k,2,1)
Z_t = Z.transpose(-2, -1) # (k,k,1,2)
Expand Down
18 changes: 9 additions & 9 deletions polyblur/deblurring.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch.fft
import torch.nn as nn
import torch.nn.functional as F
from .filters import fourier_gradients

from . import edgetaper
from . import filters
Expand Down Expand Up @@ -53,7 +52,7 @@ def polyblur_deblurring(img, n_iter=1, c=0.352, b=0.768, alpha=2, beta=3, sigma_
## Init the variables
start = time()
impred = img
grad_img = fourier_gradients(img)
grad_img = filters.fourier_gradients(img)
thetas = torch.linspace(0, 180, n_angles+1, device=img.device).unsqueeze(0).long() # (1,n)
interpolated_thetas = torch.arange(0, 180, 180 / n_interpolated_angles, device=img.device).unsqueeze(0).long() # (1,N)
print('-- warming up: %1.5f' % (time() - start))
Expand Down Expand Up @@ -95,7 +94,8 @@ def edge_aware_filtering(img, sigma_s, sigma_r):
:param sigma_s: float, smoothness parameter for domain transform
:return: img_smoothed, img_noise: torch.tensors of same size as img, the smooth and noise components of img
"""
img_smoothed = domain_transform.recursive_filter(img, sigma_r=sigma_r, sigma_s=sigma_s, num_iterations=1)
# img_smoothed = domain_transform.recursive_filter(img, sigma_r=sigma_r, sigma_s=sigma_s, num_iterations=1)
img_smoothed = filters.bilateral_filter(img)
img_noise = img - img_smoothed
return img_smoothed, img_noise

Expand Down Expand Up @@ -148,8 +148,8 @@ def compute_polynomial_fft(img, kernel, alpha, b, not_symmetric=False):
Y = C * Y
## C implements the pure phase filter described in the Polyblur article
## needed to deblur non symmetic kernels. For Gaussian kernels (symmetric) it has no effect.
a3 = alpha / 2 - b + 2;
a2 = 3 * b - alpha - 6;
a3 = alpha / 2 - b + 2
a2 = 3 * b - alpha - 6
a1 = 5 - 3 * b + alpha / 2
X = a3 * Y
X = K * X + a2 * Y
Expand All @@ -161,7 +161,7 @@ def compute_polynomial_fft(img, kernel, alpha, b, not_symmetric=False):

@torch.jit.script
def grad_prod_(grad_x, grad_y, gout_x, gout_y):
return -grad_x * gout_x + -grad_y * grad_y
return (- grad_x * gout_x) + (- grad_y * grad_y)


@torch.jit.script
Expand All @@ -188,10 +188,10 @@ def halo_masking(img, imout, grad_img=None):
:return torch.tensor of same size as img, the halo corrected image(s)
"""
if grad_img is None:
grad_x, grad_y = fourier_gradients(img)
grad_x, grad_y = filters.fourier_gradients(img)
else:
grad_x, grad_y = grad_img
gout_x, gout_y = fourier_gradients(imout)
gout_x, gout_y = filters.fourier_gradients(imout)
M = grad_prod_(grad_x, grad_y, gout_x, gout_y)
nM = torch.sum(grad_square_(grad_x, grad_y), dim=(-2, -1), keepdim=True)
z = grad_div_and_clip_(M, nM)
Expand Down Expand Up @@ -222,9 +222,9 @@ def inverse_filtering_rank3(img, kernel, alpha=2, b=4, correlate=False, remove_h
imout = compute_polynomial(img, kernel, alpha, b, method=method)
## Crop
imout = utils.crop_with_kernel(imout, kernel)
img = utils.crop_with_kernel(img, kernel)
## Mask deblurring halos
if remove_halo:
img = utils.crop_with_kernel(img, kernel)
imout = halo_masking(img, imout, grad_img)
return torch.clamp(imout, 0.0, 1.0)

Expand Down
234 changes: 185 additions & 49 deletions polyblur/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,200 @@
import torch.fft
import scipy.fft

from . import utils
# TODO: import separable_gaussian_kernels

def convolve2d(img, kernel, padding='same', method='direct'):

#####################################################################
######################## Convolution 2D #############################
#####################################################################

def convolve2d(img, kernel, ksize=25, padding='same', method='direct'):
# TODO: replace 1d kernels by (sigma,rho,theta)
"""
A per kernel wrapper for torch.nn.functional.conv2d
:param img: (B,C,H,W) torch.tensor, the input images
:param kernel: (B,C,h,w) or (B,1,h,w) torch.tensor, the blur kernels
:param kernel: (B,C,h,w) or
(B,1,h,w) torch.tensor, the 2d blur kernels (valid for both deblurring methods), or
[(B,C,h), (B,C,w)] or
[(B,1,h), (B,1,h)], the separable 1d blur kernels (valid only for spatial deblurring)
:param padding: string, can be 'valid' or 'same'
:
:return imout: (B,C,H,W) torch.tensor, the filtered images
"""
if method == 'direct':
if kernel.shape[1] == img.shape[1]:
return F.conv2d(img, kernel, groups=img.shape[1], padding=padding)
else:
imout = [F.conv2d(img[:,c:c+1], kernel, padding=padding) for c in range(img.shape[1])]
return torch.cat(imout, dim=1)
if type(kernel) == torch.Tensor: # if we have 2D kernels, do general 2D convolution
return conv2d_(img, kernel, padding)
else: # else, do Gaussian-specific 1D separable convolution
return gaussian_separable_conv2d_(img, kernel, ksize, padding)
elif method == 'fft':
ks = kernel.shape[-1]
X = torch.fft.fft2(F.pad(img, [ks, ks, ks, ks], mode='circular'))
assert(type(kernel) == torch.Tensor) # for FFT, we only use 2D kernels
X = torch.fft.fft2(utils.pad_with_kernel(img, kernel, mode='circular'))
K = p2o(kernel, X.shape[-2:])
return torch.real(torch.fft.ifft2(K * X))[..., ks:-ks, ks:-ks]
return utils.crop_with_kernel( torch.real(torch.fft.ifft2(K * X)), kernel )
else:
raise('Convolution method %s is not implemented' % method)


def conv2d_(img, kernel, padding='same'):
"""
Wrapper for F.conv2d with possibly multi-channels kernels.
"""
# if the number of kernel matches the number of
if kernel.shape[1] == img.shape[1]: # check how many color channels in the kernel
return F.conv2d(img, kernel, groups=kernel.shape[1], padding=padding)
else:
raise('%s is not implemented' % method)
img = [F.conv2d(img[:,c:c+1], kernel, padding=padding) for c in range(img.shape[1])]
return torch.cat(img, dim=1)


def gaussian_separable_conv2d_(img, kernel, ksize, padding='same', threshold=1e-4):
"""
Convolution with separable 1D Gaussian kernels on possibly non-orthogonal axes.
"""
sigma, rho, theta = kernel

## First process the orthogonal directions
mask = (theta % (np.pi / 2)) < threshold # if theta is 0, 90 or 180 degrees, orthogonal directions
if mask.any():
sigma_x = sigma[mask]
sigma_y = rho[mask]
img[mask] = gaussian_xy_separable_conv2d_(img[mask], sigma_x, sigma_y, ksize, padding)

## Second process the other directions
mask = (theta % (np.pi / 2)) >= threshold # else, general case
if mask.any():
sigma_u = sigma[mask]
sigma_v = sigma[mask]
phi = theta[mask]
img[mask] = gaussian_xt_sperabla_conv2d_(img[mask], sigma_u, sigma_v, phi, ksize)

return img


def gaussian_xy_separable_conv2d_(img, sigma_x, sigma_y, ksize, padding='same'):
# Create the 1D kernel along x
t = torch.arange(-ksize//2 + 1, ksize//2 + 1, device=device).view(1, 1, 1, ksize)
t = t * t
kernel = torch.exp( - t / (2 * sigma_x * sigma_x)) # (1, 1, 1, ksize)
kernel /= kernel.sum()

# Horizontal filter
img = F.conv2d(img, kernel, padding='same', groups=img.shape[1])
img = img.transpose(-1,-2)

# Create the 1D kernel along y
kernel = torch.exp( - t / (2 * sigma_y * sigma_y)) # (1, 1, 1, ksize)
kernel /= kernel.sum()

# Vertical filter
img = F.conv2d(img, kernel, padding='same', groups=img.shape[1])
return img.transpose(-1,-2)


def gaussian_xt_separable_conv2d_(img, sigma, rho, theta, ksize):
## Call CUDA code here
return img



#####################################################################
####################### Bilateral filter ############################
#####################################################################


def bilateral_filter(I, ksize=7, sigma_spatial=5.0, sigma_color=0.1):
## precompute the spatial kernel: each entry of gw is a square spatial difference
t = torch.arange(-ksize//2+1, ksize//2+1, device=I.device)
xx, yy = torch.meshgrid(t, t, indexing='xy')
gw = torch.exp(-(xx * xx + yy * yy) / (2 * sigma_spatial * sigma_spatial)) # (ksize, ksize)

## Create the padded array for computing the color shifts
I_padded = utils.pad_with_kernel(I, ksize=ksize)

## Filtering
var2_color = 2 * sigma_color * sigma_color
return bilateral_filter_loop_(I, I_padded, gw, var2_color, J, W)


def bilateral_filter_loop_(I, I_padded, gw, var2, do_for=True):
b, c, h, w = I.shape

if do_for: # memory-friendly option (Recommanded for larger images)
J = torch.zeros_like(I)
W = torch.zeros_like(I)
for z in range(gw.shape[0] * gw.shape[1]):
# compute the indices
x = z % gw.shape[0]
y = (z-x) // gw.shape[1]
yy = y + h
xx = x + w
# get the shifted image
I_shifted = I_padded[..., y:yy, x:xx]
# color weight
F = I_shifted - I # (B,C,H,W)
F = torch.exp(-F * F / var2)
# product with spatial weight
F *= gw[y, x] # (B,C,H,W)
J += F * I_shifted
W += F
else: # pytorch-friendly option (Faster for smaller images and/or batche sizes)
# get shifted images
I_shifted = utils.extract_tiles(I_padded, kernel_size=(h,w), stride=1) # (B,ksize*ksize,C,H,W)
F = I_shifted - I.unsqueeze(1)
F = torch.exp( - F * F / var2) # (B,ksize*ksize,C,H,W)
# product with spatial weights
F *= gw.view(-1, 1, 1, 1)
J = torch.sum(F * I_shifted, dim=1) # (B,C,H,W)
W = torch.sum(F, dim=1) # (B,C,H,W)
return J / (W + 1e-8)




#####################################################################
###################### Classical filters ############################
#####################################################################


@torch.jit.script
def fourier_gradients(images):
"""
Compute the image gradients using Fourier interpolation as in Eq. (21a) and (21b)
:param images: (B,C,H,W) torch.tensor
:return grad_x, grad_y: tuple of 2 images of same dimensions as images that
are the vertical and horizontal gradients
"""
## Find fast size for FFT
h, w = images.shape[-2:]
h_fast, w_fast = images.shape[-2:]
# h_fast = scipy.fft.next_fast_len(h)
# w_fast = scipy.fft.next_fast_len(w)
## compute FT
U = torch.fft.fft2(images, s=(h_fast, w_fast))
U = torch.fft.fftshift(U, dim=(-2, -1))
## Create the freqs components
freqh = (torch.arange(0, h_fast, device=images.device) - h_fast // 2)[None, None, :, None] / h_fast
freqw = (torch.arange(0, w_fast, device=images.device) - w_fast // 2)[None, None, None, :] / w_fast
## Compute gradients in Fourier domain
gxU = 2 * np.pi * freqw * (-torch.imag(U) + 1j * torch.real(U))
gxU = torch.fft.ifftshift(gxU, dim=(-2, -1))
gxu = torch.real(torch.fft.ifft2(gxU))
# gxu = crop(gxu, (h, w))
gyU = 2 * np.pi * freqh * (-torch.imag(U) + 1j * torch.real(U))
gyU = torch.fft.ifftshift(gyU, dim=(-2, -1))
gyu = torch.real(torch.fft.ifft2(gyU))
# gyu = crop(gyu, (h, w))
return gxu, gyu


def crop(image, new_size):
size = image.shape[-2:]
if size[0] - new_size[0] > 0:
image = image[..., :new_size[0], :]
if size[1] - new_size[1] > 0:
image = image[..., :new_size[1]]
return image


def gaussian_filter(sigma, theta, shift=np.array([0.0, 0.0]), k_size=np.array([15, 15])):
Expand Down Expand Up @@ -79,48 +250,13 @@ def dirac(dims):
return kernel


def crop(image, new_size):
size = image.shape[-2:]
if size[0] - new_size[0] > 0:
image = image[..., :new_size[0], :]
if size[1] - new_size[1] > 0:
image = image[..., :new_size[1]]
return image


@torch.jit.script
def fourier_gradients(images):
"""
Compute the image gradients using Fourier interpolation as in Eq. (21a) and (21b)
:param images: (B,C,H,W) torch.tensor
:return grad_x, grad_y: tuple of 2 images of same dimensions as images that
are the vertical and horizontal gradients
"""
## Find fast size for FFT
h, w = images.shape[-2:]
h_fast, w_fast = images.shape[-2:]
# h_fast = scipy.fft.next_fast_len(h)
# w_fast = scipy.fft.next_fast_len(w)
## compute FT
U = torch.fft.fft2(images, s=(h_fast, w_fast))
U = torch.fft.fftshift(U, dim=(-2, -1))
## Create the freqs components
freqh = (torch.arange(0, h_fast, device=images.device) - h_fast // 2)[None, None, :, None] / h_fast
freqw = (torch.arange(0, w_fast, device=images.device) - w_fast // 2)[None, None, None, :] / w_fast
## Compute gradients in Fourier domain
gxU = 2 * np.pi * freqw * (-torch.imag(U) + 1j * torch.real(U))
gxU = torch.fft.ifftshift(gxU, dim=(-2, -1))
gxu = torch.real(torch.fft.ifft2(gxU))
# gxu = crop(gxu, (h, w))
gyU = 2 * np.pi * freqh * (-torch.imag(U) + 1j * torch.real(U))
gyU = torch.fft.ifftshift(gyU, dim=(-2, -1))
gyu = torch.real(torch.fft.ifft2(gyU))
# gyu = crop(gyu, (h, w))
return gxu, gyu
#####################################################################
####################### Fourier kernel ##############################
#####################################################################


### From here, taken from https://github.com/cszn/USRNet/blob/master/utils/utils_deblur.py

def p2o(psf, shape):
'''
Convert point-spread function to optical transfer function.
Expand Down
26 changes: 21 additions & 5 deletions polyblur/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,27 @@ def to_uint(img):
return img


def pad_with_kernel(img, kernel):
ks = kernel.shape[-1] // 2
return F.pad(img, (ks, ks, ks, ks), mode='replicate')
def pad_with_kernel(img, kernel=None, ksize=3, mode='replicate'):
if kernel is not None:
ks = kernel.shape[-1] // 2
else:
ks = ksize // 2
return F.pad(img, (ks, ks, ks, ks), mode=mode)


def crop_with_kernel(img, kernel):
ks = kernel.shape[-1] // 2
def crop_with_kernel(img, kernel=None, ksize=3):
if kernel is not None:
ks = kernel.shape[-1] // 2
else:
ks = ksize // 2
return img[..., ks:-ks, ks:-ks]


def extract_tiles(img, kernel_size, stride=1):
b, c, _, _ = img.shape
h, w = kernel_size
tiles = F.unfold(img, kernel_size, stride) # (B,C*H*W,L)
tiles = tiles.permute(0, 2, 1) # (B,L,C*H*W)
tiles = tiles.view(b, -1, c, h ,w)
return tiles

2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ scikit-image
seaborn
tqdm
glob2
torch==1.11.0
ninja
torch==1.11.0

0 comments on commit 85263c0

Please sign in to comment.