Skip to content

Commit

Permalink
support tf image_dim_ordering
Browse files Browse the repository at this point in the history
This shuffles axes around if the global `image_dim_ordering` flag
is set to `tf`, meaning the network should expect input of the form
`(samples, rows, cols, channels)` rather than
`(samples, channels, rows, cols)`.
  • Loading branch information
paulfitz committed Dec 12, 2016
1 parent 223dcc4 commit f3ccb03
Showing 1 changed file with 24 additions and 7 deletions.
31 changes: 24 additions & 7 deletions resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,24 @@
AveragePooling2D
)
from keras.layers.normalization import BatchNormalization
from keras import backend as K

if K.image_dim_ordering() == 'tf':
ROW_AXIS = 1
COL_AXIS = 2
CHANNEL_AXIS = 3
else:
CHANNEL_AXIS = 1
ROW_AXIS = 2
COL_AXIS = 3


# Helper to build a conv -> BN -> relu block
def _conv_bn_relu(nb_filter, nb_row, nb_col, subsample=(1, 1)):
def f(input):
conv = Convolution2D(nb_filter=nb_filter, nb_row=nb_row, nb_col=nb_col, subsample=subsample,
init="he_normal", border_mode="same")(input)
norm = BatchNormalization(mode=0, axis=1)(conv)
norm = BatchNormalization(mode=0, axis=CHANNEL_AXIS)(conv)
return Activation("relu")(norm)

return f
Expand All @@ -29,7 +39,7 @@ def f(input):
# This is an improved scheme proposed in http://arxiv.org/pdf/1603.05027v2.pdf
def _bn_relu_conv(nb_filter, nb_row, nb_col, subsample=(1, 1)):
def f(input):
norm = BatchNormalization(mode=0, axis=1)(input)
norm = BatchNormalization(mode=0, axis=CHANNEL_AXIS)(input)
activation = Activation("relu")(norm)
return Convolution2D(nb_filter=nb_filter, nb_row=nb_row, nb_col=nb_col, subsample=subsample,
init="he_normal", border_mode="same")(activation)
Expand All @@ -42,14 +52,15 @@ def _shortcut(input, residual):
# Expand channels of shortcut to match residual.
# Stride appropriately to match residual (width, height)
# Should be int if network architecture is correctly configured.
stride_width = input._keras_shape[2] // residual._keras_shape[2]
stride_height = input._keras_shape[3] // residual._keras_shape[3]
equal_channels = residual._keras_shape[1] == input._keras_shape[1]
stride_width = input._keras_shape[ROW_AXIS] // residual._keras_shape[ROW_AXIS]
stride_height = input._keras_shape[COL_AXIS] // residual._keras_shape[COL_AXIS]
equal_channels = residual._keras_shape[CHANNEL_AXIS] == input._keras_shape[CHANNEL_AXIS]

shortcut = input
# 1 X 1 conv if shape is different. Else identity.
if stride_width > 1 or stride_height > 1 or not equal_channels:
shortcut = Convolution2D(nb_filter=residual._keras_shape[1], nb_row=1, nb_col=1,
shortcut = Convolution2D(nb_filter=residual._keras_shape[CHANNEL_AXIS],
nb_row=1, nb_col=1,
subsample=(stride_width, stride_height),
init="he_normal", border_mode="valid")(input)

Expand Down Expand Up @@ -114,6 +125,10 @@ def build(input_shape, num_outputs, block_fn, repetitions):
if len(input_shape) != 3:
raise Exception("Input shape should be a tuple (nb_channels, nb_rows, nb_cols)")

# Permute dimension order if necessary
if K.image_dim_ordering() == 'tf':
input_shape = (input_shape[1], input_shape[2], input_shape[0])

input = Input(shape=input_shape)
conv1 = _conv_bn_relu(nb_filter=64, nb_row=7, nb_col=7, subsample=(2, 2))(input)
pool1 = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), border_mode="same")(conv1)
Expand All @@ -125,7 +140,9 @@ def build(input_shape, num_outputs, block_fn, repetitions):
nb_filters *= 2

# Classifier block
pool2 = AveragePooling2D(pool_size=(block._keras_shape[2], block._keras_shape[3]), strides=(1, 1))(block)
pool2 = AveragePooling2D(pool_size=(block._keras_shape[ROW_AXIS],
block._keras_shape[COL_AXIS]),
strides=(1, 1))(block)
flatten1 = Flatten()(pool2)
dense = Dense(output_dim=num_outputs, init="he_normal", activation="softmax")(flatten1)

Expand Down

0 comments on commit f3ccb03

Please sign in to comment.