|
6 | 6 | import torch
|
7 | 7 | import torch.nn as nn
|
8 | 8 | import torch.nn.functional as F
|
| 9 | +import torch.utils.checkpoint as cp |
9 | 10 | from mmcv.cnn.bricks import (NORM_LAYERS, DropPath, build_activation_layer,
|
10 | 11 | build_norm_layer)
|
11 | 12 | from mmcv.runner import BaseModule
|
@@ -77,8 +78,11 @@ def __init__(self,
|
77 | 78 | mlp_ratio=4.,
|
78 | 79 | linear_pw_conv=True,
|
79 | 80 | drop_path_rate=0.,
|
80 |
| - layer_scale_init_value=1e-6): |
| 81 | + layer_scale_init_value=1e-6, |
| 82 | + with_cp=False): |
81 | 83 | super().__init__()
|
| 84 | + self.with_cp = with_cp |
| 85 | + |
82 | 86 | self.depthwise_conv = nn.Conv2d(
|
83 | 87 | in_channels,
|
84 | 88 | in_channels,
|
@@ -108,24 +112,33 @@ def __init__(self,
|
108 | 112 | drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
109 | 113 |
|
110 | 114 | def forward(self, x):
|
111 |
| - shortcut = x |
112 |
| - x = self.depthwise_conv(x) |
113 |
| - x = self.norm(x) |
114 | 115 |
|
115 |
| - if self.linear_pw_conv: |
116 |
| - x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) |
| 116 | + def _inner_forward(x): |
| 117 | + shortcut = x |
| 118 | + x = self.depthwise_conv(x) |
| 119 | + x = self.norm(x) |
117 | 120 |
|
118 |
| - x = self.pointwise_conv1(x) |
119 |
| - x = self.act(x) |
120 |
| - x = self.pointwise_conv2(x) |
| 121 | + if self.linear_pw_conv: |
| 122 | + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) |
121 | 123 |
|
122 |
| - if self.linear_pw_conv: |
123 |
| - x = x.permute(0, 3, 1, 2) # permute back |
| 124 | + x = self.pointwise_conv1(x) |
| 125 | + x = self.act(x) |
| 126 | + x = self.pointwise_conv2(x) |
124 | 127 |
|
125 |
| - if self.gamma is not None: |
126 |
| - x = x.mul(self.gamma.view(1, -1, 1, 1)) |
| 128 | + if self.linear_pw_conv: |
| 129 | + x = x.permute(0, 3, 1, 2) # permute back |
| 130 | + |
| 131 | + if self.gamma is not None: |
| 132 | + x = x.mul(self.gamma.view(1, -1, 1, 1)) |
| 133 | + |
| 134 | + x = shortcut + self.drop_path(x) |
| 135 | + return x |
| 136 | + |
| 137 | + if self.with_cp and x.requires_grad: |
| 138 | + x = cp.checkpoint(_inner_forward, x) |
| 139 | + else: |
| 140 | + x = _inner_forward(x) |
127 | 141 |
|
128 |
| - x = shortcut + self.drop_path(x) |
129 | 142 | return x
|
130 | 143 |
|
131 | 144 |
|
@@ -169,6 +182,8 @@ class ConvNeXt(BaseBackbone):
|
169 | 182 | gap_before_final_norm (bool): Whether to globally average the feature
|
170 | 183 | map before the final norm layer. In the official repo, it's only
|
171 | 184 | used in classification task. Defaults to True.
|
| 185 | + with_cp (bool): Use checkpoint or not. Using checkpoint will save some |
| 186 | + memory while slowing down the training speed. Defaults to False. |
172 | 187 | init_cfg (dict, optional): Initialization config dict
|
173 | 188 | """ # noqa: E501
|
174 | 189 | arch_settings = {
|
@@ -206,6 +221,7 @@ def __init__(self,
|
206 | 221 | out_indices=-1,
|
207 | 222 | frozen_stages=0,
|
208 | 223 | gap_before_final_norm=True,
|
| 224 | + with_cp=False, |
209 | 225 | init_cfg=None):
|
210 | 226 | super().__init__(init_cfg=init_cfg)
|
211 | 227 |
|
@@ -288,8 +304,8 @@ def __init__(self,
|
288 | 304 | norm_cfg=norm_cfg,
|
289 | 305 | act_cfg=act_cfg,
|
290 | 306 | linear_pw_conv=linear_pw_conv,
|
291 |
| - layer_scale_init_value=layer_scale_init_value) |
292 |
| - for j in range(depth) |
| 307 | + layer_scale_init_value=layer_scale_init_value, |
| 308 | + with_cp=with_cp) for j in range(depth) |
293 | 309 | ])
|
294 | 310 | block_idx += depth
|
295 | 311 |
|
|
0 commit comments