-
-
Notifications
You must be signed in to change notification settings - Fork 422
/
Copy pathvggface.py
112 lines (94 loc) · 4.57 KB
/
vggface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
'''VGGFace models for Keras.
# Reference:
- [Deep Face Recognition](http://www.robots.ox.ac.uk/~vgg/publications/2015/Parkhi15/parkhi15.pdf)
- [VGGFace2: A dataset for recognising faces across pose and age](http://www.robots.ox.ac.uk/~vgg/data/vgg_face2/vggface2.pdf)
'''
from __future__ import print_function
from keras_vggface.models import RESNET50, VGG16, SENET50
def VGGFace(include_top=True, model='vgg16', weights='vggface',
input_tensor=None, input_shape=None,
pooling=None,
classes=None):
"""Instantiates the VGGFace architectures.
Optionally loads weights pre-trained
on VGGFace datasets. Note that when using TensorFlow,
for best performance you should set
`image_data_format="channels_last"` in your Keras config
at ~/.keras/keras.json.
The model and the weights are compatible with both
TensorFlow and Theano. The data format
convention used by the model is the one
specified in your Keras config file.
# Arguments
include_top: whether to include the 3 fully-connected
layers at the top of the network.
weights: one of `None` (random initialization)
or "vggface" (pre-training on VGGFACE datasets).
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
to use as image input for the model.
model: selects the one of the available architectures
vgg16, resnet50 or senet50 default is vgg16.
input_shape: optional shape tuple, only to be specified
if `include_top` is False (otherwise the input shape
has to be `(224, 224, 3)` (with `channels_last` data format)
or `(3, 224, 244)` (with `channels_first` data format).
It should have exactly 3 inputs channels,
and width and height should be no smaller than 48.
E.g. `(200, 200, 3)` would be one valid value.
pooling: Optional pooling mode for feature extraction
when `include_top` is `False`.
- `None` means that the output of the model will be
the 4D tensor output of the
last convolutional layer.
- `avg` means that global average pooling
will be applied to the output of the
last convolutional layer, and thus
the output of the model will be a 2D tensor.
- `max` means that global max pooling will
be applied.
classes: optional number of classes to classify images
into, only to be specified if `include_top` is True, and
if no `weights` argument is specified.
# Returns
A Keras model instance.
# Raises
ValueError: in case of invalid argument for `weights`,
or invalid input shape.
"""
if weights not in {'vggface', None}:
raise ValueError('The `weights` argument should be either '
'`None` (random initialization) or `vggface`'
'(pre-training on VGGFace Datasets).')
if model == 'vgg16':
if classes is None:
classes = 2622
if weights == 'vggface' and include_top and classes != 2622:
raise ValueError(
'If using `weights` as vggface original with `include_top`'
' as true, `classes` should be 2622')
return VGG16(include_top=include_top, input_tensor=input_tensor,
input_shape=input_shape, pooling=pooling,
weights=weights,
classes=classes)
if model == 'resnet50':
if classes is None:
classes = 8631
if weights == 'vggface' and include_top and classes != 8631:
raise ValueError(
'If using `weights` as vggface original with `include_top`'
' as true, `classes` should be 8631')
return RESNET50(include_top=include_top, input_tensor=input_tensor,
input_shape=input_shape, pooling=pooling,
weights=weights,
classes=classes)
if model == 'senet50':
if classes is None:
classes = 8631
if weights == 'vggface' and include_top and classes != 8631:
raise ValueError(
'If using `weights` as vggface original with `include_top`'
' as true, `classes` should be 8631')
return SENET50(include_top=include_top, input_tensor=input_tensor,
input_shape=input_shape, pooling=pooling,
weights=weights,
classes=classes)