Skip to content

Commit 05e4bc1

Browse files
authoredNov 14, 2022
[Feature] Support Activation Checkpointing for ConvNeXt. (open-mmlab#1152)
* Support Activation Checkpointing for ConvNeXt * Add test case * Lint * Add docstring
1 parent 8c63bb5 commit 05e4bc1

File tree

2 files changed

+42
-16
lines changed

2 files changed

+42
-16
lines changed
 

‎mmcls/models/backbones/convnext.py

+32-16
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
import torch.nn as nn
88
import torch.nn.functional as F
9+
import torch.utils.checkpoint as cp
910
from mmcv.cnn.bricks import (NORM_LAYERS, DropPath, build_activation_layer,
1011
build_norm_layer)
1112
from mmcv.runner import BaseModule
@@ -77,8 +78,11 @@ def __init__(self,
7778
mlp_ratio=4.,
7879
linear_pw_conv=True,
7980
drop_path_rate=0.,
80-
layer_scale_init_value=1e-6):
81+
layer_scale_init_value=1e-6,
82+
with_cp=False):
8183
super().__init__()
84+
self.with_cp = with_cp
85+
8286
self.depthwise_conv = nn.Conv2d(
8387
in_channels,
8488
in_channels,
@@ -108,24 +112,33 @@ def __init__(self,
108112
drop_path_rate) if drop_path_rate > 0. else nn.Identity()
109113

110114
def forward(self, x):
111-
shortcut = x
112-
x = self.depthwise_conv(x)
113-
x = self.norm(x)
114115

115-
if self.linear_pw_conv:
116-
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
116+
def _inner_forward(x):
117+
shortcut = x
118+
x = self.depthwise_conv(x)
119+
x = self.norm(x)
117120

118-
x = self.pointwise_conv1(x)
119-
x = self.act(x)
120-
x = self.pointwise_conv2(x)
121+
if self.linear_pw_conv:
122+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
121123

122-
if self.linear_pw_conv:
123-
x = x.permute(0, 3, 1, 2) # permute back
124+
x = self.pointwise_conv1(x)
125+
x = self.act(x)
126+
x = self.pointwise_conv2(x)
124127

125-
if self.gamma is not None:
126-
x = x.mul(self.gamma.view(1, -1, 1, 1))
128+
if self.linear_pw_conv:
129+
x = x.permute(0, 3, 1, 2) # permute back
130+
131+
if self.gamma is not None:
132+
x = x.mul(self.gamma.view(1, -1, 1, 1))
133+
134+
x = shortcut + self.drop_path(x)
135+
return x
136+
137+
if self.with_cp and x.requires_grad:
138+
x = cp.checkpoint(_inner_forward, x)
139+
else:
140+
x = _inner_forward(x)
127141

128-
x = shortcut + self.drop_path(x)
129142
return x
130143

131144

@@ -169,6 +182,8 @@ class ConvNeXt(BaseBackbone):
169182
gap_before_final_norm (bool): Whether to globally average the feature
170183
map before the final norm layer. In the official repo, it's only
171184
used in classification task. Defaults to True.
185+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
186+
memory while slowing down the training speed. Defaults to False.
172187
init_cfg (dict, optional): Initialization config dict
173188
""" # noqa: E501
174189
arch_settings = {
@@ -206,6 +221,7 @@ def __init__(self,
206221
out_indices=-1,
207222
frozen_stages=0,
208223
gap_before_final_norm=True,
224+
with_cp=False,
209225
init_cfg=None):
210226
super().__init__(init_cfg=init_cfg)
211227

@@ -288,8 +304,8 @@ def __init__(self,
288304
norm_cfg=norm_cfg,
289305
act_cfg=act_cfg,
290306
linear_pw_conv=linear_pw_conv,
291-
layer_scale_init_value=layer_scale_init_value)
292-
for j in range(depth)
307+
layer_scale_init_value=layer_scale_init_value,
308+
with_cp=with_cp) for j in range(depth)
293309
])
294310
block_idx += depth
295311

‎tests/test_models/test_backbones/test_convnext.py

+10
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,13 @@ def test_convnext():
8484
for i in range(2, 4):
8585
assert model.downsample_layers[i].training
8686
assert model.stages[i].training
87+
88+
# Test Activation Checkpointing
89+
model = ConvNeXt(arch='tiny', out_indices=-1, with_cp=True)
90+
model.init_weights()
91+
model.train()
92+
93+
imgs = torch.randn(1, 3, 224, 224)
94+
feat = model(imgs)
95+
assert len(feat) == 1
96+
assert feat[0].shape == torch.Size([1, 768])

0 commit comments

Comments
 (0)
Please sign in to comment.