Skip to content

Commit

Permalink
JAX: Add section 6.2 & 6.3 Param Management and Param Initialization (d…
Browse files Browse the repository at this point in the history
…2l-ai#2303)

* JAX: Add parameters.md

* JAX: Add init-param.md

* address init params review

* Address parameters.md review

* Update init-param.md
  • Loading branch information
AnirudhDagar authored Sep 17, 2022
1 parent 4cd8f42 commit 7dbc08a
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 21 deletions.
128 changes: 108 additions & 20 deletions chapter_builders-guide/init-param.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,22 @@ By default, Keras initializes weight matrices uniformly by drawing from a range
TensorFlow provides a variety of initialization methods both in the root module and the `keras.initializers` module.
:end_tab:

```{.python .input n=1}
:begin_tab:`jax`
By default, Flax initializes weights using `jax.nn.initializers.lecun_normal`,
i.e., by drawing samples from a truncated normal distribution centered on 0 with
the standard deviation set as the squared root of $1 / \text{fan}_{\text{in}}$
where `fan_in` is the number of input units in the weight tensor. The bias
parameters are all set to zero.
Jax's `nn.initializers` module provides a variety
of preset initialization methods.
:end_tab:

```{.python .input}
%load_ext d2lbook.tab
tab.interact_select(['mxnet', 'pytorch', 'tensorflow'])
tab.interact_select(['mxnet', 'pytorch', 'tensorflow', 'jax'])
```

```{.python .input n=2}
```{.python .input}
%%tab mxnet
from mxnet import init, np, npx
from mxnet.gluon import nn
Expand All @@ -47,7 +57,7 @@ X = np.random.uniform(size=(2, 4))
net(X).shape
```

```{.python .input n=3}
```{.python .input}
%%tab pytorch
import torch
from torch import nn
Expand All @@ -57,7 +67,7 @@ X = torch.rand(size=(2, 4))
net(X).shape
```

```{.python .input n=4}
```{.python .input}
%%tab tensorflow
import tensorflow as tf
Expand All @@ -71,22 +81,35 @@ X = tf.random.uniform((2, 4))
net(X).shape
```

```{.python .input}
%%tab jax
from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnp
net = nn.Sequential([nn.Dense(8), nn.relu, nn.Dense(1)])
X = jax.random.uniform(jax.random.PRNGKey(d2l.get_seed()), (2, 4))
params = net.init(jax.random.PRNGKey(d2l.get_seed()), X)
net.apply(params, X).shape
```

## [**Built-in Initialization**]

Let's begin by calling on built-in initializers.
The code below initializes all weight parameters
as Gaussian random variables
with standard deviation 0.01, while bias parameters cleared to zero.

```{.python .input n=5}
```{.python .input}
%%tab mxnet
# Here `force_reinit` ensures that parameters are freshly initialized even if
# they were already initialized previously
net.initialize(init=init.Normal(sigma=0.01), force_reinit=True)
net[0].weight.data()[0]
```

```{.python .input n=6}
```{.python .input}
%%tab pytorch
def init_normal(module):
if type(module) == nn.Linear:
Expand All @@ -96,7 +119,7 @@ net.apply(init_normal)
net[0].weight.data[0], net[0].bias.data[0]
```

```{.python .input n=7}
```{.python .input}
%%tab tensorflow
net = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
Expand All @@ -110,16 +133,30 @@ net(X)
net.weights[0], net.weights[1]
```

```{.python .input}
%%tab jax
weight_init = nn.initializers.normal(0.01)
bias_init = nn.initializers.zeros
net = nn.Sequential([nn.Dense(8, kernel_init=weight_init, bias_init=bias_init),
nn.relu,
nn.Dense(1, kernel_init=weight_init, bias_init=bias_init)])
params = net.init(jax.random.PRNGKey(d2l.get_seed()), X)
layer_0 = params['params']['layers_0']
layer_0['kernel'][:, 0], layer_0['bias'][0]
```

We can also initialize all the parameters
to a given constant value (say, 1).

```{.python .input n=8}
```{.python .input}
%%tab mxnet
net.initialize(init=init.Constant(1), force_reinit=True)
net[0].weight.data()[0]
```

```{.python .input n=9}
```{.python .input}
%%tab pytorch
def init_constant(module):
if type(module) == nn.Linear:
Expand All @@ -129,7 +166,7 @@ net.apply(init_constant)
net[0].weight.data[0], net[0].bias.data[0]
```

```{.python .input n=10}
```{.python .input}
%%tab tensorflow
net = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
Expand All @@ -144,21 +181,34 @@ net(X)
net.weights[0], net.weights[1]
```

```{.python .input}
%%tab jax
weight_init = nn.initializers.constant(1)
net = nn.Sequential([nn.Dense(8, kernel_init=weight_init, bias_init=bias_init),
nn.relu,
nn.Dense(1, kernel_init=weight_init, bias_init=bias_init)])
params = net.init(jax.random.PRNGKey(d2l.get_seed()), X)
layer_0 = params['params']['layers_0']
layer_0['kernel'][:, 0], layer_0['bias'][0]
```

[**We can also apply different initializers for certain blocks.**]
For example, below we initialize the first layer
with the Xavier initializer
and initialize the second layer
to a constant value of 42.

```{.python .input n=11}
```{.python .input}
%%tab mxnet
net[0].weight.initialize(init=init.Xavier(), force_reinit=True)
net[1].initialize(init=init.Constant(42), force_reinit=True)
print(net[0].weight.data()[0])
print(net[1].weight.data())
```

```{.python .input n=12}
```{.python .input}
%%tab pytorch
def init_xavier(module):
if type(module) == nn.Linear:
Expand All @@ -173,7 +223,7 @@ print(net[0].weight.data[0])
print(net[2].weight.data)
```

```{.python .input n=13}
```{.python .input}
%%tab tensorflow
net = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
Expand All @@ -190,6 +240,18 @@ print(net.layers[1].weights[0])
print(net.layers[2].weights[0])
```

```{.python .input}
%%tab jax
net = nn.Sequential([nn.Dense(8, kernel_init=nn.initializers.xavier_uniform(),
bias_init=bias_init),
nn.relu,
nn.Dense(1, kernel_init=nn.initializers.constant(42),
bias_init=bias_init)])
params = net.init(jax.random.PRNGKey(d2l.get_seed()), X)
params['params']['layers_0']['kernel'][:, 0], params['params']['layers_2']['kernel']
```

### [**Custom Initialization**]

Sometimes, the initialization methods we need
Expand Down Expand Up @@ -223,7 +285,13 @@ Here we define a subclass of `Initializer` and implement the `__call__`
function that return a desired tensor given the shape and data type.
:end_tab:

```{.python .input n=14}
:begin_tab:`jax`
Jax initialization functions take as arguments the `PRNGKey`, `shape` and
`dtype`. Here we implement the function `my_init` that returns a desired
tensor given the shape and data type.
:end_tab:

```{.python .input}
%%tab mxnet
class MyInit(init.Initializer):
def _init_weight(self, name, data):
Expand All @@ -235,7 +303,7 @@ net.initialize(MyInit(), force_reinit=True)
net[0].weight.data()[:2]
```

```{.python .input n=15}
```{.python .input}
%%tab pytorch
def my_init(module):
if type(module) == nn.Linear:
Expand All @@ -248,7 +316,7 @@ net.apply(my_init)
net[0].weight[:2]
```

```{.python .input n=16}
```{.python .input}
%%tab tensorflow
class MyInit(tf.keras.initializers.Initializer):
def __call__(self, shape, dtype=None):
Expand All @@ -270,24 +338,44 @@ net(X)
print(net.layers[1].weights[0])
```

```{.python .input}
%%tab jax
def my_init(key, shape, dtype=jnp.float_):
data = jax.random.uniform(key, shape, minval=-10, maxval=10)
return data * (jnp.abs(data) >= 5)
net = nn.Sequential([nn.Dense(8, kernel_init=my_init), nn.relu, nn.Dense(1)])
params = net.init(jax.random.PRNGKey(d2l.get_seed()), X)
print(params['params']['layers_0']['kernel'][:, :2])
```

:begin_tab:`mxnet, pytorch, tensorflow`
Note that we always have the option
of setting parameters directly.
:end_tab:

:begin_tab:`jax`
When initializing parameters in JAX and Flax, the the dictionary of parameters
returned has a `flax.core.frozen_dict.FrozenDict` type. It is not advisable in
the Jax ecosystem to directly alter the values of an array, hence the datatypes
are generally immutable. One might use `params.unfreeze()` to make changes.
:end_tab:

```{.python .input n=17}
```{.python .input}
%%tab mxnet
net[0].weight.data()[:] += 1
net[0].weight.data()[0, 0] = 42
net[0].weight.data()[0]
```

```{.python .input n=18}
```{.python .input}
%%tab pytorch
net[0].weight.data[:] += 1
net[0].weight.data[0, 0] = 42
net[0].weight.data[0]
```

```{.python .input n=19}
```{.python .input}
%%tab tensorflow
net.layers[1].weights[0][:].assign(net.layers[1].weights[0] + 1)
net.layers[1].weights[0][0, 0].assign(42)
Expand Down
Loading

0 comments on commit 7dbc08a

Please sign in to comment.