-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
335 lines (316 loc) · 14.6 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
import numpy
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import (
Dense,
Dropout,
Flatten,
Input,
BatchNormalization, #Added
)
from tensorflow.keras.applications import (
MobileNet,
MobileNetV2,
InceptionResNetV2,
InceptionV3,
ResNet50,
ResNet50V2,
ResNet101V2,
NASNetLarge,
NASNetMobile,
Xception
)
from layers import (
#BatchNormalization,
ArcMarginPenaltyLogists,
AddMarginPenaltyLogists,
MulMarginPenaltyLogists,
CurMarginPenaltyLogists,
CadMarginPenaltyLogists,
AdaMarginPenaltyLogists
)
from backbone.efficientnet_lite import (
EfficientNetLite0,
EfficientNetLite1,
EfficientNetLite2,
EfficientNetLite3,
EfficientNetLite4,
EfficientNetLite5,
EfficientNetLite6
)
from backbone.efficientnet import (
EfficientNetB0,
EfficientNetB1,
EfficientNetB2,
EfficientNetB3,
EfficientNetB4,
EfficientNetB5,
EfficientNetB6,
EfficientNetB7
)
from backbone.mobilenet_v3 import (
MobileNetV3Small,
MobileNetV3Large
)
from backbone.mnasnet import (
MnasNetModel
)
WEIGHTS_DIR = "./weights/"
def _regularizer(weights_decay=5e-4):
return tf.keras.regularizers.l2(weights_decay)
def Backbone(backbone_type='ResNet50V2', use_pretrain=True):
"""Backbone Model"""
weights = None
if use_pretrain:
weights = 'imagenet'
def backbone(x_in):
if backbone_type == 'ResNet50':
return ResNet50(input_shape=x_in.shape[1:], include_top=False,
weights=weights)(x_in)
elif backbone_type == 'ResNet50V2':
return ResNet50V2(input_shape=x_in.shape[1:], include_top=False,
weights=weights)(x_in)
elif backbone_type == 'ResNet101V2':
return ResNet101V2(input_shape=x_in.shape[1:], include_top=False,
weights=weights)(x_in)
elif backbone_type == 'InceptionResNetV2':
return InceptionResNetV2(input_shape=x_in.shape[1:], include_top=False,
weights=weights)(x_in)
elif backbone_type == 'InceptionV3':
return InceptionV3(input_shape=x_in.shape[1:], include_top=False,
weights=weights)(x_in)
elif backbone_type == 'MobileNet':
return MobileNet(input_shape=x_in.shape[1:], include_top=False,
weights=weights)(x_in)
elif backbone_type == 'MobileNetV2':
return MobileNetV2(input_shape=x_in.shape[1:], include_top=False,
weights=weights)(x_in)
elif backbone_type == 'NASNetLarge':
model = NASNetLarge(input_shape=x_in.shape[1:], include_top=False,
weights=None)
model.load_weights(WEIGHTS_DIR+"nasnet_large_no_top.h5")
return model(x_in)
elif backbone_type == 'NASNetMobile':
model = NASNetMobile(input_shape=x_in.shape[1:], include_top=False,
weights=None)
model.load_weights(WEIGHTS_DIR+"nasnet_mobile_no_top.h5")
return model(x_in)
elif backbone_type == 'Xception':
return Xception(input_shape=x_in.shape[1:], include_top=False,
weights=weights)(x_in)
elif backbone_type == 'MobileNetV3Small':
model = MobileNetV3Small(input_shape=x_in.shape[1:], include_top=False,
weights=None)
model.load_weights(WEIGHTS_DIR+"mobilenet_v3_small_notop.ckpt")
return model(x_in)
elif backbone_type == 'MobileNetV3Large':
model = MobileNetV3Large(input_shape=x_in.shape[1:], include_top=False,
weights=None)
model.load_weights(WEIGHTS_DIR+"mobilenet_v3_large_notop.ckpt")
return model(x_in)
elif backbone_type == 'EfficientNetLite0':
model = EfficientNetLite0(input_shape=x_in.shape[1:], include_top=False,
weights=None)
model.load_weights(WEIGHTS_DIR+"efficientnet_lite0_notop.ckpt")
return model(x_in)
elif backbone_type == 'EfficientNetLite1':
model = EfficientNetLite1(input_shape=x_in.shape[1:], include_top=False,
weights=None)
model.load_weights(WEIGHTS_DIR+"efficientnet_lite1_notop.ckpt")
return model(x_in)
elif backbone_type == 'EfficientNetLite2':
model = EfficientNetLite2(input_shape=x_in.shape[1:], include_top=False,
weights=None)
model.load_weights(WEIGHTS_DIR+"efficientnet_lite2_notop.ckpt")
return model(x_in)
elif backbone_type == 'EfficientNetLite3':
model = EfficientNetLite3(input_shape=x_in.shape[1:], include_top=False,
weights=None)
model.load_weights(WEIGHTS_DIR+"efficientnet_lite3_notop.ckpt")
return model(x_in)
elif backbone_type == 'EfficientNetLite4':
model = EfficientNetLite4(input_shape=x_in.shape[1:], include_top=False,
weights=None)
model.load_weights(WEIGHTS_DIR+"efficientnet_lite4_notop.ckpt")
return model(x_in)
elif backbone_type == 'EfficientNetLite5':
model = EfficientNetLite5(input_shape=x_in.shape[1:], include_top=False,
weights=None)
model.load_weights(WEIGHTS_DIR+"efficientnet_lite5_notop.ckpt")
return model(x_in)
elif backbone_type == 'EfficientNetLite6':
model = EfficientNetLite6(input_shape=x_in.shape[1:], include_top=False,
weights=None)
model.load_weights(WEIGHTS_DIR+"efficientnet_lite6_notop.ckpt")
return model(x_in)
elif backbone_type == 'EfficientNetB0':
model = EfficientNetB0(input_shape=x_in.shape[1:], include_top=False,
weights=None)
model.load_weights(WEIGHTS_DIR+"efficientnetb0_notop.ckpt")
return model(x_in)
elif backbone_type == 'EfficientNetB1':
model = EfficientNetB1(input_shape=x_in.shape[1:], include_top=False,
weights=None)
model.load_weights(WEIGHTS_DIR+"efficientnetb1_notop.ckpt")
return model(x_in)
elif backbone_type == 'EfficientNetB2':
model = EfficientNetB2(input_shape=x_in.shape[1:], include_top=False,
weights=None)
model.load_weights(WEIGHTS_DIR+"efficientnetb2_notop.ckpt")
return model(x_in)
elif backbone_type == 'EfficientNetB3':
model = EfficientNetB3(input_shape=x_in.shape[1:], include_top=False,
weights=None)
model.load_weights(WEIGHTS_DIR+"efficientnetb3_notop.ckpt")
return model(x_in)
elif backbone_type == 'EfficientNetB4':
model = EfficientNetB4(input_shape=x_in.shape[1:], include_top=False,
weights=None)
model.load_weights(WEIGHTS_DIR+"efficientnetb4_notop.ckpt")
return model(x_in)
elif backbone_type == 'EfficientNetB5':
model = EfficientNetB5(input_shape=x_in.shape[1:], include_top=False,
weights=None)
model.load_weights(WEIGHTS_DIR+"efficientnetb5_notop.ckpt")
return model(x_in)
elif backbone_type == 'EfficientNetB6':
model = EfficientNetB6(input_shape=x_in.shape[1:], include_top=False,
weights=None)
if use_pretrain:
model.load_weights(WEIGHTS_DIR+"efficientnetb6_notop.ckpt")
return model(x_in)
elif backbone_type == 'EfficientNetB7':
model = EfficientNetB7(input_shape=x_in.shape[1:], include_top=False,
weights=None)
model.load_weights(WEIGHTS_DIR+"efficientnetb7_notop.ckpt")
return model(x_in)
elif backbone_type == 'MnasNetA1':
return MnasNetModel(input_shape=x_in.shape[1:], include_top=False,
weights=None, name="MnasNetA1")(x_in)
elif backbone_type == 'MnasNetB1':
return MnasNetModel(input_shape=x_in.shape[1:], include_top=False,
weights=None, name="MnasNetB1")(x_in)
elif backbone_type == 'MnasNetSmall':
return MnasNetModel(input_shape=x_in.shape[1:], include_top=False,
weights=None, name="MnasNetSmall")(x_in)
else:
raise TypeError('backbone_type error!')
return backbone
def OutputLayer(embd_shape, w_decay=5e-4,trainable=False, name='OutputLayer'):
"""Output Later"""
def output_layer(x_in):
x = inputs = Input(x_in.shape[1:])
x = BatchNormalization(trainable=trainable, name='output_batch_norm_1')(x)
x = Dropout(rate=0.5)(x)
x = Flatten()(x)
x = Dense(embd_shape, kernel_regularizer=_regularizer(w_decay))(x)
x = BatchNormalization(trainable=trainable, name='output_batch_norm_2')(x)
return Model(inputs, x, name=name)(x_in)
return output_layer
def ArcHead(num_classes, margin=0.5, logist_scale=64, projection_head=False, name='ArcHead'):
"""Arc Head"""
def arc_head(x_in, y_in):
x = inputs1 = Input(x_in.shape[1:])
y = Input(y_in.shape[1:])
# nonlinear projection head
if projection_head:
x = Dense(32, activation='relu')(x)
# x = Dense(64, activation='relu', use_bias=True, bias_initializer='zeros')(x)
x = ArcMarginPenaltyLogists(num_classes=num_classes,
margin=margin,
logist_scale=logist_scale)(x, y)
return Model((inputs1, y), x, name=name)((x_in, y_in))
return arc_head
def CosHead(num_classes, margin=0.35, logist_scale=64, name='CosHead'):
"""Cos Head"""
def cos_head(x_in, y_in):
x = inputs1 = Input(x_in.shape[1:])
y = Input(y_in.shape[1:])
x = AddMarginPenaltyLogists(num_classes=num_classes,
margin=margin,
logist_scale=logist_scale)(x, y)
return Model((inputs1, y), x, name=name)((x_in, y_in))
return cos_head
def SphereHead(num_classes, margin=1.35, logist_scale=30, name='SphereHead'):
"""Sphere Head"""
def sphere_head(x_in, y_in):
x = inputs1 = Input(x_in.shape[1:])
y = Input(y_in.shape[1:], dtype=tf.int32)
x = MulMarginPenaltyLogists(num_classes=num_classes, margin=margin, logist_scale=logist_scale)(x, y)
return Model((inputs1, y), x, name=name)((x_in, y_in))
return sphere_head
def NormHead(num_classes, w_decay=5e-4, name='NormHead'):
"""Norm Head"""
def norm_head(x_in):
x = inputs = Input(x_in.shape[1:])
x = Dense(num_classes, kernel_regularizer=_regularizer(w_decay))(x)
return Model(inputs, x, name=name)(x_in)
return norm_head
def CurHead(num_classes, margin=0.35, logist_scale=64, name='CurHead'):
"""Cur Head"""
def cur_head(x_in, y_in):
x = inputs1 = Input(x_in.shape[1:])
y = Input(y_in.shape[1:], dtype=tf.int32)
x = CurMarginPenaltyLogists(num_classes=num_classes,
margin=margin,
logist_scale=logist_scale)(x, y)
return Model((inputs1, y), x, name=name)((x_in, y_in))
return cur_head
def CadHead(num_classes, margin=0.35, logist_scale=64, name='CadHead'):
"""Cad Head"""
def cad_head(x_in, y_in):
x = inputs1 = Input(x_in.shape[1:])
y = Input(y_in.shape[1:], dtype=tf.int32)
x = CadMarginPenaltyLogists(num_classes=num_classes,
margin=margin,
logist_scale=logist_scale)(x, y)
return Model((inputs1, y), x, name=name)((x_in, y_in))
return cad_head
def AdaHead(num_classes, margin=0.35, logist_scale=64, name='AdaHead'):
"""Ada Head"""
def ada_head(x_in, y_in):
x = inputs1 = Input(x_in.shape[1:])
y = Input(y_in.shape[1:], dtype=tf.int32)
x = AdaMarginPenaltyLogists(num_classes=num_classes,
margin=margin,
logist_scale=logist_scale)(x, y)
return Model((inputs1, y), x, name=name)((x_in, y_in))
return ada_head
def ArcFaceModel(size=None, channels=3, num_classes=None, name='arcface_model',
margin=0.5, logist_scale=64, embd_shape=512,
head_type='ArcHead', backbone_type='ResNet50',
w_decay=5e-4, use_pretrain=True, training=False, projection_head=False):
"""Arc Face Model"""
x = inputs = Input([size, size, channels], name='input_image')
x = Backbone(backbone_type=backbone_type, use_pretrain=use_pretrain)(x)
# tf.print("after backbone: ", x)
# print("after backbone: ", x.numpy())
embds = OutputLayer(embd_shape, w_decay=w_decay, trainable=training)(x)
# tf.print("after outputlayer: ", embds)
# print("after outputlayer: ", embds.numpy())
if training:
assert num_classes is not None
labels = Input([], name='label', dtype=tf.int32)
if head_type == 'ArcHead':
logist = ArcHead(num_classes=num_classes, margin=margin,
logist_scale=logist_scale, projection_head=projection_head)(embds, labels)
elif head_type == 'CosHead':
logist = CosHead(num_classes=num_classes, margin=margin,
logist_scale=logist_scale)(embds, labels)
elif head_type == 'SphereHead':
logist = SphereHead(num_classes=num_classes, margin=margin,
logist_scale=logist_scale)(embds, labels)
elif head_type == 'CurHead':
logist = CurHead(num_classes=num_classes, margin=margin,
logist_scale=logist_scale)(embds, labels)
elif head_type == 'CadHead':
logist = CadHead(num_classes=num_classes, margin=margin,
logist_scale=logist_scale)(embds, labels)
elif head_type == 'AdaHead':
logist = AdaHead(num_classes=num_classes, margin=margin,
logist_scale=logist_scale)(embds, labels)
else:
logist = NormHead(num_classes=num_classes, w_decay=w_decay)(embds)
return Model((inputs, labels), logist, name=name)
else:
return Model(inputs, embds, name=name)