Skip to content

Commit

Permalink
Internal.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 471489067
  • Loading branch information
gortizji authored and copybara-github committed Sep 1, 2022
1 parent 55a203d commit f4f8398
Showing 1 changed file with 78 additions and 62 deletions.
140 changes: 78 additions & 62 deletions uncertainty_baselines/models/wide_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,8 @@ def Conv2D(filters, seed=None, **kwargs): # pylint: disable=invalid-name
return tf.keras.layers.Conv2D(filters, **default_kwargs)


def basic_block(
inputs: tf.Tensor,
filters: int,
strides: int,
conv_l2: float,
bn_l2: float,
seed: int,
version: int) -> tf.Tensor:
def basic_block(inputs: tf.Tensor, filters: int, strides: int, conv_l2: float,
bn_l2: float, seed: int, version: int) -> tf.Tensor:
"""Basic residual block of two 3x3 convs.
Args:
Expand All @@ -75,30 +69,42 @@ def basic_block(
x = inputs
y = inputs
if version == 2:
y = BatchNormalization(beta_regularizer=tf.keras.regularizers.l2(bn_l2),
gamma_regularizer=tf.keras.regularizers.l2(bn_l2))(y)
y = BatchNormalization(
beta_regularizer=tf.keras.regularizers.l2(bn_l2),
gamma_regularizer=tf.keras.regularizers.l2(bn_l2))(
y)
y = tf.keras.layers.Activation('relu')(y)
seeds = tf.random.experimental.stateless_split([seed, seed + 1], 3)[:, 0]
y = Conv2D(filters,
strides=strides,
seed=seeds[0],
kernel_regularizer=tf.keras.regularizers.l2(conv_l2))(y)
y = BatchNormalization(beta_regularizer=tf.keras.regularizers.l2(bn_l2),
gamma_regularizer=tf.keras.regularizers.l2(bn_l2))(y)
y = Conv2D(
filters,
strides=strides,
seed=seeds[0],
kernel_regularizer=tf.keras.regularizers.l2(conv_l2))(
y)
y = BatchNormalization(
beta_regularizer=tf.keras.regularizers.l2(bn_l2),
gamma_regularizer=tf.keras.regularizers.l2(bn_l2))(
y)
y = tf.keras.layers.Activation('relu')(y)
y = Conv2D(filters,
strides=1,
seed=seeds[1],
kernel_regularizer=tf.keras.regularizers.l2(conv_l2))(y)
y = Conv2D(
filters,
strides=1,
seed=seeds[1],
kernel_regularizer=tf.keras.regularizers.l2(conv_l2))(
y)
if version == 1:
y = BatchNormalization(beta_regularizer=tf.keras.regularizers.l2(bn_l2),
gamma_regularizer=tf.keras.regularizers.l2(bn_l2))(y)
y = BatchNormalization(
beta_regularizer=tf.keras.regularizers.l2(bn_l2),
gamma_regularizer=tf.keras.regularizers.l2(bn_l2))(
y)
if not x.shape.is_compatible_with(y.shape):
x = Conv2D(filters,
kernel_size=1,
strides=strides,
seed=seeds[2],
kernel_regularizer=tf.keras.regularizers.l2(conv_l2))(x)
x = Conv2D(
filters,
kernel_size=1,
strides=strides,
seed=seeds[2],
kernel_regularizer=tf.keras.regularizers.l2(conv_l2))(
x)
x = tf.keras.layers.add([x, y])
if version == 1:
x = tf.keras.layers.Activation('relu')(x)
Expand All @@ -107,8 +113,8 @@ def basic_block(

def group(inputs, filters, strides, num_blocks, conv_l2, bn_l2, version, seed):
"""Group of residual blocks."""
seeds = tf.random.experimental.stateless_split(
[seed, seed + 1], num_blocks)[:, 0]
seeds = tf.random.experimental.stateless_split([seed, seed + 1],
num_blocks)[:, 0]
x = basic_block(
inputs,
filters=filters,
Expand Down Expand Up @@ -187,49 +193,59 @@ def wide_resnet(
raise ValueError('depth should be 6n+4 (e.g., 16, 22, 28, 40).')
num_blocks = (depth - 4) // 6
inputs = tf.keras.layers.Input(shape=input_shape)
x = Conv2D(16,
strides=1,
seed=seeds[0],
kernel_regularizer=l2_reg(hps['input_conv_l2']))(inputs)
x = Conv2D(
16,
strides=1,
seed=seeds[0],
kernel_regularizer=l2_reg(hps['input_conv_l2']))(
inputs)
if version == 1:
x = BatchNormalization(beta_regularizer=l2_reg(hps['bn_l2']),
gamma_regularizer=l2_reg(hps['bn_l2']))(x)
x = BatchNormalization(
beta_regularizer=l2_reg(hps['bn_l2']),
gamma_regularizer=l2_reg(hps['bn_l2']))(
x)
x = tf.keras.layers.Activation('relu')(x)
x = group(x,
filters=16 * width_multiplier,
strides=1,
num_blocks=num_blocks,
conv_l2=hps['group_1_conv_l2'],
bn_l2=hps['bn_l2'],
version=version,
seed=seeds[1])
x = group(x,
filters=32 * width_multiplier,
strides=2,
num_blocks=num_blocks,
conv_l2=hps['group_2_conv_l2'],
bn_l2=hps['bn_l2'],
version=version,
seed=seeds[2])
x = group(x,
filters=64 * width_multiplier,
strides=2,
num_blocks=num_blocks,
conv_l2=hps['group_3_conv_l2'],
bn_l2=hps['bn_l2'],
version=version,
seed=seeds[3])
x = group(
x,
filters=16 * width_multiplier,
strides=1,
num_blocks=num_blocks,
conv_l2=hps['group_1_conv_l2'],
bn_l2=hps['bn_l2'],
version=version,
seed=seeds[1])
x = group(
x,
filters=32 * width_multiplier,
strides=2,
num_blocks=num_blocks,
conv_l2=hps['group_2_conv_l2'],
bn_l2=hps['bn_l2'],
version=version,
seed=seeds[2])
x = group(
x,
filters=64 * width_multiplier,
strides=2,
num_blocks=num_blocks,
conv_l2=hps['group_3_conv_l2'],
bn_l2=hps['bn_l2'],
version=version,
seed=seeds[3])
if version == 2:
x = BatchNormalization(beta_regularizer=l2_reg(hps['bn_l2']),
gamma_regularizer=l2_reg(hps['bn_l2']))(x)
x = BatchNormalization(
beta_regularizer=l2_reg(hps['bn_l2']),
gamma_regularizer=l2_reg(hps['bn_l2']))(
x)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.AveragePooling2D(pool_size=8)(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(
num_classes,
kernel_initializer=tf.keras.initializers.HeNormal(seed=seeds[4]),
kernel_regularizer=l2_reg(hps['dense_kernel_l2']),
bias_regularizer=l2_reg(hps['dense_bias_l2']))(x)
bias_regularizer=l2_reg(hps['dense_bias_l2']))(
x)
return tf.keras.Model(
inputs=inputs,
outputs=x,
Expand Down

0 comments on commit f4f8398

Please sign in to comment.