Skip to content

Commit

Permalink
Fix network topology (keras-team#8981)
Browse files Browse the repository at this point in the history
  • Loading branch information
ozabluda authored and fchollet committed Jan 6, 2018
1 parent eac78b8 commit f097f69
Showing 1 changed file with 28 additions and 18 deletions.
46 changes: 28 additions & 18 deletions examples/cifar10_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
# ResNet20 | 3 (2)| 92.16 | 91.25 | ----- | ----- | 35 (---)
# ResNet32 | 5(NA)| 92.46 | 92.49 | NA | NA | 50 ( NA)
# ResNet44 | 7(NA)| 92.50 | 92.83 | NA | NA | 70 ( NA)
# ResNet56 | 9 (6)| 92.71 | 93.03 | 92.60 | NA | 90 (100)
# ResNet110 |18(12)| 92.65 | 93.39+-.16| 93.03 | 93.63 | 165(180)
# ResNet56 | 9 (6)| 92.71 | 93.03 | 93.01 | NA | 90 (100)
# ResNet110 |18(12)| 92.65 | 93.39+-.16| 93.15 | 93.63 | 165(180)
# ResNet164 |27(18)| ----- | 94.07 | ----- | 94.54 | ---(---)
# ResNet1001| (111)| ----- | 92.39 | ----- | 95.08+-.14| ---(---)
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -118,6 +118,7 @@ def resnet_block(inputs,
kernel_size=3,
strides=1,
activation='relu',
batch_normalization=True,
conv_first=True):
"""2D Convolution-Batch Normalization-Activation stack builder
Expand All @@ -127,24 +128,28 @@ def resnet_block(inputs,
kernel_size (int): Conv2D square kernel dimensions
strides (int): Conv2D square stride dimensions
activation (string): activation name
batch_normalization (bool): whether to include batch normalization
conv_first (bool): conv-bn-activation (True) or
activation-bn-conv (False)
# Returns
x (tensor): tensor as input to the next layer
"""
x = inputs
if conv_first:
x = Conv2D(num_filters,
kernel_size=kernel_size,
strides=strides,
padding='same',
kernel_initializer='he_normal',
kernel_regularizer=l2(1e-4))(inputs)
x = BatchNormalization()(x)
kernel_regularizer=l2(1e-4))(x)
if batch_normalization:
x = BatchNormalization()(x)
if activation:
x = Activation(activation)(x)
return x
x = BatchNormalization()(inputs)
if batch_normalization:
x = BatchNormalization()(x)
if activation:
x = Activation('relu')(x)
x = Conv2D(num_filters,
Expand Down Expand Up @@ -204,7 +209,8 @@ def resnet_v1(input_shape, depth, num_classes=10):
num_filters=num_filters,
kernel_size=1,
strides=strides,
activation=None)
activation=None,
batch_normalization=False)
x = keras.layers.add([x, y])
x = Activation('relu')(x)
num_filters = 2 * num_filters
Expand Down Expand Up @@ -248,14 +254,14 @@ def resnet_v2(input_shape, depth, num_classes=10):
filter_multiplier = 4
num_sub_blocks = int((depth - 2) / 9)

# v2 performs Conv2D on input w/o BN-ReLU
x = Conv2D(num_filters_in,
kernel_size=3,
padding='same',
kernel_initializer='he_normal',
kernel_regularizer=l2(1e-4))(inputs)
# v2 performs Conv2D with BN-ReLU on input before splitting into 2 paths
x = resnet_block(inputs=inputs,
num_filters=num_filters_in,
conv_first=True)

# Instantiate convolutional base (stack of blocks).
activation = None
batch_normalization = False
for i in range(3):
if i > 0:
filter_multiplier = 2
Expand All @@ -270,7 +276,11 @@ def resnet_v2(input_shape, depth, num_classes=10):
num_filters=num_filters_in,
kernel_size=1,
strides=strides,
activation=activation,
batch_normalization=batch_normalization,
conv_first=False)
activation = 'relu'
batch_normalization = True
y = resnet_block(inputs=y,
num_filters=num_filters_in,
conv_first=False)
Expand All @@ -279,12 +289,12 @@ def resnet_v2(input_shape, depth, num_classes=10):
kernel_size=1,
conv_first=False)
if j == 0:
x = Conv2D(num_filters_out,
kernel_size=1,
strides=strides,
padding='same',
kernel_initializer='he_normal',
kernel_regularizer=l2(1e-4))(x)
x = resnet_block(inputs=x,
num_filters=num_filters_out,
kernel_size=1,
strides=strides,
activation=None,
batch_normalization=False)
x = keras.layers.add([x, y])

num_filters_in = num_filters_out
Expand Down

0 comments on commit f097f69

Please sign in to comment.