Skip to content

Commit

Permalink
add grayscale judgement (sczhou#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
sczhou committed Sep 14, 2022
1 parent 4d598f8 commit bddee53
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 1 deletion.
1 change: 1 addition & 0 deletions basicsr/utils/img_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,4 @@ def crop_border(imgs, crop_border):
return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
else:
return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]

8 changes: 7 additions & 1 deletion facelib/utils/face_restoration_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from facelib.detection import init_detection_model
from facelib.parsing import init_parsing_model
from facelib.utils.misc import img2tensor, imwrite
from facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray


def get_largest_face(det_faces, h, w):
Expand Down Expand Up @@ -125,6 +125,9 @@ def read_image(self, img):
img = img[:, :, 0:3]

self.input_img = img
self.is_gray = is_gray(img, threshold=5)
if self.is_gray:
print('Grayscale input: True')

if min(self.input_img.shape[:2])<512:
f = 512.0/min(self.input_img.shape[:2])
Expand Down Expand Up @@ -416,6 +419,9 @@ def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box
fuse_mask = (inv_soft_parse_mask<inv_soft_mask).astype('int')
inv_soft_mask = inv_soft_parse_mask*fuse_mask + inv_soft_mask*(1-fuse_mask)

if self.is_gray:
pasted_face = bgr2gray(pasted_face) # convert img into grayscale

if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
alpha = upsample_img[:, :, 3:]
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
Expand Down
33 changes: 33 additions & 0 deletions facelib/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import cv2
import os
import os.path as osp
import numpy as np
from PIL import Image
import torch
from torch.hub import download_url_to_file, get_dir
from urllib.parse import urlparse
Expand Down Expand Up @@ -139,3 +141,34 @@ def _scandir(dir_path, suffix, recursive):
continue

return _scandir(dir_path, suffix=suffix, recursive=recursive)


def is_gray(img, threshold=10):
img = Image.fromarray(img)
if len(img.getbands()) == 1:
return True
img1 = np.asarray(img.getchannel(channel=0), dtype=np.int16)
img2 = np.asarray(img.getchannel(channel=1), dtype=np.int16)
img3 = np.asarray(img.getchannel(channel=2), dtype=np.int16)
diff1 = (img1 - img2).var()
diff2 = (img2 - img3).var()
diff3 = (img3 - img1).var()
diff_sum = (diff1 + diff2 + diff3) / 3.0
if diff_sum <= threshold:
return True
else:
return False

def rgb2gray(img, out_channel=3):
r, g, b = img[:,:,0], img[:,:,1], img[:,:,2]
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
if out_channel == 3:
gray = gray[:,:,np.newaxis].repeat(3, axis=2)
return gray

def bgr2gray(img, out_channel=3):
b, g, r = img[:,:,0], img[:,:,1], img[:,:,2]
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
if out_channel == 3:
gray = gray[:,:,np.newaxis].repeat(3, axis=2)
return gray

0 comments on commit bddee53

Please sign in to comment.