Skip to content

Commit

Permalink
JAX: Add padding-and-strides.md (d2l-ai#2318)
Browse files Browse the repository at this point in the history
* JAX: Add padding-and-strides.md

* Update chapter_convolutional-neural-networks/padding-and-strides.md

* Update chapter_convolutional-neural-networks/padding-and-strides.md

* Update chapter_convolutional-neural-networks/padding-and-strides.md

* Update chapter_convolutional-neural-networks/padding-and-strides.md

* Update chapter_convolutional-neural-networks/padding-and-strides.md

* Update chapter_convolutional-neural-networks/padding-and-strides.md

* Update chapter_convolutional-neural-networks/padding-and-strides.md

Co-authored-by: Aston Zhang <[email protected]>
  • Loading branch information
AnirudhDagar and astonzhang authored Oct 3, 2022
1 parent c0fdbfa commit 6997700
Showing 1 changed file with 56 additions and 13 deletions.
69 changes: 56 additions & 13 deletions chapter_convolutional-neural-networks/padding-and-strides.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
```{.python .input}
%load_ext d2lbook.tab
tab.interact_select(['mxnet', 'pytorch', 'tensorflow'])
tab.interact_select(['mxnet', 'pytorch', 'tensorflow', 'jax'])
```

# Padding and Stride
Expand Down Expand Up @@ -135,9 +135,9 @@ comp_conv2d(conv2d, X).shape
import torch
from torch import nn
# We define a helper function to calculate convolutions. It initializes
# the convolutional layer weights and performs corresponding dimensionality
# elevations and reductions on the input and output.
# We define a helper function to calculate convolutions. It initializes
# the convolutional layer weights and performs corresponding dimensionality
# elevations and reductions on the input and output
def comp_conv2d(conv2d, X):
# (1, 1) indicates that batch size and the number of channels are both 1
X = X.reshape((1, 1) + X.shape)
Expand All @@ -154,9 +154,9 @@ comp_conv2d(conv2d, X).shape
%%tab tensorflow
import tensorflow as tf
# We define a helper function to calculate convolutions. It initializes
# the convolutional layer weights and performs corresponding dimensionality
# elevations and reductions on the input and output.
# We define a helper function to calculate convolutions. It initializes
# the convolutional layer weights and performs corresponding dimensionality
# elevations and reductions on the input and output
def comp_conv2d(conv2d, X):
# (1, 1) indicates that batch size and the number of channels are both 1
X = tf.reshape(X, (1, ) + X.shape + (1, ))
Expand All @@ -169,34 +169,65 @@ X = tf.random.uniform(shape=(8, 8))
comp_conv2d(conv2d, X).shape
```

```{.python .input}
%%tab jax
from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnp
# We define a helper function to calculate convolutions. It initializes
# the convolutional layer weights and performs corresponding dimensionality
# elevations and reductions on the input and output
def comp_conv2d(conv2d, X):
# (1, X.shape, 1) indicates that batch size and the number of channels are both 1
key = jax.random.PRNGKey(d2l.get_seed())
X = X.reshape((1,) + X.shape + (1,))
Y, _ = conv2d.init_with_output(key, X)
# Strip the dimensions: examples and channels
return Y.reshape(Y.shape[1:3])
# 1 row and column is padded on either side, so a total of 2 rows or columns are added
conv2d = nn.Conv(1, kernel_size=(3, 3), padding='SAME')
X = jax.random.uniform(jax.random.PRNGKey(d2l.get_seed()), shape=(8, 8))
comp_conv2d(conv2d, X).shape
```

When the height and width of the convolution kernel are different,
we can make the output and input have the same height and width
by [**setting different padding numbers for height and width.**]

```{.python .input}
%%tab mxnet
# We use a convolution kernel with height 5 and width 3. The padding on
# either side of the height and width are 2 and 1, respectively.
# We use a convolution kernel with height 5 and width 3. The padding on
# either side of the height and width are 2 and 1, respectively
conv2d = nn.Conv2D(1, kernel_size=(5, 3), padding=(2, 1))
comp_conv2d(conv2d, X).shape
```

```{.python .input}
%%tab pytorch
# We use a convolution kernel with height 5 and width 3. The padding on
# either side of the height and width are 2 and 1, respectively.
# We use a convolution kernel with height 5 and width 3. The padding on
# either side of the height and width are 2 and 1, respectively
conv2d = nn.LazyConv2d(1, kernel_size=(5, 3), padding=(2, 1))
comp_conv2d(conv2d, X).shape
```

```{.python .input}
%%tab tensorflow
# We use a convolution kernel with height 5 and width 3. The padding on
# either side of the height and width are 2 and 1, respectively.
# We use a convolution kernel with height 5 and width 3. The padding on
# either side of the height and width are 2 and 1, respectively
conv2d = tf.keras.layers.Conv2D(1, kernel_size=(5, 3), padding='same')
comp_conv2d(conv2d, X).shape
```

```{.python .input}
%%tab jax
# We use a convolution kernel with height 5 and width 3. The padding on
# either side of the height and width are 2 and 1, respectively
conv2d = nn.Conv(1, kernel_size=(5, 3), padding=(2, 1))
comp_conv2d(conv2d, X).shape
```

## Stride

When computing the cross-correlation,
Expand Down Expand Up @@ -260,6 +291,12 @@ conv2d = tf.keras.layers.Conv2D(1, kernel_size=3, padding='same', strides=2)
comp_conv2d(conv2d, X).shape
```

```{.python .input}
%%tab jax
conv2d = nn.Conv(1, kernel_size=(3, 3), padding=1, strides=2)
comp_conv2d(conv2d, X).shape
```

Let's look at (**a slightly more complicated example**).

```{.python .input}
Expand All @@ -281,6 +318,12 @@ conv2d = tf.keras.layers.Conv2D(1, kernel_size=(3,5), padding='valid',
comp_conv2d(conv2d, X).shape
```

```{.python .input}
%%tab jax
conv2d = nn.Conv(1, kernel_size=(3, 5), padding=(0, 1), strides=(3, 4))
comp_conv2d(conv2d, X).shape
```

## Summary and Discussion

Padding can increase the height and width of the output. This is often used to give the output the same height and width as the input to avoid undesirable shrinkage of the output. Moreover, it ensures that all pixels are used equally frequently. Typically we pick symmetric padding on both sides of the input height and width. In this case we refer to $(p_h, p_w)$ padding. Most commonly we set $p_h = p_w$, in which case we simply state that we choose padding $p$.
Expand Down

0 comments on commit 6997700

Please sign in to comment.