Skip to content

Commit 1e439af

Browse files
committed
model code
1 parent 5016de9 commit 1e439af

File tree

2 files changed

+295
-0
lines changed

2 files changed

+295
-0
lines changed

efficientnet.py

+236
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
import math
2+
import mlconfig
3+
import torch
4+
import torch.nn as nn
5+
from torch.hub import load_state_dict_from_url
6+
'''
7+
### usage ###
8+
import efficientnet
9+
model = efficientnet.efficientnet_b0(pretrained=False, progress=True, num_classes=num_classes).to(device)
10+
'''
11+
model_urls = {
12+
'efficientnet_b0': 'https://www.dropbox.com/s/9wigibun8n260qm/efficientnet-b0-4cfa50.pth?dl=1',
13+
'efficientnet_b1': 'https://www.dropbox.com/s/6745ear79b1ltkh/efficientnet-b1-ef6aa7.pth?dl=1',
14+
'efficientnet_b2': 'https://www.dropbox.com/s/0dhtv1t5wkjg0iy/efficientnet-b2-7c98aa.pth?dl=1',
15+
'efficientnet_b3': 'https://www.dropbox.com/s/5uqok5gd33fom5p/efficientnet-b3-bdc7f4.pth?dl=1',
16+
'efficientnet_b4': 'https://www.dropbox.com/s/y2nqt750lixs8kc/efficientnet-b4-3e4967.pth?dl=1',
17+
'efficientnet_b5': 'https://www.dropbox.com/s/qxonlu3q02v9i47/efficientnet-b5-4c7978.pth?dl=1',
18+
'efficientnet_b6': None,
19+
'efficientnet_b7': None,
20+
}
21+
22+
params = {
23+
'efficientnet_b0': (1.0, 1.0, 224, 0.2),
24+
'efficientnet_b1': (1.0, 1.1, 240, 0.2),
25+
'efficientnet_b2': (1.1, 1.2, 260, 0.3),
26+
'efficientnet_b3': (1.2, 1.4, 300, 0.3),
27+
'efficientnet_b4': (1.4, 1.8, 380, 0.4),
28+
'efficientnet_b5': (1.6, 2.2, 456, 0.4),
29+
'efficientnet_b6': (1.8, 2.6, 528, 0.5),
30+
'efficientnet_b7': (2.0, 3.1, 600, 0.5),
31+
}
32+
33+
class Swish(nn.Module):
34+
def __init__(self, *args, **kwargs):
35+
super(Swish, self).__init__()
36+
37+
def forward(self, x):
38+
return x * torch.sigmoid(x)
39+
40+
class ConvBNReLU(nn.Sequential):
41+
def __init__(self, in_planes, out_planes, kernel_size, stride=1, groups=1):
42+
padding = self._get_padding(kernel_size, stride)
43+
super(ConvBNReLU, self).__init__(nn.ZeroPad2d(padding),
44+
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding=0, groups=groups, bias=False),
45+
nn.BatchNorm2d(out_planes),
46+
Swish())
47+
def _get_padding(self, kernel_size, stride):
48+
p = max(kernel_size - stride, 0)
49+
return [p//2, p-p//2, p//2, p-p//2]
50+
51+
class SqueezeExcitation(nn.Module):
52+
53+
def __init__(self, in_planes, reduced_dim):
54+
super(SqueezeExcitation, self).__init__()
55+
self.se = nn.Sequential(
56+
nn.AdaptiveAvgPool2d(1),
57+
nn.Conv2d(in_planes, reduced_dim, 1),
58+
Swish(),
59+
nn.Conv2d(reduced_dim, in_planes, 1),
60+
nn.Sigmoid(),
61+
)
62+
63+
def forward(self, x):
64+
return x * self.se(x)
65+
66+
class MBConvBlock(nn.Module):
67+
def __init__(self, in_planes, out_planes, expand_ratio, kernel_size, stride, reduction_ratio=4, drop_connect_rate=0.2):
68+
super(MBConvBlock, self).__init__()
69+
self.drop_connect_rate = drop_connect_rate
70+
self.use_residual = in_planes == out_planes and stride == 1
71+
assert stride in [1, 2]
72+
assert kernel_size in [3, 5]
73+
74+
hidden_dim = in_planes * expand_ratio
75+
reduced_dim = max(1, int(in_planes / reduction_ratio))
76+
77+
layers = []
78+
# pw
79+
if in_planes != hidden_dim:
80+
layers += [ConvBNReLU(in_planes, hidden_dim, 1)]
81+
82+
layers += [
83+
# dw
84+
ConvBNReLU(hidden_dim, hidden_dim, kernel_size, stride=stride, groups=hidden_dim),
85+
# se
86+
SqueezeExcitation(hidden_dim, reduced_dim),
87+
# pw-linear
88+
nn.Conv2d(hidden_dim, out_planes, 1, bias=False),
89+
nn.BatchNorm2d(out_planes)
90+
]
91+
92+
self.conv = nn.Sequential(*layers)
93+
94+
def _drop_connect(self, x):
95+
if not self.training:
96+
return x
97+
keep_prob = 1.0 - self.drop_connect_rate
98+
batch_size = x.size(0)
99+
random_tensor = keep_prob
100+
random_tensor += torch.rand(batch_size, 1, 1, 1, device=x.device)
101+
binary_tensor = random_tensor.floor()
102+
return x.div(keep_prob) * binary_tensor
103+
104+
def forward(self, x):
105+
if self.use_residual:
106+
return x + self._drop_connect(self.conv(x))
107+
else:
108+
return self.conv(x)
109+
110+
def _make_divisible(value, divisor=8):
111+
new_value = max(divisor, int(value + divisor / 2) // divisor * divisor)
112+
if new_value < 0.9 * value:
113+
new_value += divisor
114+
return new_value
115+
116+
def _round_filters(filters, width_mult):
117+
if width_mult == 1.0:
118+
return filters
119+
return int(_make_divisible(filters * width_mult))
120+
121+
def _round_repeats(repeats, depth_mult):
122+
if depth_mult == 1.0:
123+
return repeats
124+
return int(math.ceil(depth_mult * repeats))
125+
126+
@mlconfig.register
127+
class EfficientNet(nn.Module):
128+
def __init__(self, width_mult=1.0, depth_mult=1.0, dropout_rate=0.2, num_classes=1000):
129+
super(EfficientNet, self).__init__()
130+
131+
# yapf : disable
132+
settings = [
133+
# t, c, n, s, k
134+
[1, 16, 1, 1, 3], # MBConv1_3x3, SE, 112 -> 112
135+
[6, 24, 2, 2, 3], # MBConv6_3x3, SE, 112 -> 56
136+
[6, 40, 2, 2, 5], # MBConv6_5x5, SE, 56 -> 28
137+
[6, 80, 3, 2, 3], # MBConv6_3x3, SE, 28 -> 14
138+
[6, 112, 3, 1, 5], # MBConv6_5x5, SE, 14 -> 14
139+
[6, 192, 4, 2, 5], # MBConv6_5x5, SE, 14 -> 7
140+
[6, 320, 1, 1, 3] # MBConv6_3x3, SE, 7 -> 7
141+
]
142+
# yapf : enable
143+
144+
out_channels = _round_filters(32, width_mult)
145+
features = [ConvBNReLU(3, out_channels, 3, stride=2)] # gray = 1, RGB = 3
146+
147+
in_channels = out_channels
148+
for t, c, n, s, k in settings:
149+
out_channels = _round_filters(c, width_mult)
150+
repeats = _round_repeats(n, depth_mult)
151+
for i in range(repeats):
152+
stride = s if i == 0 else 1
153+
features += [MBConvBlock(in_channels, out_channels, expand_ratio=t, stride=stride, kernel_size=k)]
154+
in_channels = out_channels
155+
156+
last_channels = _round_filters(1280, width_mult)
157+
features += [ConvBNReLU(in_channels, last_channels, 1)]
158+
159+
self.features = nn.Sequential(*features)
160+
self.classifier = nn.Sequential(
161+
nn.Dropout(dropout_rate),
162+
nn.Linear(last_channels, num_classes)
163+
)
164+
165+
# weight initialization
166+
for m in self.modules():
167+
if isinstance(m, nn.Conv2d):
168+
nn.init.kaiming_normal_(m.weight, mode='fan_out')
169+
if m.bias is not None:
170+
nn.init.zeros_(m.bias)
171+
elif isinstance(m, nn.BatchNorm2d):
172+
nn.init.ones_(m.weight)
173+
nn.init.zeros_(m.bias)
174+
elif isinstance(m, nn.Linear):
175+
fan_out = m.weight.size(0)
176+
init_range = 1.0 / math.sqrt(fan_out)
177+
nn.init.uniform_(m.weight, -init_range, init_range)
178+
if m.bias is not None:
179+
nn.init.zeros_(m.bias)
180+
181+
def forward(self, x):
182+
x = self.features(x)
183+
x = x.mean([2,3])
184+
x = self.classifier(x)
185+
return x
186+
187+
def _efficientnet(arch, pretrained, progress, **kwargs):
188+
width_mult, depth_mult, _, dropout_rate = params[arch]
189+
model = EfficientNet(width_mult, depth_mult, dropout_rate, **kwargs)
190+
if pretrained:
191+
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
192+
if 'num_classes' in kwargs and kwargs['num_classes'] != 1000:
193+
del state_dict['features.0.1.weight']
194+
del state_dict['classifier.1.weight']
195+
del state_dict['classifier.1.bias']
196+
model.load_state_dict(state_dict, strict=False)
197+
return model
198+
199+
@mlconfig.register
200+
def efficientnet_b0(pretrained=False, progress=True, **kwargs):
201+
return _efficientnet('efficientnet_b0', pretrained, progress, **kwargs)
202+
203+
@mlconfig.register
204+
def efficientnet_b1(pretrained=False, progress=True, **kwargs):
205+
return _efficientnet('efficientnet_b1', pretrained, progress, **kwargs)
206+
207+
208+
@mlconfig.register
209+
def efficientnet_b2(pretrained=False, progress=True, **kwargs):
210+
return _efficientnet('efficientnet_b2', pretrained, progress, **kwargs)
211+
212+
213+
@mlconfig.register
214+
def efficientnet_b3(pretrained=False, progress=True, **kwargs):
215+
return _efficientnet('efficientnet_b3', pretrained, progress, **kwargs)
216+
217+
218+
@mlconfig.register
219+
def efficientnet_b4(pretrained=False, progress=True, **kwargs):
220+
return _efficientnet('efficientnet_b4', pretrained, progress, **kwargs)
221+
222+
223+
@mlconfig.register
224+
def efficientnet_b5(pretrained=False, progress=True, **kwargs):
225+
return _efficientnet('efficientnet_b5', pretrained, progress, **kwargs)
226+
227+
228+
@mlconfig.register
229+
def efficientnet_b6(pretrained=False, progress=True, **kwargs):
230+
return _efficientnet('efficientnet_b6', pretrained, progress, **kwargs)
231+
232+
233+
@mlconfig.register
234+
def efficientnet_b7(pretrained=False, progress=True, **kwargs):
235+
return _efficientnet('efficientnet_b7', pretrained, progress, **kwargs)
236+

model.py

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from efficientnet import efficientnet_b3
5+
6+
7+
class BaseModel(nn.Module):
8+
def __init__(self, num_classes):
9+
super().__init__()
10+
11+
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=1)
12+
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1)
13+
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1)
14+
self.dropout1 = nn.Dropout(0.25)
15+
self.dropout2 = nn.Dropout(0.25)
16+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
17+
self.fc = nn.Linear(128, num_classes)
18+
19+
def forward(self, x):
20+
x = self.conv1(x)
21+
x = F.relu(x)
22+
23+
x = self.conv2(x)
24+
x = F.relu(x)
25+
x = F.max_pool2d(x, 2)
26+
x = self.dropout1(x)
27+
28+
x = self.conv3(x)
29+
x = F.relu(x)
30+
x = F.max_pool2d(x, 2)
31+
x = self.dropout2(x)
32+
33+
x = self.avgpool(x)
34+
x = x.view(-1, 128)
35+
return self.fc(x)
36+
37+
38+
class EnsembleModel(nn.Module):
39+
def __init__(self, num_classes):
40+
super(EnsembleModel, self).__init__()
41+
self.feature = efficientnet_b3(pretrained=True, progress=True, num_classes=num_classes).features
42+
self.classifier1 = nn.Sequential(
43+
nn.Dropout(0.2),
44+
nn.Linear(1536, 3)) # mask classifier
45+
self.classifier2 = nn.Sequential(
46+
nn.Dropout(0.2),
47+
nn.Linear(1536, 3)) # gender classifier
48+
self.classifier3 = nn.Sequential(
49+
nn.Dropout(0.2),
50+
nn.Linear(1536, 3)) # age classifier
51+
52+
def forward(self, x):
53+
x = self.feature(x)
54+
x = x.mean([2, 3])
55+
x1 = self.classifier1(x)
56+
x2 = self.classifier2(x)
57+
x3 = self.classifier3(x)
58+
return (x1, x2, x3)
59+

0 commit comments

Comments
 (0)