Skip to content

Commit

Permalink
Fix applications.
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Feb 25, 2017
1 parent e1a283e commit f516cc6
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 124 deletions.
46 changes: 15 additions & 31 deletions keras/applications/inception_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,8 @@
from .imagenet_utils import _obtain_input_shape


TH_WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/inception_v3_weights_th_dim_ordering_th_kernels.h5'
TF_WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/inception_v3_weights_tf_dim_ordering_tf_kernels.h5'
TH_WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/inception_v3_weights_th_dim_ordering_th_kernels_notop.h5'
TF_WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5'
WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/inception_v3_weights_tf_dim_ordering_tf_kernels.h5'
WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5'


def conv2d_bn(x, filters, num_row, num_col,
Expand Down Expand Up @@ -309,8 +307,7 @@ def InceptionV3(include_top=True, weights='imagenet',

if include_top:
# Classification block
x = AveragePooling2D((8, 8), strides=(8, 8), name='avg_pool')(x)
x = Flatten(name='flatten')(x)
x = GlobalAveragePooling2D(name='avg_pool')(x)
x = Dense(classes, activation='softmax', name='predictions')(x)
else:
if pooling == 'avg':
Expand All @@ -330,17 +327,6 @@ def InceptionV3(include_top=True, weights='imagenet',
# load weights
if weights == 'imagenet':
if K.image_data_format() == 'channels_first':
if include_top:
weights_path = get_file('inception_v3_weights_th_dim_ordering_th_kernels.h5',
TH_WEIGHTS_PATH,
cache_subdir='models',
md5_hash='b3baf3070cc4bf476d43a2ea61b0ca5f')
else:
weights_path = get_file('inception_v3_weights_th_dim_ordering_th_kernels_notop.h5',
TH_WEIGHTS_PATH_NO_TOP,
cache_subdir='models',
md5_hash='79aaa90ab4372b4593ba3df64e142f05')
model.load_weights(weights_path)
if K.backend() == 'tensorflow':
warnings.warn('You are using the TensorFlow backend, yet you '
'are using the Theano '
Expand All @@ -350,21 +336,19 @@ def InceptionV3(include_top=True, weights='imagenet',
'`image_data_format="channels_last"` in '
'your Keras config '
'at ~/.keras/keras.json.')
convert_all_kernels_in_model(model)
if include_top:
weights_path = get_file('inception_v3_weights_tf_dim_ordering_tf_kernels.h5',
WEIGHTS_PATH,
cache_subdir='models',
md5_hash='fe114b3ff2ea4bf891e9353d1bbfb32f')
else:
if include_top:
weights_path = get_file('inception_v3_weights_tf_dim_ordering_tf_kernels.h5',
TF_WEIGHTS_PATH,
cache_subdir='models',
md5_hash='fe114b3ff2ea4bf891e9353d1bbfb32f')
else:
weights_path = get_file('inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5',
TF_WEIGHTS_PATH_NO_TOP,
cache_subdir='models',
md5_hash='2f3609166de1d967d1a481094754f691')
model.load_weights(weights_path)
if K.backend() == 'theano':
convert_all_kernels_in_model(model)
weights_path = get_file('inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5',
WEIGHTS_PATH_NO_TOP,
cache_subdir='models',
md5_hash='2f3609166de1d967d1a481094754f691')
model.load_weights(weights_path)
if K.backend() == 'theano':
convert_all_kernels_in_model(model)
return model


Expand Down
58 changes: 25 additions & 33 deletions keras/applications/resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,15 @@
from ..models import Model
from .. import backend as K
from ..engine.topology import get_source_inputs
from ..utils.layer_utils import convert_all_kernels_in_model
from ..utils import layer_utils
from ..utils.data_utils import get_file
from .imagenet_utils import decode_predictions
from .imagenet_utils import preprocess_input
from .imagenet_utils import _obtain_input_shape


TH_WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_th_dim_ordering_th_kernels.h5'
TF_WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels.h5'
TH_WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_th_dim_ordering_th_kernels_notop.h5'
TF_WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels.h5'
WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'


def identity_block(input_tensor, kernel_size, filters, stage, block):
Expand Down Expand Up @@ -65,8 +63,8 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
x = Activation('relu')(x)

x = Conv2D(filters2, kernel_size, kernel_size,
pading='same', name=conv_name_base + '2b')(x)
x = Conv2D(filters2, kernel_size,
padding='same', name=conv_name_base + '2b')(x)
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
x = Activation('relu')(x)

Expand Down Expand Up @@ -107,7 +105,7 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2))
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
x = Activation('relu')(x)

x = Conv2D(filters2, (kernel_size, kernel_size), pading='same',
x = Conv2D(filters2, kernel_size, padding='same',
name=conv_name_base + '2b')(x)
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
x = Activation('relu')(x)
Expand Down Expand Up @@ -253,18 +251,27 @@ def ResNet50(include_top=True, weights='imagenet',

# load weights
if weights == 'imagenet':
if include_top:
weights_path = get_file('resnet50_weights_tf_dim_ordering_tf_kernels.h5',
WEIGHTS_PATH,
cache_subdir='models',
md5_hash='a7b3fe01876f51b976af0dea6bc144eb')
else:
weights_path = get_file('resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5',
WEIGHTS_PATH_NO_TOP,
cache_subdir='models',
md5_hash='a268eb855778b3df3c7506639542a6af')
model.load_weights(weights_path)
if K.backend() == 'theano':
layer_utils.convert_all_kernels_in_model(model)

if K.image_data_format() == 'channels_first':
if include_top:
weights_path = get_file('resnet50_weights_th_dim_ordering_th_kernels.h5',
TH_WEIGHTS_PATH,
cache_subdir='models',
md5_hash='1c1f8f5b0c8ee28fe9d950625a230e1c')
else:
weights_path = get_file('resnet50_weights_th_dim_ordering_th_kernels_notop.h5',
TH_WEIGHTS_PATH_NO_TOP,
cache_subdir='models',
md5_hash='f64f049c92468c9affcd44b0976cdafe')
model.load_weights(weights_path)
maxpool = model.get_layer(name='avg_pool')
shape = maxpool.output_shape[1:]
dense = model.get_layer(name='fc1000')
layer_utils.convert_dense_weights_data_format(dense, shape, 'channels_first')

if K.backend() == 'tensorflow':
warnings.warn('You are using the TensorFlow backend, yet you '
'are using the Theano '
Expand All @@ -274,19 +281,4 @@ def ResNet50(include_top=True, weights='imagenet',
'`image_data_format="channels_last"` in '
'your Keras config '
'at ~/.keras/keras.json.')
convert_all_kernels_in_model(model)
else:
if include_top:
weights_path = get_file('resnet50_weights_tf_dim_ordering_tf_kernels.h5',
TF_WEIGHTS_PATH,
cache_subdir='models',
md5_hash='a7b3fe01876f51b976af0dea6bc144eb')
else:
weights_path = get_file('resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5',
TF_WEIGHTS_PATH_NO_TOP,
cache_subdir='models',
md5_hash='a268eb855778b3df3c7506639542a6af')
model.load_weights(weights_path)
if K.backend() == 'theano':
convert_all_kernels_in_model(model)
return model
46 changes: 20 additions & 26 deletions keras/applications/vgg16.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,16 @@
from ..layers import GlobalAveragePooling2D
from ..layers import GlobalMaxPooling2D
from ..engine.topology import get_source_inputs
from ..utils.layer_utils import convert_all_kernels_in_model
from ..utils import layer_utils
from ..utils.data_utils import get_file
from .. import backend as K
from .imagenet_utils import decode_predictions
from .imagenet_utils import preprocess_input
from .imagenet_utils import _obtain_input_shape


TH_WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_th_dim_ordering_th_kernels.h5'
TF_WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels.h5'
TH_WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_th_dim_ordering_th_kernels_notop.h5'
TF_WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5'
WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels.h5'
WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5'


def VGG16(include_top=True, weights='imagenet',
Expand Down Expand Up @@ -160,16 +158,25 @@ def VGG16(include_top=True, weights='imagenet',

# load weights
if weights == 'imagenet':
if include_top:
weights_path = get_file('vgg16_weights_tf_dim_ordering_tf_kernels.h5',
WEIGHTS_PATH,
cache_subdir='models')
else:
weights_path = get_file('vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5',
WEIGHTS_PATH_NO_TOP,
cache_subdir='models')
model.load_weights(weights_path)
if K.backend() == 'theano':
layer_utils.convert_all_kernels_in_model(model)

if K.image_data_format() == 'channels_first':
if include_top:
weights_path = get_file('vgg16_weights_th_dim_ordering_th_kernels.h5',
TH_WEIGHTS_PATH,
cache_subdir='models')
else:
weights_path = get_file('vgg16_weights_th_dim_ordering_th_kernels_notop.h5',
TH_WEIGHTS_PATH_NO_TOP,
cache_subdir='models')
model.load_weights(weights_path)
maxpool = model.get_layer(name='block5_pool')
shape = maxpool.output_shape[1:]
dense = model.get_layer(name='fc1')
layer_utils.convert_dense_weights_data_format(dense, shape, 'channels_first')

if K.backend() == 'tensorflow':
warnings.warn('You are using the TensorFlow backend, yet you '
'are using the Theano '
Expand All @@ -179,17 +186,4 @@ def VGG16(include_top=True, weights='imagenet',
'`image_data_format="channels_last"` in '
'your Keras config '
'at ~/.keras/keras.json.')
convert_all_kernels_in_model(model)
else:
if include_top:
weights_path = get_file('vgg16_weights_tf_dim_ordering_tf_kernels.h5',
TF_WEIGHTS_PATH,
cache_subdir='models')
else:
weights_path = get_file('vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5',
TF_WEIGHTS_PATH_NO_TOP,
cache_subdir='models')
model.load_weights(weights_path)
if K.backend() == 'theano':
convert_all_kernels_in_model(model)
return model
46 changes: 20 additions & 26 deletions keras/applications/vgg19.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,16 @@
from ..layers import GlobalAveragePooling2D
from ..layers import GlobalMaxPooling2D
from ..engine.topology import get_source_inputs
from ..utils.layer_utils import convert_all_kernels_in_model
from ..utils import layer_utils
from ..utils.data_utils import get_file
from .. import backend as K
from .imagenet_utils import decode_predictions
from .imagenet_utils import preprocess_input
from .imagenet_utils import _obtain_input_shape


TH_WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg19_weights_th_dim_ordering_th_kernels.h5'
TF_WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg19_weights_tf_dim_ordering_tf_kernels.h5'
TH_WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg19_weights_th_dim_ordering_th_kernels_notop.h5'
TF_WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5'
WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg19_weights_tf_dim_ordering_tf_kernels.h5'
WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5'


def VGG19(include_top=True, weights='imagenet',
Expand Down Expand Up @@ -163,16 +161,25 @@ def VGG19(include_top=True, weights='imagenet',

# load weights
if weights == 'imagenet':
if include_top:
weights_path = get_file('vgg19_weights_tf_dim_ordering_tf_kernels.h5',
WEIGHTS_PATH,
cache_subdir='models')
else:
weights_path = get_file('vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5',
WEIGHTS_PATH_NO_TOP,
cache_subdir='models')
model.load_weights(weights_path)
if K.backend() == 'theano':
layer_utils.convert_all_kernels_in_model(model)

if K.image_data_format() == 'channels_first':
if include_top:
weights_path = get_file('vgg19_weights_th_dim_ordering_th_kernels.h5',
TH_WEIGHTS_PATH,
cache_subdir='models')
else:
weights_path = get_file('vgg19_weights_th_dim_ordering_th_kernels_notop.h5',
TH_WEIGHTS_PATH_NO_TOP,
cache_subdir='models')
model.load_weights(weights_path)
maxpool = model.get_layer(name='block5_pool')
shape = maxpool.output_shape[1:]
dense = model.get_layer(name='fc1')
layer_utils.convert_dense_weights_data_format(dense, shape, 'channels_first')

if K.backend() == 'tensorflow':
warnings.warn('You are using the TensorFlow backend, yet you '
'are using the Theano '
Expand All @@ -182,17 +189,4 @@ def VGG19(include_top=True, weights='imagenet',
'`image_data_format="channels_last"` in '
'your Keras config '
'at ~/.keras/keras.json.')
convert_all_kernels_in_model(model)
else:
if include_top:
weights_path = get_file('vgg19_weights_tf_dim_ordering_tf_kernels.h5',
TF_WEIGHTS_PATH,
cache_subdir='models')
else:
weights_path = get_file('vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5',
TF_WEIGHTS_PATH_NO_TOP,
cache_subdir='models')
model.load_weights(weights_path)
if K.backend() == 'theano':
convert_all_kernels_in_model(model)
return model
16 changes: 8 additions & 8 deletions keras/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,6 @@ def build(self, input_shape):
axes={self.axis: dim})
shape = (dim,)

if self.center:
self.beta = self.add_weight(shape,
name='beta',
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint)
else:
self.beta = None
if self.scale:
self.gamma = self.add_weight(shape,
name='gamma',
Expand All @@ -109,6 +101,14 @@ def build(self, input_shape):
constraint=self.gamma_constraint)
else:
self.gamma = None
if self.center:
self.beta = self.add_weight(shape,
name='beta',
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint)
else:
self.beta = None
self.moving_mean = self.add_weight(
shape,
name='moving_mean',
Expand Down
Loading

0 comments on commit f516cc6

Please sign in to comment.