forked from yu4u/noise2noise
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
103 lines (79 loc) · 3.55 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from keras.models import Model
from keras.layers import Input, Add, PReLU, Conv2DTranspose, Concatenate, MaxPooling2D, UpSampling2D, Dropout
from keras.layers.convolutional import Conv2D
from keras.layers.normalization import BatchNormalization
from keras import backend as K
import tensorflow as tf
def tf_log10(x):
numerator = tf.log(x)
denominator = tf.log(tf.constant(10, dtype=numerator.dtype))
return numerator / denominator
def PSNR(y_true, y_pred):
max_pixel = 255.0
y_pred = K.clip(y_pred, 0.0, 255.0)
return 10.0 * tf_log10((max_pixel ** 2) / (K.mean(K.square(y_pred - y_true))))
def get_model(model_name="srresnet"):
if model_name == "srresnet":
return get_srresnet_model()
elif model_name == "unet":
return get_unet_model(out_ch=3)
else:
raise ValueError("model_name should be 'srresnet'or 'unet'")
# SRResNet
def get_srresnet_model(input_channel_num=3, feature_dim=64, resunit_num=16):
def _residual_block(inputs):
x = Conv2D(feature_dim, (3, 3), padding="same", kernel_initializer="he_normal")(inputs)
x = BatchNormalization()(x)
x = PReLU(shared_axes=[1, 2])(x)
x = Conv2D(feature_dim, (3, 3), padding="same", kernel_initializer="he_normal")(x)
x = BatchNormalization()(x)
m = Add()([x, inputs])
return m
inputs = Input(shape=(None, None, input_channel_num))
x = Conv2D(feature_dim, (3, 3), padding="same", kernel_initializer="he_normal")(inputs)
x = PReLU(shared_axes=[1, 2])(x)
x0 = x
for i in range(resunit_num):
x = _residual_block(x)
x = Conv2D(feature_dim, (3, 3), padding="same", kernel_initializer="he_normal")(x)
x = BatchNormalization()(x)
x = Add()([x, x0])
x = Conv2D(input_channel_num, (3, 3), padding="same", kernel_initializer="he_normal")(x)
model = Model(inputs=inputs, outputs=x)
return model
# UNet: code from https://github.com/pietz/unet-keras
def get_unet_model(input_channel_num=3, out_ch=3, start_ch=64, depth=4, inc_rate=2., activation='relu',
dropout=0.5, batchnorm=False, maxpool=True, upconv=True, residual=False):
def _conv_block(m, dim, acti, bn, res, do=0):
n = Conv2D(dim, 3, activation=acti, padding='same')(m)
n = BatchNormalization()(n) if bn else n
n = Dropout(do)(n) if do else n
n = Conv2D(dim, 3, activation=acti, padding='same')(n)
n = BatchNormalization()(n) if bn else n
return Concatenate()([m, n]) if res else n
def _level_block(m, dim, depth, inc, acti, do, bn, mp, up, res):
if depth > 0:
n = _conv_block(m, dim, acti, bn, res)
m = MaxPooling2D()(n) if mp else Conv2D(dim, 3, strides=2, padding='same')(n)
m = _level_block(m, int(inc * dim), depth - 1, inc, acti, do, bn, mp, up, res)
if up:
m = UpSampling2D()(m)
m = Conv2D(dim, 2, activation=acti, padding='same')(m)
else:
m = Conv2DTranspose(dim, 3, strides=2, activation=acti, padding='same')(m)
n = Concatenate()([n, m])
m = _conv_block(n, dim, acti, bn, res)
else:
m = _conv_block(m, dim, acti, bn, res, do)
return m
i = Input(shape=(None, None, input_channel_num))
o = _level_block(i, start_ch, depth, inc_rate, activation, dropout, batchnorm, maxpool, upconv, residual)
o = Conv2D(out_ch, 1)(o)
model = Model(inputs=i, outputs=o)
return model
def main():
# model = get_model()
model = get_model("unet")
model.summary()
if __name__ == '__main__':
main()