forked from yu4u/noise2noise
-
Notifications
You must be signed in to change notification settings - Fork 0
/
generator.py
63 lines (50 loc) · 2.16 KB
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from pathlib import Path
import random
import numpy as np
import cv2
from keras.utils import Sequence
class NoisyImageGenerator(Sequence):
def __init__(self, image_dir, source_noise_model, target_noise_model, batch_size=32, image_size=64):
self.image_paths = list(Path(image_dir).glob("*.jpg"))
self.source_noise_model = source_noise_model
self.target_noise_model = target_noise_model
self.image_num = len(self.image_paths)
self.batch_size = batch_size
self.image_size = image_size
def __len__(self):
return self.image_num // self.batch_size
def __getitem__(self, idx):
batch_size = self.batch_size
image_size = self.image_size
x = np.zeros((batch_size, image_size, image_size, 3), dtype=np.uint8)
y = np.zeros((batch_size, image_size, image_size, 3), dtype=np.uint8)
sample_id = 0
while True:
image_path = random.choice(self.image_paths)
image = cv2.imread(str(image_path))
h, w, _ = image.shape
if h >= image_size and w >= image_size:
h, w, _ = image.shape
i = np.random.randint(h - image_size + 1)
j = np.random.randint(w - image_size + 1)
clean_patch = image[i:i + image_size, j:j + image_size]
x[sample_id] = self.source_noise_model(clean_patch)
y[sample_id] = self.target_noise_model(clean_patch)
sample_id += 1
if sample_id == batch_size:
return x, y
class ValGenerator(Sequence):
def __init__(self, image_dir, val_noise_model):
image_paths = list(Path(image_dir).glob("*.*"))
self.image_num = len(image_paths)
self.data = []
for image_path in image_paths:
y = cv2.imread(str(image_path))
h, w, _ = y.shape
y = y[:(h // 16) * 16, :(w // 16) * 16] # for stride (maximum 16)
x = val_noise_model(y)
self.data.append([np.expand_dims(x, axis=0), np.expand_dims(y, axis=0)])
def __len__(self):
return self.image_num
def __getitem__(self, idx):
return self.data[idx]