Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 321666729
  • Loading branch information
Yu-hui Chen authored and TF Object Detection Team committed Jul 17, 2020
1 parent 2e6ccf0 commit 4c83bd7
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 8 deletions.
158 changes: 151 additions & 7 deletions research/object_detection/models/keras_models/resnet_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import tensorflow.compat.v1 as tf

from tensorflow.python.keras.applications import resnet
from object_detection.core import freezable_batch_norm
from object_detection.models.keras_models import model_utils

Expand Down Expand Up @@ -95,11 +96,11 @@ def __init__(self,
self.regularizer = tf.keras.regularizers.l2(weight_decay)
self.initializer = tf.variance_scaling_initializer()

def _FixedPaddingLayer(self, kernel_size, rate=1):
def _FixedPaddingLayer(self, kernel_size, rate=1): # pylint: disable=invalid-name
return tf.keras.layers.Lambda(
lambda x: _fixed_padding(x, kernel_size, rate))

def Conv2D(self, filters, kernel_size, **kwargs):
def Conv2D(self, filters, kernel_size, **kwargs): # pylint: disable=invalid-name
"""Builds a Conv2D layer according to the current Object Detection config.
Overrides the Keras Resnet application's convolutions with ones that
Expand Down Expand Up @@ -141,7 +142,7 @@ def padded_conv(features): # pylint: disable=invalid-name
else:
return tf.keras.layers.Conv2D(filters, kernel_size, **kwargs)

def Activation(self, *args, **kwargs): # pylint: disable=unused-argument
def Activation(self, *args, **kwargs): # pylint: disable=unused-argument,invalid-name
"""Builds an activation layer.
Overrides the Keras application Activation layer specified by the
Expand All @@ -163,7 +164,7 @@ def Activation(self, *args, **kwargs): # pylint: disable=unused-argument
else:
return tf.keras.layers.Lambda(tf.nn.relu, name=name)

def BatchNormalization(self, **kwargs):
def BatchNormalization(self, **kwargs): # pylint: disable=invalid-name
"""Builds a normalization layer.
Overrides the Keras application batch norm with the norm specified by the
Expand Down Expand Up @@ -191,7 +192,7 @@ def BatchNormalization(self, **kwargs):
momentum=self._default_batchnorm_momentum,
**kwargs)

def Input(self, shape):
def Input(self, shape): # pylint: disable=invalid-name
"""Builds an Input layer.
Overrides the Keras application Input layer with one that uses a
Expand Down Expand Up @@ -219,7 +220,7 @@ def Input(self, shape):
input=input_tensor, shape=[None] + shape)
return model_utils.input_layer(shape, placeholder_with_default)

def MaxPooling2D(self, pool_size, **kwargs):
def MaxPooling2D(self, pool_size, **kwargs): # pylint: disable=invalid-name
"""Builds a MaxPooling2D layer with default padding as 'SAME'.
This is specified by the default resnet arg_scope in slim.
Expand All @@ -237,7 +238,7 @@ def MaxPooling2D(self, pool_size, **kwargs):
# Add alias as Keras also has it.
MaxPool2D = MaxPooling2D # pylint: disable=invalid-name

def ZeroPadding2D(self, padding, **kwargs): # pylint: disable=unused-argument
def ZeroPadding2D(self, padding, **kwargs): # pylint: disable=unused-argument,invalid-name
"""Replaces explicit padding in the Keras application with a no-op.
Args:
Expand Down Expand Up @@ -395,3 +396,146 @@ def resnet_v1_152(batchnorm_training,
return tf.keras.applications.resnet.ResNet152(
layers=layers_override, **kwargs)
# pylint: enable=invalid-name


# The following codes are based on the existing keras ResNet model pattern:
# google3/third_party/tensorflow/python/keras/applications/resnet.py
def block_basic(x,
filters,
kernel_size=3,
stride=1,
conv_shortcut=False,
name=None):
"""A residual block for ResNet18/34.
Arguments:
x: input tensor.
filters: integer, filters of the bottleneck layer.
kernel_size: default 3, kernel size of the bottleneck layer.
stride: default 1, stride of the first layer.
conv_shortcut: default False, use convolution shortcut if True, otherwise
identity shortcut.
name: string, block label.
Returns:
Output tensor for the residual block.
"""
layers = tf.keras.layers
bn_axis = 3 if tf.keras.backend.image_data_format() == 'channels_last' else 1

preact = layers.BatchNormalization(
axis=bn_axis, epsilon=1.001e-5, name=name + '_preact_bn')(
x)
preact = layers.Activation('relu', name=name + '_preact_relu')(preact)

if conv_shortcut:
shortcut = layers.Conv2D(
filters, 1, strides=1, name=name + '_0_conv')(
preact)
else:
shortcut = layers.MaxPooling2D(1, strides=stride)(x) if stride > 1 else x

x = layers.ZeroPadding2D(
padding=((1, 1), (1, 1)), name=name + '_1_pad')(
preact)
x = layers.Conv2D(
filters, kernel_size, strides=1, use_bias=False, name=name + '_1_conv')(
x)
x = layers.BatchNormalization(
axis=bn_axis, epsilon=1.001e-5, name=name + '_1_bn')(
x)
x = layers.Activation('relu', name=name + '_1_relu')(x)

x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)), name=name + '_2_pad')(x)
x = layers.Conv2D(
filters,
kernel_size,
strides=stride,
use_bias=False,
name=name + '_2_conv')(
x)
x = layers.BatchNormalization(
axis=bn_axis, epsilon=1.001e-5, name=name + '_2_bn')(
x)
x = layers.Activation('relu', name=name + '_2_relu')(x)
x = layers.Add(name=name + '_out')([shortcut, x])
return x


def stack_basic(x, filters, blocks, stride1=2, name=None):
"""A set of stacked residual blocks for ResNet18/34.
Arguments:
x: input tensor.
filters: integer, filters of the bottleneck layer in a block.
blocks: integer, blocks in the stacked blocks.
stride1: default 2, stride of the first layer in the first block.
name: string, stack label.
Returns:
Output tensor for the stacked blocks.
"""
x = block_basic(x, filters, conv_shortcut=True, name=name + '_block1')
for i in range(2, blocks):
x = block_basic(x, filters, name=name + '_block' + str(i))
x = block_basic(
x, filters, stride=stride1, name=name + '_block' + str(blocks))
return x


def resnet_v1_18(include_top=True,
weights='imagenet',
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
classifier_activation='softmax'):
"""Instantiates the ResNet18 architecture."""

def stack_fn(x):
x = stack_basic(x, 64, 2, stride1=1, name='conv2')
x = stack_basic(x, 128, 2, name='conv3')
x = stack_basic(x, 256, 2, name='conv4')
return stack_basic(x, 512, 2, name='conv5')

return resnet.ResNet(
stack_fn,
True,
True,
'resnet18',
include_top,
weights,
input_tensor,
input_shape,
pooling,
classes,
classifier_activation=classifier_activation)


def resnet_v1_34(include_top=True,
weights='imagenet',
input_tensor=None,
input_shape=None,
pooling=None,
classes=1000,
classifier_activation='softmax'):
"""Instantiates the ResNet34 architecture."""

def stack_fn(x):
x = stack_basic(x, 64, 3, stride1=1, name='conv2')
x = stack_basic(x, 128, 4, name='conv3')
x = stack_basic(x, 256, 6, name='conv4')
return stack_basic(x, 512, 3, name='conv5')

return resnet.ResNet(
stack_fn,
True,
True,
'resnet34',
include_top,
weights,
input_tensor,
input_shape,
pooling,
classes,
classifier_activation=classifier_activation)
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
2. Number of global variables.
"""
import unittest

from absl.testing import parameterized
import numpy as np
from six.moves import zip
import tensorflow.compat.v1 as tf

from google.protobuf import text_format

from object_detection.builders import hyperparams_builder
from object_detection.models.keras_models import resnet_v1
from object_detection.protos import hyperparams_pb2
Expand Down Expand Up @@ -180,5 +181,46 @@ def test_variable_count(self):
self.assertEqual(len(variables), var_num)


class ResnetShapeTest(test_case.TestCase, parameterized.TestCase):

@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
@parameterized.parameters(
{
'resnet_type':
'resnet_v1_34',
'output_layer_names': [
'conv2_block3_out', 'conv3_block4_out', 'conv4_block6_out',
'conv5_block3_out'
]
}, {
'resnet_type':
'resnet_v1_18',
'output_layer_names': [
'conv2_block2_out', 'conv3_block2_out', 'conv4_block2_out',
'conv5_block2_out'
]
})
def test_output_shapes(self, resnet_type, output_layer_names):
if resnet_type == 'resnet_v1_34':
model = resnet_v1.resnet_v1_34(weights=None)
else:
model = resnet_v1.resnet_v1_18(weights=None)
outputs = [
model.get_layer(output_layer_name).output
for output_layer_name in output_layer_names
]
resnet_model = tf.keras.models.Model(inputs=model.input, outputs=outputs)
outputs = resnet_model(np.zeros((2, 64, 64, 3), dtype=np.float32))

# Check the shape of 'conv2_block3_out':
self.assertEqual(outputs[0].shape, [2, 16, 16, 64])
# Check the shape of 'conv3_block4_out':
self.assertEqual(outputs[1].shape, [2, 8, 8, 128])
# Check the shape of 'conv4_block6_out':
self.assertEqual(outputs[2].shape, [2, 4, 4, 256])
# Check the shape of 'conv5_block3_out':
self.assertEqual(outputs[3].shape, [2, 2, 2, 512])


if __name__ == '__main__':
tf.test.main()

0 comments on commit 4c83bd7

Please sign in to comment.