%load_ext d2lbook.tab
tab.interact_select(['mxnet', 'pytorch', 'tensorflow', 'jax'])
🏷️sec_cnn-design
The previous sections have taken us on a tour of modern network design for computer vision. Common to all the work we covered was that it greatly relied on the intuition of scientists. Many of the architectures are heavily informed by human creativity and to a much lesser extent by systematic exploration of the design space that deep networks offer. Nonetheless, this network engineering approach has been tremendously successful.
Ever since AlexNet (:numref:sec_alexnet
)
beat conventional computer vision models on ImageNet,
it has become popular to construct very deep networks
by stacking blocks of convolutions, all designed according to the same pattern.
In particular, sec_vgg
).
NiN (:numref:sec_nin
) showed that even sec_googlenet
) added multiple branches of different convolution width,
combining the advantages of VGG and NiN in its Inception block.
ResNets (:numref:sec_resnet
)
changed the inductive bias towards the identity mapping (from subsec_resnext
) added grouped convolutions, offering a better trade-off between parameters and computation. A precursor to Transformers for vision, the Squeeze-and-Excitation Networks (SENets) allow for efficient information transfer between locations
:cite:Hu.Shen.Sun.2018
. This was accomplished by computing a per-channel global attention function.
Up to now we have omitted networks obtained via neural architecture search (NAS) :cite:zoph2016neural,liu2018darts
. We chose to do so since their cost is usually enormous, relying on brute-force search, genetic algorithms, reinforcement learning, or some other form of hyperparameter optimization. Given a fixed search space,
NAS uses a search strategy to automatically select
an architecture based on the returned performance estimation.
The outcome of NAS
is a single network instance. EfficientNets are a notable outcome of this search :cite:tan2019efficientnet
.
In the following we discuss an idea that is quite different to the quest for the single best network. It is computationally relatively inexpensive, it leads to scientific insights on the way, and it is quite effective in terms of the quality of outcomes. Let's review the strategy by :citet:Radosavovic.Kosaraju.Girshick.ea.2020
to design network design spaces. The strategy combines the strength of manual design and NAS. It accomplishes this by operating on distributions of networks and optimizing the distributions in a way to obtain good performance for entire families of networks. The outcome of it are RegNets, specifically RegNetX and RegNetY, plus a range of guiding principles for the design of performant CNNs.
%%tab mxnet
from d2l import mxnet as d2l
from mxnet import np, npx, init
from mxnet.gluon import nn
npx.set_np()
%%tab pytorch
from d2l import torch as d2l
import torch
from torch import nn
from torch.nn import functional as F
%%tab tensorflow
import tensorflow as tf
from d2l import tensorflow as d2l
%%tab jax
from d2l import jax as d2l
from flax import linen as nn
🏷️subsec_the-anynet-design-space
The description below closely follows the reasoning in :citet:Radosavovic.Kosaraju.Girshick.ea.2020
with some abbreviations to make it fit in the scope of the book.
To begin, we need a template for the family of networks to explore. One of the commonalities of the designs in this chapter is that the networks consist of a stem, a body and a head. The stem performs initial image processing, often through convolutions with a larger window size. The body consists of multiple blocks, carrying out the bulk of the transformations needed to go from raw images to object representations. Lastly, the head converts this into the desired outputs, such as via a softmax regressor for multiclass classification.
The body, in turn, consists of multiple stages, operating on the image at decreasing resolutions. In fact, both the stem and each subsequent stage quarter the spatial resolution. Lastly, each stage consists of one or more blocks. This pattern is common to all networks, from VGG to ResNeXt. Indeed, for the design of generic AnyNet networks, :citet:Radosavovic.Kosaraju.Girshick.ea.2020
used the ResNeXt block of :numref:fig_resnext_block
.
Let's review the structure outlined in :numref:fig_anynet_full
in detail. As mentioned, an AnyNet consists of a stem, body, and head. The stem takes as its input RGB images (3 channels), using a
Since the network is designed to work well with ImageNet images of shape sec_nin
), followed by a fully connected layer to emit an
Most of the relevant design decisions are inherent to the body of the network. It proceeds in stages, where each stage is composed of the same type of ResNeXt blocks as we discussed in :numref:subsec_resnext
. The design there is again entirely generic: we begin with a block that halves the resolution by using a stride of fig_anynet_full
). To match this, the residual branch of the ResNeXt block needs to pass through a
This seemingly generic design space provides us nonetheless with many parameters: we can set the block width (number of channels)
%%tab mxnet
class AnyNet(d2l.Classifier):
def stem(self, num_channels):
net = nn.Sequential()
net.add(nn.Conv2D(num_channels, kernel_size=3, padding=1, strides=2),
nn.BatchNorm(), nn.Activation('relu'))
return net
%%tab pytorch
class AnyNet(d2l.Classifier):
def stem(self, num_channels):
return nn.Sequential(
nn.LazyConv2d(num_channels, kernel_size=3, stride=2, padding=1),
nn.LazyBatchNorm2d(), nn.ReLU())
%%tab tensorflow
class AnyNet(d2l.Classifier):
def stem(self, num_channels):
return tf.keras.models.Sequential([
tf.keras.layers.Conv2D(num_channels, kernel_size=3, strides=2,
padding='same'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('relu')])
%%tab jax
class AnyNet(d2l.Classifier):
arch: tuple
stem_channels: int
lr: float = 0.1
num_classes: int = 10
training: bool = True
def setup(self):
self.net = self.create_net()
def stem(self, num_channels):
return nn.Sequential([
nn.Conv(num_channels, kernel_size=(3, 3), strides=(2, 2),
padding=(1, 1)),
nn.BatchNorm(not self.training),
nn.relu
])
Each stage consists of depth
ResNeXt blocks,
where num_channels
specifies the block width.
Note that the first block halves the height and width of input images.
%%tab mxnet
@d2l.add_to_class(AnyNet)
def stage(self, depth, num_channels, groups, bot_mul):
net = nn.Sequential()
for i in range(depth):
if i == 0:
net.add(d2l.ResNeXtBlock(
num_channels, groups, bot_mul, use_1x1conv=True, strides=2))
else:
net.add(d2l.ResNeXtBlock(
num_channels, num_channels, groups, bot_mul))
return net
%%tab pytorch
@d2l.add_to_class(AnyNet)
def stage(self, depth, num_channels, groups, bot_mul):
blk = []
for i in range(depth):
if i == 0:
blk.append(d2l.ResNeXtBlock(num_channels, groups, bot_mul,
use_1x1conv=True, strides=2))
else:
blk.append(d2l.ResNeXtBlock(num_channels, groups, bot_mul))
return nn.Sequential(*blk)
%%tab tensorflow
@d2l.add_to_class(AnyNet)
def stage(self, depth, num_channels, groups, bot_mul):
net = tf.keras.models.Sequential()
for i in range(depth):
if i == 0:
net.add(d2l.ResNeXtBlock(num_channels, groups, bot_mul,
use_1x1conv=True, strides=2))
else:
net.add(d2l.ResNeXtBlock(num_channels, groups, bot_mul))
return net
%%tab jax
@d2l.add_to_class(AnyNet)
def stage(self, depth, num_channels, groups, bot_mul):
blk = []
for i in range(depth):
if i == 0:
blk.append(d2l.ResNeXtBlock(num_channels, groups, bot_mul,
use_1x1conv=True, strides=(2, 2), training=self.training))
else:
blk.append(d2l.ResNeXtBlock(num_channels, groups, bot_mul,
training=self.training))
return nn.Sequential(blk)
Putting the network stem, body, and head together, we complete the implementation of AnyNet.
%%tab pytorch, mxnet, tensorflow
@d2l.add_to_class(AnyNet)
def __init__(self, arch, stem_channels, lr=0.1, num_classes=10):
super(AnyNet, self).__init__()
self.save_hyperparameters()
if tab.selected('mxnet'):
self.net = nn.Sequential()
self.net.add(self.stem(stem_channels))
for i, s in enumerate(arch):
self.net.add(self.stage(*s))
self.net.add(nn.GlobalAvgPool2D(), nn.Dense(num_classes))
self.net.initialize(init.Xavier())
if tab.selected('pytorch'):
self.net = nn.Sequential(self.stem(stem_channels))
for i, s in enumerate(arch):
self.net.add_module(f'stage{i+1}', self.stage(*s))
self.net.add_module('head', nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(),
nn.LazyLinear(num_classes)))
self.net.apply(d2l.init_cnn)
if tab.selected('tensorflow'):
self.net = tf.keras.models.Sequential(self.stem(stem_channels))
for i, s in enumerate(arch):
self.net.add(self.stage(*s))
self.net.add(tf.keras.models.Sequential([
tf.keras.layers.GlobalAvgPool2D(),
tf.keras.layers.Dense(units=num_classes)]))
%%tab jax
@d2l.add_to_class(AnyNet)
def create_net(self):
net = nn.Sequential([self.stem(self.stem_channels)])
for i, s in enumerate(self.arch):
net.layers.extend([self.stage(*s)])
net.layers.extend([nn.Sequential([
lambda x: nn.avg_pool(x, window_shape=x.shape[1:3],
strides=x.shape[1:3], padding='valid'),
lambda x: x.reshape((x.shape[0], -1)),
nn.Dense(self.num_classes)])])
return net
As just discussed in :numref:subsec_the-anynet-design-space
, parameters of a design space are hyperparameters of networks in that design space.
Consider the problem of identifying good parameters in the AnyNet design space. We could try finding the single best parameter choice for a given amount of computation (e.g., FLOPs and compute time). If we allowed for even only two possible choices for each parameter, we would have to explore radosavovic2019network
relies on the following four assumptions:
- We assume that general design principles actually exist, so that many networks satisfying these requirements should offer good performance. Consequently, identifying a distribution over networks can be a sensible strategy. In other words, we assume that there are many good needles in the haystack.
- We need not train networks to convergence before we can assess whether a network is good. Instead, it is sufficient to use the intermediate results as reliable guidance for final accuracy. Using (approximate) proxies to optimize an objective is referred to as multi-fidelity optimization :cite:
forrester2007multi
. Consequently, design optimization is carried out, based on the accuracy achieved after only a few passes through the dataset, reducing the cost significantly. - Results obtained at a smaller scale (for smaller networks) generalize to larger ones. Consequently, optimization is carried out for networks that are structurally similar, but with a smaller number of blocks, fewer channels, etc. Only in the end will we need to verify that the so-found networks also offer good performance at scale.
- Aspects of the design can be approximately factorized so that it is possible to infer their effect on the quality of the outcome somewhat independently. In other words, the optimization problem is moderately easy.
These assumptions allow us to test many networks cheaply. In particular, we can sample uniformly from the space of configurations and evaluate their performance. Subsequently, we can evaluate the quality of the choice of parameters by reviewing the distribution of error/accuracy that can be achieved with said networks. Denote by
Our goal is now to find a distribution
Whenever the CDF for one set of choices majorizes (or matches) another CDF it follows that its choice of parameters is superior (or indifferent). Accordingly
:citet:Radosavovic.Kosaraju.Girshick.ea.2020
experimented with a shared network bottleneck ratio fig_regnet-fig
.
Likewise, we could choose to pick the same group width fig_regnet-fig
.
Both steps combined reduce the number of free parameters by six.
Next we look for ways to reduce the multitude of potential choices for width and depth of the stages. It is a reasonable assumption that, as we go deeper, the number of channels should increase, i.e., fig_regnet-fig
), yielding
$\textrm{AnyNetX}D$. Likewise, it is equally reasonable to assume that as the stages progress, they should become deeper, i.e., $d_i \geq d{i-1}$, yielding fig_regnet-fig
, respectively.
The resulting
- Share the bottleneck ratio
$k_i = k$ for all stages$i$ ; - Share the group width
$g_i = g$ for all stages$i$ ; - Increase network width across stages:
$c_{i} \leq c_{i+1}$ ; - Increase network depth across stages:
$d_{i} \leq d_{i+1}$ .
This leaves us with a final set of choices: how to pick the specific values for the above parameters of the eventual
We recommend the interested reader reviews further details in the design of specific networks for different amounts of computation by perusing :citet:Radosavovic.Kosaraju.Girshick.ea.2020
. For instance, an effective 32-layer RegNetX variant is given by Hu.Shen.Sun.2018
.
%%tab pytorch, mxnet, tensorflow
class RegNetX32(AnyNet):
def __init__(self, lr=0.1, num_classes=10):
stem_channels, groups, bot_mul = 32, 16, 1
depths, channels = (4, 6), (32, 80)
super().__init__(
((depths[0], channels[0], groups, bot_mul),
(depths[1], channels[1], groups, bot_mul)),
stem_channels, lr, num_classes)
%%tab jax
class RegNetX32(AnyNet):
lr: float = 0.1
num_classes: int = 10
stem_channels: int = 32
arch: tuple = ((4, 32, 16, 1), (6, 80, 16, 1))
We can see that each RegNetX stage progressively reduces resolution and increases output channels.
%%tab mxnet, pytorch
RegNetX32().layer_summary((1, 1, 96, 96))
%%tab tensorflow
RegNetX32().layer_summary((1, 96, 96, 1))
%%tab jax
RegNetX32(training=False).layer_summary((1, 96, 96, 1))
Training the 32-layer RegNetX on the Fashion-MNIST dataset is just like before.
%%tab mxnet, pytorch, jax
model = RegNetX32(lr=0.05)
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128, resize=(96, 96))
trainer.fit(model, data)
%%tab tensorflow
trainer = d2l.Trainer(max_epochs=10)
data = d2l.FashionMNIST(batch_size=128, resize=(96, 96))
with d2l.try_gpu():
model = RegNetX32(lr=0.01)
trainer.fit(model, data)
With desirable inductive biases (assumptions or preferences) like locality and translation invariance (:numref:sec_why-conv
)
for vision, CNNs have been the dominant architectures in this area. This remained the case from LeNet up until Transformers (:numref:sec_transformer
) :cite:Dosovitskiy.Beyer.Kolesnikov.ea.2021,touvron2021training
started surpassing CNNs in terms of accuracy. While much of the recent progress in terms of vision Transformers can be backported into CNNs :cite:liu2022convnet
, it is only possible at a higher computational cost. Just as importantly, recent hardware optimizations (NVIDIA Ampere and Hopper) have only widened the gap in favor of Transformers.
It is worth noting that Transformers have a significantly lower degree of inductive bias towards locality and translation invariance than CNNs. That learned structures prevailed is due, not least, to the availability of large image collections, such as LAION-400m and LAION-5B :cite:schuhmann2022laion
with up to 5 billion images. Quite surprisingly, some of the more relevant work in this context even includes MLPs :cite:tolstikhin2021mlp
.
In sum, vision Transformers (:numref:sec_vision-transformer
) by now lead in terms of
state-of-the-art performance in large-scale image classification,
showing that scalability trumps inductive biases :cite:Dosovitskiy.Beyer.Kolesnikov.ea.2021
.
This includes pretraining large-scale Transformers (:numref:sec_large-pretraining-transformers
) with multi-head self-attention (:numref:sec_multihead-attention
). We invite the readers to dive into these chapters for a much more detailed discussion.
- Increase the number of stages to four. Can you design a deeper RegNetX that performs better?
- De-ResNeXt-ify RegNets by replacing the ResNeXt block with the ResNet block. How does your new model perform?
- Implement multiple instances of a "VioNet" family by violating the design principles of RegNetX. How do they perform? Which of (
$d_i$ ,$c_i$ ,$g_i$ ,$b_i$ ) is the most important factor? - Your goal is to design the "perfect" MLP. Can you use the design principles introduced above to find good architectures? Is it possible to extrapolate from small to large networks?
:begin_tab:mxnet
Discussions
:end_tab:
:begin_tab:pytorch
Discussions
:end_tab:
:begin_tab:tensorflow
Discussions
:end_tab:
:begin_tab:jax
Discussions
:end_tab: