%load_ext d2lbook.tab
tab.interact_select(['mxnet', 'pytorch', 'tensorflow', 'jax'])
🏷️sec_weight_decay
Now that we have characterized the problem of overfitting, we can introduce our first regularization technique. Recall that we can always mitigate overfitting by collecting more training data. However, that can be costly, time consuming, or entirely out of our control, making it impossible in the short run. For now, we can assume that we already have as much high-quality data as our resources permit and focus the tools at our disposal when the dataset is taken as a given.
Recall that in our polynomial regression example
(:numref:subsec_polynomial-curve-fitting
)
we could limit our model's capacity
by tweaking the degree
of the fitted polynomial.
Indeed, limiting the number of features
is a popular technique for mitigating overfitting.
However, simply tossing aside features
can be too blunt an instrument.
Sticking with the polynomial regression
example, consider what might happen
with high-dimensional input.
The natural extensions of polynomials
to multivariate data are called monomials,
which are simply products of powers of variables.
The degree of a monomial is the sum of the powers.
For example,
Note that the number of terms with degree
%%tab mxnet
%matplotlib inline
from d2l import mxnet as d2l
from mxnet import autograd, gluon, init, np, npx
from mxnet.gluon import nn
npx.set_np()
%%tab pytorch
%matplotlib inline
from d2l import torch as d2l
import torch
from torch import nn
%%tab tensorflow
%matplotlib inline
from d2l import tensorflow as d2l
import tensorflow as tf
%%tab jax
%matplotlib inline
from d2l import jax as d2l
import jax
from jax import numpy as jnp
import optax
(Rather than directly manipulating the number of parameters,
weight decay, operates by restricting the values
that the parameters can take.)
More commonly called
One simple interpretation might be
to measure the complexity of a linear function
subsec_lin-algebra-norms
.
The most common method for ensuring a small weight vector
is to add its norm as a penalty term
to the problem of minimizing the loss.
Thus we replace our original objective,
minimizing the prediction loss on the training labels,
with new objective,
minimizing the sum of the prediction loss and the penalty term.
Now, if our weight vector grows too large,
our learning algorithm might focus
on minimizing the weight norm sec_linear_regression
for linear regression.
There, our loss was given by
Recall that
For
Moreover, you might ask why we work with the
Using the same notation in :eqref:eq_linreg_batch_update
,
minibatch stochastic gradient descent updates
for
As before, we update
We can illustrate the benefits of weight decay through a simple synthetic example.
First, we [generate some data as before]:
(
In this synthetic dataset, our label is given
by an underlying linear function of our inputs,
corrupted by Gaussian noise
with zero mean and standard deviation 0.01.
For illustrative purposes,
we can make the effects of overfitting pronounced,
by increasing the dimensionality of our problem to
%%tab all
class Data(d2l.DataModule):
def __init__(self, num_train, num_val, num_inputs, batch_size):
self.save_hyperparameters()
n = num_train + num_val
if tab.selected('mxnet') or tab.selected('pytorch'):
self.X = d2l.randn(n, num_inputs)
noise = d2l.randn(n, 1) * 0.01
if tab.selected('tensorflow'):
self.X = d2l.normal((n, num_inputs))
noise = d2l.normal((n, 1)) * 0.01
if tab.selected('jax'):
self.X = jax.random.normal(jax.random.PRNGKey(0), (n, num_inputs))
noise = jax.random.normal(jax.random.PRNGKey(0), (n, 1)) * 0.01
w, b = d2l.ones((num_inputs, 1)) * 0.01, 0.05
self.y = d2l.matmul(self.X, w) + b + noise
def get_dataloader(self, train):
i = slice(0, self.num_train) if train else slice(self.num_train, None)
return self.get_tensorloader([self.X, self.y], train, i)
Now, let's try implementing weight decay from scratch.
Since minibatch stochastic gradient descent
is our optimizer,
we just need to add the squared
Perhaps the most convenient way of implementing this penalty is to square all terms in place and sum them.
%%tab all
def l2_penalty(w):
return d2l.reduce_sum(w**2) / 2
In the final model,
the linear regression and the squared loss have not changed since :numref:sec_linear_scratch
,
so we will just define a subclass of d2l.LinearRegressionScratch
. The only change here is that our loss now includes the penalty term.
%%tab pytorch, mxnet, tensorflow
class WeightDecayScratch(d2l.LinearRegressionScratch):
def __init__(self, num_inputs, lambd, lr, sigma=0.01):
super().__init__(num_inputs, lr, sigma)
self.save_hyperparameters()
def loss(self, y_hat, y):
return (super().loss(y_hat, y) +
self.lambd * l2_penalty(self.w))
%%tab jax
class WeightDecayScratch(d2l.LinearRegressionScratch):
lambd: int = 0
def loss(self, params, X, y, state):
return (super().loss(params, X, y, state) +
self.lambd * l2_penalty(params['w']))
The following code fits our model on the training set with 20 examples and evaluates it on the validation set with 100 examples.
%%tab all
data = Data(num_train=20, num_val=100, num_inputs=200, batch_size=5)
trainer = d2l.Trainer(max_epochs=10)
def train_scratch(lambd):
model = WeightDecayScratch(num_inputs=200, lambd=lambd, lr=0.01)
model.board.yscale='log'
trainer.fit(model, data)
if tab.selected('pytorch', 'mxnet', 'tensorflow'):
print('L2 norm of w:', float(l2_penalty(model.w)))
if tab.selected('jax'):
print('L2 norm of w:',
float(l2_penalty(trainer.state.params['w'])))
We now run this code with lambd = 0
,
disabling weight decay.
Note that we overfit badly,
decreasing the training error but not the
validation error---a textbook case of overfitting.
%%tab all
train_scratch(0)
Below, we run with substantial weight decay. Note that the training error increases but the validation error decreases. This is precisely the effect we expect from regularization.
%%tab all
train_scratch(3)
Because weight decay is ubiquitous in neural network optimization, the deep learning framework makes it especially convenient, integrating weight decay into the optimization algorithm itself for easy use in combination with any loss function. Moreover, this integration serves a computational benefit, allowing implementation tricks to add weight decay to the algorithm, without any additional computational overhead. Since the weight decay portion of the update depends only on the current value of each parameter, the optimizer must touch each parameter once anyway.
:begin_tab:mxnet
Below, we specify
the weight decay hyperparameter directly
through wd
when instantiating our Trainer
.
By default, Gluon decays both
weights and biases simultaneously.
Note that the hyperparameter wd
will be multiplied by wd_mult
when updating model parameters.
Thus, if we set wd_mult
to zero,
the bias parameter
:begin_tab:pytorch
Below, we specify
the weight decay hyperparameter directly
through weight_decay
when instantiating our optimizer.
By default, PyTorch decays both
weights and biases simultaneously, but
we can configure the optimizer to handle different parameters
according to different policies.
Here, we only set weight_decay
for
the weights (the net.weight
parameters), hence the
bias (the net.bias
parameter) will not decay.
:end_tab:
:begin_tab:tensorflow
Below, we create an wd
and apply it to the layer's weights
through the kernel_regularizer
argument.
:end_tab:
%%tab mxnet
class WeightDecay(d2l.LinearRegression):
def __init__(self, wd, lr):
super().__init__(lr)
self.save_hyperparameters()
self.wd = wd
def configure_optimizers(self):
self.collect_params('.*bias').setattr('wd_mult', 0)
return gluon.Trainer(self.collect_params(),
'sgd',
{'learning_rate': self.lr, 'wd': self.wd})
%%tab pytorch
class WeightDecay(d2l.LinearRegression):
def __init__(self, wd, lr):
super().__init__(lr)
self.save_hyperparameters()
self.wd = wd
def configure_optimizers(self):
return torch.optim.SGD([
{'params': self.net.weight, 'weight_decay': self.wd},
{'params': self.net.bias}], lr=self.lr)
%%tab tensorflow
class WeightDecay(d2l.LinearRegression):
def __init__(self, wd, lr):
super().__init__(lr)
self.save_hyperparameters()
self.net = tf.keras.layers.Dense(
1, kernel_regularizer=tf.keras.regularizers.l2(wd),
kernel_initializer=tf.keras.initializers.RandomNormal(0, 0.01)
)
def loss(self, y_hat, y):
return super().loss(y_hat, y) + self.net.losses
%%tab jax
class WeightDecay(d2l.LinearRegression):
wd: int = 0
def configure_optimizers(self):
# Weight Decay is not available directly within optax.sgd, but
# optax allows chaining several transformations together
return optax.chain(optax.additive_weight_decay(self.wd),
optax.sgd(self.lr))
[The plot looks similar to that when we implemented weight decay from scratch]. However, this version runs faster and is easier to implement, benefits that will become more pronounced as you address larger problems and this work becomes more routine.
%%tab all
model = WeightDecay(wd=3, lr=0.01)
model.board.yscale='log'
trainer.fit(model, data)
if tab.selected('jax'):
print('L2 norm of w:', float(l2_penalty(model.get_w_b(trainer.state)[0])))
if tab.selected('pytorch', 'mxnet', 'tensorflow'):
print('L2 norm of w:', float(l2_penalty(model.get_w_b()[0])))
So far, we have touched upon one notion of what constitutes a simple linear function. However, even for simple nonlinear functions, the situation can be much more complex. To see this, the concept of reproducing kernel Hilbert space (RKHS) allows one to apply tools introduced for linear functions in a nonlinear context. Unfortunately, RKHS-based algorithms tend to scale poorly to large, high-dimensional data. In this book we will often adopt the common heuristic whereby weight decay is applied to all layers of a deep network.
Regularization is a common method for dealing with overfitting. Classical regularization techniques add a penalty term to the loss function (when training) to reduce the complexity of the learned model.
One particular choice for keeping the model simple is using an
- Experiment with the value of
$\lambda$ in the estimation problem in this section. Plot training and validation accuracy as a function of$\lambda$ . What do you observe? - Use a validation set to find the optimal value of
$\lambda$ . Is it really the optimal value? Does this matter? - What would the update equations look like if instead of
$|\mathbf{w}|^2$ we used$\sum_i |w_i|$ as our penalty of choice ($\ell_1$ regularization)? - We know that
$|\mathbf{w}|^2 = \mathbf{w}^\top \mathbf{w}$ . Can you find a similar equation for matrices (see the Frobenius norm in :numref:subsec_lin-algebra-norms
)? - Review the relationship between training error and generalization error. In addition to weight decay, increased training, and the use of a model of suitable complexity, what other ways might help us deal with overfitting?
- In Bayesian statistics we use the product of prior and likelihood to arrive at a posterior via
$P(w \mid x) \propto P(x \mid w) P(w)$ . How can you identify$P(w)$ with regularization?
:begin_tab:mxnet
Discussions
:end_tab:
:begin_tab:pytorch
Discussions
:end_tab:
:begin_tab:tensorflow
Discussions
:end_tab: