diff --git a/uncertainty_baselines/models/wide_resnet.py b/uncertainty_baselines/models/wide_resnet.py index eea6e9cb0..200f9e037 100644 --- a/uncertainty_baselines/models/wide_resnet.py +++ b/uncertainty_baselines/models/wide_resnet.py @@ -49,14 +49,8 @@ def Conv2D(filters, seed=None, **kwargs): # pylint: disable=invalid-name return tf.keras.layers.Conv2D(filters, **default_kwargs) -def basic_block( - inputs: tf.Tensor, - filters: int, - strides: int, - conv_l2: float, - bn_l2: float, - seed: int, - version: int) -> tf.Tensor: +def basic_block(inputs: tf.Tensor, filters: int, strides: int, conv_l2: float, + bn_l2: float, seed: int, version: int) -> tf.Tensor: """Basic residual block of two 3x3 convs. Args: @@ -75,30 +69,42 @@ def basic_block( x = inputs y = inputs if version == 2: - y = BatchNormalization(beta_regularizer=tf.keras.regularizers.l2(bn_l2), - gamma_regularizer=tf.keras.regularizers.l2(bn_l2))(y) + y = BatchNormalization( + beta_regularizer=tf.keras.regularizers.l2(bn_l2), + gamma_regularizer=tf.keras.regularizers.l2(bn_l2))( + y) y = tf.keras.layers.Activation('relu')(y) seeds = tf.random.experimental.stateless_split([seed, seed + 1], 3)[:, 0] - y = Conv2D(filters, - strides=strides, - seed=seeds[0], - kernel_regularizer=tf.keras.regularizers.l2(conv_l2))(y) - y = BatchNormalization(beta_regularizer=tf.keras.regularizers.l2(bn_l2), - gamma_regularizer=tf.keras.regularizers.l2(bn_l2))(y) + y = Conv2D( + filters, + strides=strides, + seed=seeds[0], + kernel_regularizer=tf.keras.regularizers.l2(conv_l2))( + y) + y = BatchNormalization( + beta_regularizer=tf.keras.regularizers.l2(bn_l2), + gamma_regularizer=tf.keras.regularizers.l2(bn_l2))( + y) y = tf.keras.layers.Activation('relu')(y) - y = Conv2D(filters, - strides=1, - seed=seeds[1], - kernel_regularizer=tf.keras.regularizers.l2(conv_l2))(y) + y = Conv2D( + filters, + strides=1, + seed=seeds[1], + kernel_regularizer=tf.keras.regularizers.l2(conv_l2))( + y) if version == 1: - y = BatchNormalization(beta_regularizer=tf.keras.regularizers.l2(bn_l2), - gamma_regularizer=tf.keras.regularizers.l2(bn_l2))(y) + y = BatchNormalization( + beta_regularizer=tf.keras.regularizers.l2(bn_l2), + gamma_regularizer=tf.keras.regularizers.l2(bn_l2))( + y) if not x.shape.is_compatible_with(y.shape): - x = Conv2D(filters, - kernel_size=1, - strides=strides, - seed=seeds[2], - kernel_regularizer=tf.keras.regularizers.l2(conv_l2))(x) + x = Conv2D( + filters, + kernel_size=1, + strides=strides, + seed=seeds[2], + kernel_regularizer=tf.keras.regularizers.l2(conv_l2))( + x) x = tf.keras.layers.add([x, y]) if version == 1: x = tf.keras.layers.Activation('relu')(x) @@ -107,8 +113,8 @@ def basic_block( def group(inputs, filters, strides, num_blocks, conv_l2, bn_l2, version, seed): """Group of residual blocks.""" - seeds = tf.random.experimental.stateless_split( - [seed, seed + 1], num_blocks)[:, 0] + seeds = tf.random.experimental.stateless_split([seed, seed + 1], + num_blocks)[:, 0] x = basic_block( inputs, filters=filters, @@ -187,41 +193,50 @@ def wide_resnet( raise ValueError('depth should be 6n+4 (e.g., 16, 22, 28, 40).') num_blocks = (depth - 4) // 6 inputs = tf.keras.layers.Input(shape=input_shape) - x = Conv2D(16, - strides=1, - seed=seeds[0], - kernel_regularizer=l2_reg(hps['input_conv_l2']))(inputs) + x = Conv2D( + 16, + strides=1, + seed=seeds[0], + kernel_regularizer=l2_reg(hps['input_conv_l2']))( + inputs) if version == 1: - x = BatchNormalization(beta_regularizer=l2_reg(hps['bn_l2']), - gamma_regularizer=l2_reg(hps['bn_l2']))(x) + x = BatchNormalization( + beta_regularizer=l2_reg(hps['bn_l2']), + gamma_regularizer=l2_reg(hps['bn_l2']))( + x) x = tf.keras.layers.Activation('relu')(x) - x = group(x, - filters=16 * width_multiplier, - strides=1, - num_blocks=num_blocks, - conv_l2=hps['group_1_conv_l2'], - bn_l2=hps['bn_l2'], - version=version, - seed=seeds[1]) - x = group(x, - filters=32 * width_multiplier, - strides=2, - num_blocks=num_blocks, - conv_l2=hps['group_2_conv_l2'], - bn_l2=hps['bn_l2'], - version=version, - seed=seeds[2]) - x = group(x, - filters=64 * width_multiplier, - strides=2, - num_blocks=num_blocks, - conv_l2=hps['group_3_conv_l2'], - bn_l2=hps['bn_l2'], - version=version, - seed=seeds[3]) + x = group( + x, + filters=16 * width_multiplier, + strides=1, + num_blocks=num_blocks, + conv_l2=hps['group_1_conv_l2'], + bn_l2=hps['bn_l2'], + version=version, + seed=seeds[1]) + x = group( + x, + filters=32 * width_multiplier, + strides=2, + num_blocks=num_blocks, + conv_l2=hps['group_2_conv_l2'], + bn_l2=hps['bn_l2'], + version=version, + seed=seeds[2]) + x = group( + x, + filters=64 * width_multiplier, + strides=2, + num_blocks=num_blocks, + conv_l2=hps['group_3_conv_l2'], + bn_l2=hps['bn_l2'], + version=version, + seed=seeds[3]) if version == 2: - x = BatchNormalization(beta_regularizer=l2_reg(hps['bn_l2']), - gamma_regularizer=l2_reg(hps['bn_l2']))(x) + x = BatchNormalization( + beta_regularizer=l2_reg(hps['bn_l2']), + gamma_regularizer=l2_reg(hps['bn_l2']))( + x) x = tf.keras.layers.Activation('relu')(x) x = tf.keras.layers.AveragePooling2D(pool_size=8)(x) x = tf.keras.layers.Flatten()(x) @@ -229,7 +244,8 @@ def wide_resnet( num_classes, kernel_initializer=tf.keras.initializers.HeNormal(seed=seeds[4]), kernel_regularizer=l2_reg(hps['dense_kernel_l2']), - bias_regularizer=l2_reg(hps['dense_bias_l2']))(x) + bias_regularizer=l2_reg(hps['dense_bias_l2']))( + x) return tf.keras.Model( inputs=inputs, outputs=x,