Skip to content

Commit

Permalink
Added regularization option to BatchNormalization layer (keras-team#3671
Browse files Browse the repository at this point in the history
)

* Added regularization option to BatchNormalization layer

* Update normalization.py

* Added regularization to BN test

* Fixed identation

* Removed trailing whitespace and refixed identation
  • Loading branch information
lolemacs authored and fchollet committed Sep 2, 2016
1 parent 870d7f7 commit f90cbcd
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
22 changes: 20 additions & 2 deletions keras/layers/normalization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ..engine import Layer, InputSpec
from .. import initializations
from .. import initializations, regularizers
from .. import backend as K


Expand Down Expand Up @@ -44,6 +44,10 @@ class BatchNormalization(Layer):
[initializations](../initializations.md)), or alternatively,
Theano/TensorFlow function to use for weights initialization.
This parameter is only relevant if you don't pass a `weights` argument.
gamma_regularizer: instance of [WeightRegularizer](../regularizers.md)
(eg. L1 or L2 regularization), applied to the gamma vector.
beta_regularizer: instance of [WeightRegularizer](../regularizers.md),
applied to the beta vector.
# Input shape
Arbitrary. Use the keyword argument `input_shape`
Expand All @@ -57,14 +61,17 @@ class BatchNormalization(Layer):
- [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](http://jmlr.org/proceedings/papers/v37/ioffe15.html)
'''
def __init__(self, epsilon=1e-5, mode=0, axis=-1, momentum=0.99,
weights=None, beta_init='zero', gamma_init='one', **kwargs):
weights=None, beta_init='zero', gamma_init='one',
gamma_regularizer=None, beta_regularizer=None, **kwargs):
self.supports_masking = True
self.beta_init = initializations.get(beta_init)
self.gamma_init = initializations.get(gamma_init)
self.epsilon = epsilon
self.mode = mode
self.axis = axis
self.momentum = momentum
self.gamma_regularizer = regularizers.get(gamma_regularizer)
self.beta_regularizer = regularizers.get(beta_regularizer)
self.initial_weights = weights
if self.mode == 0:
self.uses_learning_phase = True
Expand All @@ -78,6 +85,15 @@ def build(self, input_shape):
self.beta = self.beta_init(shape, name='{}_beta'.format(self.name))
self.trainable_weights = [self.gamma, self.beta]

self.regularizers = []
if self.gamma_regularizer:
self.gamma_regularizer.set_param(self.gamma)
self.regularizers.append(self.gamma_regularizer)

if self.beta_regularizer:
self.beta_regularizer.set_param(self.beta)
self.regularizers.append(self.beta_regularizer)

self.running_mean = K.zeros(shape,
name='{}_running_mean'.format(self.name))
self.running_std = K.ones(shape,
Expand Down Expand Up @@ -155,6 +171,8 @@ def get_config(self):
config = {"epsilon": self.epsilon,
"mode": self.mode,
"axis": self.axis,
"gamma_regularizer": self.gamma_regularizer.get_config() if self.gamma_regularizer else None,
"beta_regularizer": self.beta_regularizer.get_config() if self.beta_regularizer else None,
"momentum": self.momentum}
base_config = super(BatchNormalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
5 changes: 4 additions & 1 deletion tests/keras/layers/test_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@

@keras_test
def basic_batchnorm_test():
from keras import regularizers
layer_test(normalization.BatchNormalization,
kwargs={'mode': 1},
kwargs={'mode': 1,
'gamma_regularizer': regularizers.l2(0.01),
'beta_regularizer': regularizers.l2(0.01)},
input_shape=(3, 4, 2))
layer_test(normalization.BatchNormalization,
kwargs={'mode': 0},
Expand Down

0 comments on commit f90cbcd

Please sign in to comment.