|
| 1 | +import math |
| 2 | +import mlconfig |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +from torch.hub import load_state_dict_from_url |
| 6 | +''' |
| 7 | +### usage ### |
| 8 | +import efficientnet |
| 9 | +model = efficientnet.efficientnet_b0(pretrained=False, progress=True, num_classes=num_classes).to(device) |
| 10 | +''' |
| 11 | +model_urls = { |
| 12 | + 'efficientnet_b0': 'https://www.dropbox.com/s/9wigibun8n260qm/efficientnet-b0-4cfa50.pth?dl=1', |
| 13 | + 'efficientnet_b1': 'https://www.dropbox.com/s/6745ear79b1ltkh/efficientnet-b1-ef6aa7.pth?dl=1', |
| 14 | + 'efficientnet_b2': 'https://www.dropbox.com/s/0dhtv1t5wkjg0iy/efficientnet-b2-7c98aa.pth?dl=1', |
| 15 | + 'efficientnet_b3': 'https://www.dropbox.com/s/5uqok5gd33fom5p/efficientnet-b3-bdc7f4.pth?dl=1', |
| 16 | + 'efficientnet_b4': 'https://www.dropbox.com/s/y2nqt750lixs8kc/efficientnet-b4-3e4967.pth?dl=1', |
| 17 | + 'efficientnet_b5': 'https://www.dropbox.com/s/qxonlu3q02v9i47/efficientnet-b5-4c7978.pth?dl=1', |
| 18 | + 'efficientnet_b6': None, |
| 19 | + 'efficientnet_b7': None, |
| 20 | +} |
| 21 | + |
| 22 | +params = { |
| 23 | + 'efficientnet_b0': (1.0, 1.0, 224, 0.2), |
| 24 | + 'efficientnet_b1': (1.0, 1.1, 240, 0.2), |
| 25 | + 'efficientnet_b2': (1.1, 1.2, 260, 0.3), |
| 26 | + 'efficientnet_b3': (1.2, 1.4, 300, 0.3), |
| 27 | + 'efficientnet_b4': (1.4, 1.8, 380, 0.4), |
| 28 | + 'efficientnet_b5': (1.6, 2.2, 456, 0.4), |
| 29 | + 'efficientnet_b6': (1.8, 2.6, 528, 0.5), |
| 30 | + 'efficientnet_b7': (2.0, 3.1, 600, 0.5), |
| 31 | +} |
| 32 | + |
| 33 | +class Swish(nn.Module): |
| 34 | + def __init__(self, *args, **kwargs): |
| 35 | + super(Swish, self).__init__() |
| 36 | + |
| 37 | + def forward(self, x): |
| 38 | + return x * torch.sigmoid(x) |
| 39 | + |
| 40 | +class ConvBNReLU(nn.Sequential): |
| 41 | + def __init__(self, in_planes, out_planes, kernel_size, stride=1, groups=1): |
| 42 | + padding = self._get_padding(kernel_size, stride) |
| 43 | + super(ConvBNReLU, self).__init__(nn.ZeroPad2d(padding), |
| 44 | + nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding=0, groups=groups, bias=False), |
| 45 | + nn.BatchNorm2d(out_planes), |
| 46 | + Swish()) |
| 47 | + def _get_padding(self, kernel_size, stride): |
| 48 | + p = max(kernel_size - stride, 0) |
| 49 | + return [p//2, p-p//2, p//2, p-p//2] |
| 50 | + |
| 51 | +class SqueezeExcitation(nn.Module): |
| 52 | + |
| 53 | + def __init__(self, in_planes, reduced_dim): |
| 54 | + super(SqueezeExcitation, self).__init__() |
| 55 | + self.se = nn.Sequential( |
| 56 | + nn.AdaptiveAvgPool2d(1), |
| 57 | + nn.Conv2d(in_planes, reduced_dim, 1), |
| 58 | + Swish(), |
| 59 | + nn.Conv2d(reduced_dim, in_planes, 1), |
| 60 | + nn.Sigmoid(), |
| 61 | + ) |
| 62 | + |
| 63 | + def forward(self, x): |
| 64 | + return x * self.se(x) |
| 65 | + |
| 66 | +class MBConvBlock(nn.Module): |
| 67 | + def __init__(self, in_planes, out_planes, expand_ratio, kernel_size, stride, reduction_ratio=4, drop_connect_rate=0.2): |
| 68 | + super(MBConvBlock, self).__init__() |
| 69 | + self.drop_connect_rate = drop_connect_rate |
| 70 | + self.use_residual = in_planes == out_planes and stride == 1 |
| 71 | + assert stride in [1, 2] |
| 72 | + assert kernel_size in [3, 5] |
| 73 | + |
| 74 | + hidden_dim = in_planes * expand_ratio |
| 75 | + reduced_dim = max(1, int(in_planes / reduction_ratio)) |
| 76 | + |
| 77 | + layers = [] |
| 78 | + # pw |
| 79 | + if in_planes != hidden_dim: |
| 80 | + layers += [ConvBNReLU(in_planes, hidden_dim, 1)] |
| 81 | + |
| 82 | + layers += [ |
| 83 | + # dw |
| 84 | + ConvBNReLU(hidden_dim, hidden_dim, kernel_size, stride=stride, groups=hidden_dim), |
| 85 | + # se |
| 86 | + SqueezeExcitation(hidden_dim, reduced_dim), |
| 87 | + # pw-linear |
| 88 | + nn.Conv2d(hidden_dim, out_planes, 1, bias=False), |
| 89 | + nn.BatchNorm2d(out_planes) |
| 90 | + ] |
| 91 | + |
| 92 | + self.conv = nn.Sequential(*layers) |
| 93 | + |
| 94 | + def _drop_connect(self, x): |
| 95 | + if not self.training: |
| 96 | + return x |
| 97 | + keep_prob = 1.0 - self.drop_connect_rate |
| 98 | + batch_size = x.size(0) |
| 99 | + random_tensor = keep_prob |
| 100 | + random_tensor += torch.rand(batch_size, 1, 1, 1, device=x.device) |
| 101 | + binary_tensor = random_tensor.floor() |
| 102 | + return x.div(keep_prob) * binary_tensor |
| 103 | + |
| 104 | + def forward(self, x): |
| 105 | + if self.use_residual: |
| 106 | + return x + self._drop_connect(self.conv(x)) |
| 107 | + else: |
| 108 | + return self.conv(x) |
| 109 | + |
| 110 | +def _make_divisible(value, divisor=8): |
| 111 | + new_value = max(divisor, int(value + divisor / 2) // divisor * divisor) |
| 112 | + if new_value < 0.9 * value: |
| 113 | + new_value += divisor |
| 114 | + return new_value |
| 115 | + |
| 116 | +def _round_filters(filters, width_mult): |
| 117 | + if width_mult == 1.0: |
| 118 | + return filters |
| 119 | + return int(_make_divisible(filters * width_mult)) |
| 120 | + |
| 121 | +def _round_repeats(repeats, depth_mult): |
| 122 | + if depth_mult == 1.0: |
| 123 | + return repeats |
| 124 | + return int(math.ceil(depth_mult * repeats)) |
| 125 | + |
| 126 | +@mlconfig.register |
| 127 | +class EfficientNet(nn.Module): |
| 128 | + def __init__(self, width_mult=1.0, depth_mult=1.0, dropout_rate=0.2, num_classes=1000): |
| 129 | + super(EfficientNet, self).__init__() |
| 130 | + |
| 131 | + # yapf : disable |
| 132 | + settings = [ |
| 133 | + # t, c, n, s, k |
| 134 | + [1, 16, 1, 1, 3], # MBConv1_3x3, SE, 112 -> 112 |
| 135 | + [6, 24, 2, 2, 3], # MBConv6_3x3, SE, 112 -> 56 |
| 136 | + [6, 40, 2, 2, 5], # MBConv6_5x5, SE, 56 -> 28 |
| 137 | + [6, 80, 3, 2, 3], # MBConv6_3x3, SE, 28 -> 14 |
| 138 | + [6, 112, 3, 1, 5], # MBConv6_5x5, SE, 14 -> 14 |
| 139 | + [6, 192, 4, 2, 5], # MBConv6_5x5, SE, 14 -> 7 |
| 140 | + [6, 320, 1, 1, 3] # MBConv6_3x3, SE, 7 -> 7 |
| 141 | + ] |
| 142 | + # yapf : enable |
| 143 | + |
| 144 | + out_channels = _round_filters(32, width_mult) |
| 145 | + features = [ConvBNReLU(3, out_channels, 3, stride=2)] # gray = 1, RGB = 3 |
| 146 | + |
| 147 | + in_channels = out_channels |
| 148 | + for t, c, n, s, k in settings: |
| 149 | + out_channels = _round_filters(c, width_mult) |
| 150 | + repeats = _round_repeats(n, depth_mult) |
| 151 | + for i in range(repeats): |
| 152 | + stride = s if i == 0 else 1 |
| 153 | + features += [MBConvBlock(in_channels, out_channels, expand_ratio=t, stride=stride, kernel_size=k)] |
| 154 | + in_channels = out_channels |
| 155 | + |
| 156 | + last_channels = _round_filters(1280, width_mult) |
| 157 | + features += [ConvBNReLU(in_channels, last_channels, 1)] |
| 158 | + |
| 159 | + self.features = nn.Sequential(*features) |
| 160 | + self.classifier = nn.Sequential( |
| 161 | + nn.Dropout(dropout_rate), |
| 162 | + nn.Linear(last_channels, num_classes) |
| 163 | + ) |
| 164 | + |
| 165 | + # weight initialization |
| 166 | + for m in self.modules(): |
| 167 | + if isinstance(m, nn.Conv2d): |
| 168 | + nn.init.kaiming_normal_(m.weight, mode='fan_out') |
| 169 | + if m.bias is not None: |
| 170 | + nn.init.zeros_(m.bias) |
| 171 | + elif isinstance(m, nn.BatchNorm2d): |
| 172 | + nn.init.ones_(m.weight) |
| 173 | + nn.init.zeros_(m.bias) |
| 174 | + elif isinstance(m, nn.Linear): |
| 175 | + fan_out = m.weight.size(0) |
| 176 | + init_range = 1.0 / math.sqrt(fan_out) |
| 177 | + nn.init.uniform_(m.weight, -init_range, init_range) |
| 178 | + if m.bias is not None: |
| 179 | + nn.init.zeros_(m.bias) |
| 180 | + |
| 181 | + def forward(self, x): |
| 182 | + x = self.features(x) |
| 183 | + x = x.mean([2,3]) |
| 184 | + x = self.classifier(x) |
| 185 | + return x |
| 186 | + |
| 187 | +def _efficientnet(arch, pretrained, progress, **kwargs): |
| 188 | + width_mult, depth_mult, _, dropout_rate = params[arch] |
| 189 | + model = EfficientNet(width_mult, depth_mult, dropout_rate, **kwargs) |
| 190 | + if pretrained: |
| 191 | + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) |
| 192 | + if 'num_classes' in kwargs and kwargs['num_classes'] != 1000: |
| 193 | + del state_dict['features.0.1.weight'] |
| 194 | + del state_dict['classifier.1.weight'] |
| 195 | + del state_dict['classifier.1.bias'] |
| 196 | + model.load_state_dict(state_dict, strict=False) |
| 197 | + return model |
| 198 | + |
| 199 | +@mlconfig.register |
| 200 | +def efficientnet_b0(pretrained=False, progress=True, **kwargs): |
| 201 | + return _efficientnet('efficientnet_b0', pretrained, progress, **kwargs) |
| 202 | + |
| 203 | +@mlconfig.register |
| 204 | +def efficientnet_b1(pretrained=False, progress=True, **kwargs): |
| 205 | + return _efficientnet('efficientnet_b1', pretrained, progress, **kwargs) |
| 206 | + |
| 207 | + |
| 208 | +@mlconfig.register |
| 209 | +def efficientnet_b2(pretrained=False, progress=True, **kwargs): |
| 210 | + return _efficientnet('efficientnet_b2', pretrained, progress, **kwargs) |
| 211 | + |
| 212 | + |
| 213 | +@mlconfig.register |
| 214 | +def efficientnet_b3(pretrained=False, progress=True, **kwargs): |
| 215 | + return _efficientnet('efficientnet_b3', pretrained, progress, **kwargs) |
| 216 | + |
| 217 | + |
| 218 | +@mlconfig.register |
| 219 | +def efficientnet_b4(pretrained=False, progress=True, **kwargs): |
| 220 | + return _efficientnet('efficientnet_b4', pretrained, progress, **kwargs) |
| 221 | + |
| 222 | + |
| 223 | +@mlconfig.register |
| 224 | +def efficientnet_b5(pretrained=False, progress=True, **kwargs): |
| 225 | + return _efficientnet('efficientnet_b5', pretrained, progress, **kwargs) |
| 226 | + |
| 227 | + |
| 228 | +@mlconfig.register |
| 229 | +def efficientnet_b6(pretrained=False, progress=True, **kwargs): |
| 230 | + return _efficientnet('efficientnet_b6', pretrained, progress, **kwargs) |
| 231 | + |
| 232 | + |
| 233 | +@mlconfig.register |
| 234 | +def efficientnet_b7(pretrained=False, progress=True, **kwargs): |
| 235 | + return _efficientnet('efficientnet_b7', pretrained, progress, **kwargs) |
| 236 | + |
0 commit comments