Skip to content

Commit

Permalink
[Feature] Enable fast conv bn eval (open-mmlab#1202)
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored Jul 14, 2023
1 parent 278f7f5 commit 40e49ff
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 0 deletions.
4 changes: 4 additions & 0 deletions docs/en/common_usage/save_gpu_memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

Memory capacity is critical in deep learning training and inference and determines whether the model can run successfully. Common memory saving approaches include:

- Enable Fast Conv BN Eval Feature (Experimental)

We've recently [introduced](https://github.com/open-mmlab/mmcv/pull/2807) an experimental feature in MMCV: the Fast Conv BN Eval, based on the concepts discussed in [this paper](https://arxiv.org/abs/2305.11624). This feature has been designed with the aim of reducing memory footprint during network training without hurting performance. If your network architecture contains a series of consecutive Conv+BN blocks, and these normalization layers are maintained in `eval` mode during the training process (a common occurrence when training object detectors with [MMDetection](https://github.com/open-mmlab/mmdetection)), this feature could reduce memory consumption by more than $20%$. To enable the Fast Conv BN Eval feature, simply add the following command-line arguments: `--cfg-options fast_conv_bn_eval="[backbone]"`. When you see `Enabling the "fast_conv_bn_eval" feature for these modules ...` in the output log, the feature is successfully enabled. As this is currently in an experimental phase, we are eagerly looking forward to hearing about your experience with it. Please share your usage reports, observations, and suggestions at [this discussion thread](https://github.com/open-mmlab/mmcv/discussions/2841). Your feedback is crucial for further development and for determining whether this feature should be integrated into the stable release.

- Gradient Accumulation

Gradient accumulation is the mechanism that runs at a configured number of steps accumulating the gradients instead of updating parameters, after which the network parameters are updated and the gradients are cleared. With this technique of delayed parameter update, the result is similar to those scenarios using a large batch size, while the memory of activation can be saved. However, it should be noted that if the model contains a batch normalization layer, using gradient accumulation will impact performance.
Expand Down
4 changes: 4 additions & 0 deletions docs/zh_cn/common_usage/save_gpu_memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

在深度学习训练推理过程中显存容量至关重要,其决定了模型是否能成功运行。常见的节省显存办法包括:

- 启用快速卷积BN评估功能(实验性)

基于在[这篇论文](https://arxiv.org/abs/2305.11624)中讨论的概念,我们最近在MMCV中[引入](https://github.com/open-mmlab/mmcv/pull/2807)了一个实验性功能:快速卷积BN评估。这个功能的设计目标是在不损害性能的情况下减少网络训练过程中的显存占用。如果你的网络架构包含了一系列连续的Conv+BN模块,而且这些BN层在训练过程中保持在 `eval` 模式(在使用 [MMDetection](https://github.com/open-mmlab/mmdetection)训练对象检测器时很常见),这个功能可以将显存消耗减少超过 $20%$。要启用快速卷积BN评估功能,只需添加以下命令行参数:`--cfg-options fast_conv_bn_eval="[backbone]"`。当你在输出日志中看到 `Enabling the "fast_conv_bn_eval" feature for these modules ...`时,意味着功能已成功启用。由于这仍处于实验阶段,我们非常期待听到你对它的使用体验。请在[这个讨论线程](https://github.com/open-mmlab/mmcv/discussions/2841)分享你的使用报告、观察和建议。你的反馈对于进一步的开发和确定是否应将此功能集成到稳定版中至关重要。

- 梯度累加

梯度累加是指在每计算一个批次的梯度后,不进行清零而是进行梯度累加,当累加到一定的次数之后,再更新网络参数和梯度清零。 通过这种参数延迟更新的手段,实现与采用大 batch 尺寸相近的效果,达到节省显存的目的。但是需要注意如果模型中包含 batch normalization 层,使用梯度累加会对性能有一定影响。
Expand Down
145 changes: 145 additions & 0 deletions mmengine/model/fast_conv_bn_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial
from operator import attrgetter
from typing import List, Union

import torch
import torch.nn as nn


def fast_conv_bn_eval_forward(bn: nn.modules.batchnorm._BatchNorm,
conv: nn.modules.conv._ConvNd, x: torch.Tensor):
"""Code borrowed from mmcv 2.0.1, so that this feature can be used for old
mmcv versions.
Implementation based on https://arxiv.org/abs/2305.11624
"Tune-Mode ConvBN Blocks For Efficient Transfer Learning"
It leverages the associative law between convolution and affine transform,
i.e., normalize (weight conv feature) = (normalize weight) conv feature.
It works for Eval mode of ConvBN blocks during validation, and can be used
for training as well. It reduces memory and computation cost.
Args:
bn (_BatchNorm): a BatchNorm module.
conv (nn._ConvNd): a conv module
x (torch.Tensor): Input feature map.
"""
# These lines of code are designed to deal with various cases
# like bn without affine transform, and conv without bias
weight_on_the_fly = conv.weight
if conv.bias is not None:
bias_on_the_fly = conv.bias
else:
bias_on_the_fly = torch.zeros_like(bn.running_var)

if bn.weight is not None:
bn_weight = bn.weight
else:
bn_weight = torch.ones_like(bn.running_var)

if bn.bias is not None:
bn_bias = bn.bias
else:
bn_bias = torch.zeros_like(bn.running_var)

# shape of [C_out, 1, 1, 1] in Conv2d
weight_coeff = torch.rsqrt(bn.running_var +
bn.eps).reshape([-1] + [1] *
(len(conv.weight.shape) - 1))
# shape of [C_out, 1, 1, 1] in Conv2d
coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff

# shape of [C_out, C_in, k, k] in Conv2d
weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly
# shape of [C_out] in Conv2d
bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() *\
(bias_on_the_fly - bn.running_mean)

return conv._conv_forward(x, weight_on_the_fly, bias_on_the_fly)


def bn_once_identity_forward(bn: nn.modules.batchnorm._BatchNorm,
x: torch.Tensor):
"""The forward function is an identity function.
The magic is that after one call, the `bn.forward` will be restored to what
it used to be.
"""
bn.__dict__.pop('forward')
return x


def fast_conv_bn_eval_control(bn: nn.modules.batchnorm._BatchNorm,
conv: nn.modules.conv._ConvNd, x: torch.Tensor):
"""This function controls whether to use `fast_conv_bn_eval_forward`.
If the following `bn` is in `eval` mode, then we turn on the special
`fast_conv_bn_eval_forward` and let the following call of `bn.forward` to
be identity. Note that this `bn.forward` modification only works for one
call. After the call, `bn.forward` will be restored to the default
function. This is to deal with the case where one `bn` module is used in
multiple places.
"""
if not bn.training:
# bn in eval mode
output = fast_conv_bn_eval_forward(bn, conv, x)
bn.forward = partial(bn_once_identity_forward, bn)
return output
else:
return conv._conv_forward(x, conv.weight, conv.bias)


def turn_on_fast_conv_bn_eval_for_single_model(model: torch.nn.Module):
# optimize consecutive conv+bn by modifying forward function
# Symbolically trace the input model to create an FX GraphModule
import torch.fx as fx
fx_model: fx.GraphModule = fx.symbolic_trace(model)
modules = dict(fx_model.named_modules())

patterns = [(torch.nn.modules.conv._ConvNd,
torch.nn.modules.batchnorm._BatchNorm)]

# Iterate through nodes in the graph to find ConvBN blocks
for node in fx_model.graph.nodes:
# If our current node isn't calling a Module then we can ignore it.
if node.op != 'call_module':
continue
target_module = modules[node.target]
found_pair = False
for conv_class, bn_class in patterns:
if isinstance(target_module, bn_class):
source_module = modules[node.args[0].target]
if isinstance(source_module, conv_class):
found_pair = True
# Not a conv-BN pattern or output of conv is used by other nodes
if not found_pair or len(node.args[0].users) > 1:
continue

# check if the conv modules are used in multiple nodes
conv_name = node.args[0].target
bn_name = node.target

conv_usage_count = 0
for _node in fx_model.graph.nodes:
if _node.op != 'call_module':
continue
if _node.target == conv_name:
conv_usage_count += 1

if conv_usage_count > 1:
continue

# Find a pair of conv and bn to optimize
conv_module = modules[conv_name]
bn_module = modules[bn_name]

conv_module.forward = partial(fast_conv_bn_eval_control, bn_module,
conv_module)


def turn_on_fast_conv_bn_eval(model: torch.nn.Module, modules: Union[List[str],
str]):
if isinstance(modules, str):
modules = [modules]
for module_name in modules:
module = attrgetter(module_name)(model)
turn_on_fast_conv_bn_eval_for_single_model(module)
9 changes: 9 additions & 0 deletions mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from mmengine.logging import MessageHub, MMLogger, print_log
from mmengine.model import (MMDistributedDataParallel, convert_sync_batchnorm,
is_model_wrapper, revert_sync_batchnorm)
from mmengine.model.fast_conv_bn_eval import turn_on_fast_conv_bn_eval
from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
build_optim_wrapper)
from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, FUNCTIONS,
Expand Down Expand Up @@ -1719,6 +1720,14 @@ def train(self) -> nn.Module:

# initialize the model weights
self._init_model_weights()

# try to enable fast_conv_bn_eval feature
modules = self.cfg.get('fast_conv_bn_eval', None)
if modules is not None:
self.logger.info(f'Enabling the "fast_conv_bn_eval" feature'
f' for sub-modules: {modules}')
turn_on_fast_conv_bn_eval(ori_model, modules)

# make sure checkpoint-related hooks are triggered after `before_run`
self.load_or_resume()

Expand Down
57 changes: 57 additions & 0 deletions tests/test_model/test_fast_conv_bn_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase

import torch
from torch import nn

from mmengine.model.fast_conv_bn_eval import \
turn_on_fast_conv_bn_eval_for_single_model
from mmengine.testing import assert_allclose
from mmengine.utils import is_installed

mmcv_is_installed = is_installed('mmcv')


class BackboneModel(nn.Module):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
if mmcv_is_installed:
from mmcv.cnn import ConvModule
conv0 = nn.Conv2d(6, 6, 6)
bn0 = nn.BatchNorm2d(6)
self.mod1 = ConvModule.create_from_conv_bn(conv0, bn0)
self.conv1 = nn.Conv2d(6, 6, 6)
self.bn1 = nn.BatchNorm2d(6)
self.conv2 = nn.Conv2d(6, 6, 6)
self.bn2 = nn.BatchNorm2d(6)
self.conv3 = nn.Conv2d(6, 6, 6)
self.bn3 = nn.BatchNorm2d(6)

def forward(self, x):
if mmcv_is_installed:
# this ConvModule can use fast_conv_bn_eval feature
x = self.mod1(x)
# this conv-bn pair can use fast_conv_bn_eval feature
x = self.bn1(self.conv1(x))
# this conv-bn pair cannot use fast_conv_bn_eval feature
# because `self.conv2` is used twice
x = self.bn2(self.conv2(self.conv2(x)))
# this conv-bn pair can use fast_conv_bn_eval feature
# just for the first forward of the `self.bn3`
x = self.bn3(self.bn3(self.conv3(x)))
return x


class TestFastConvBNEval(TestCase):
"""Test the turn_on_fast_conv_bn_eval function."""

def test_fast_conv_bn_eval(self):
model = BackboneModel()
model.eval()
input = torch.randn(64, 6, 32, 32)
output = model(input)
turn_on_fast_conv_bn_eval_for_single_model(model)
output2 = model(input)
print((output - output2).abs().max().item())
assert_allclose(output, output2)

0 comments on commit 40e49ff

Please sign in to comment.