forked from XPixelGroup/DiffBIR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcodeformer.py
executable file
·109 lines (96 loc) · 3.78 KB
/
codeformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from typing import Sequence, Dict, Union
import math
import time
import numpy as np
import cv2
from PIL import Image
import torch.utils.data as data
from utils.file import load_file_list
from utils.image import center_crop_arr, augment, random_crop_arr
from utils.degradation import (
random_mixed_kernels, random_add_gaussian_noise, random_add_jpg_compression
)
class CodeformerDataset(data.Dataset):
def __init__(
self,
file_list: str,
out_size: int,
crop_type: str,
use_hflip: bool,
blur_kernel_size: int,
kernel_list: Sequence[str],
kernel_prob: Sequence[float],
blur_sigma: Sequence[float],
downsample_range: Sequence[float],
noise_range: Sequence[float],
jpeg_range: Sequence[int]
) -> "CodeformerDataset":
super(CodeformerDataset, self).__init__()
self.file_list = file_list
self.paths = load_file_list(file_list)
self.out_size = out_size
self.crop_type = crop_type
assert self.crop_type in ["none", "center", "random"]
self.use_hflip = use_hflip
# degradation configurations
self.blur_kernel_size = blur_kernel_size
self.kernel_list = kernel_list
self.kernel_prob = kernel_prob
self.blur_sigma = blur_sigma
self.downsample_range = downsample_range
self.noise_range = noise_range
self.jpeg_range = jpeg_range
def __getitem__(self, index: int) -> Dict[str, Union[np.ndarray, str]]:
# load gt image
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
gt_path = self.paths[index]
success = False
for _ in range(3):
try:
pil_img = Image.open(gt_path).convert("RGB")
success = True
break
except:
time.sleep(1)
assert success, f"failed to load image {gt_path}"
if self.crop_type == "center":
pil_img_gt = center_crop_arr(pil_img, self.out_size)
elif self.crop_type == "random":
pil_img_gt = random_crop_arr(pil_img, self.out_size)
else:
pil_img_gt = np.array(pil_img)
assert pil_img_gt.shape[:2] == (self.out_size, self.out_size)
img_gt = (pil_img_gt[..., ::-1] / 255.0).astype(np.float32)
# random horizontal flip
img_gt = augment(img_gt, hflip=self.use_hflip, rotation=False, return_status=False)
h, w, _ = img_gt.shape
# ------------------------ generate lq image ------------------------ #
# blur
kernel = random_mixed_kernels(
self.kernel_list,
self.kernel_prob,
self.blur_kernel_size,
self.blur_sigma,
self.blur_sigma,
[-math.pi, math.pi],
noise_range=None
)
img_lq = cv2.filter2D(img_gt, -1, kernel)
# downsample
scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
# noise
if self.noise_range is not None:
img_lq = random_add_gaussian_noise(img_lq, self.noise_range)
# jpeg compression
if self.jpeg_range is not None:
img_lq = random_add_jpg_compression(img_lq, self.jpeg_range)
# resize to original size
img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)
# BGR to RGB, [-1, 1]
target = (img_gt[..., ::-1] * 2 - 1).astype(np.float32)
# BGR to RGB, [0, 1]
source = img_lq[..., ::-1].astype(np.float32)
return dict(jpg=target, txt="", hint=source)
def __len__(self) -> int:
return len(self.paths)