forked from z1069614715/objectdetection_script
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMLCA.py
56 lines (43 loc) · 2.22 KB
/
MLCA.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import math, torch
from torch import nn
import torch.nn.functional as F
class MLCA(nn.Module):
def __init__(self, in_size, local_size=5, gamma = 2, b = 1,local_weight=0.5):
super(MLCA, self).__init__()
# ECA 计算方法
self.local_size=local_size
self.gamma = gamma
self.b = b
t = int(abs(math.log(in_size, 2) + self.b) / self.gamma) # eca gamma=2
k = t if t % 2 else t + 1
self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
self.conv_local = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
self.local_weight=local_weight
self.local_arv_pool = nn.AdaptiveAvgPool2d(local_size)
self.global_arv_pool=nn.AdaptiveAvgPool2d(1)
def forward(self, x):
local_arv=self.local_arv_pool(x)
global_arv=self.global_arv_pool(local_arv)
b,c,m,n = x.shape
b_local, c_local, m_local, n_local = local_arv.shape
# (b,c,local_size,local_size) -> (b,c,local_size*local_size) -> (b,local_size*local_size,c) -> (b,1,local_size*local_size*c)
temp_local= local_arv.view(b, c_local, -1).transpose(-1, -2).reshape(b, 1, -1)
# (b,c,1,1) -> (b,c,1) -> (b,1,c)
temp_global = global_arv.view(b, c, -1).transpose(-1, -2)
y_local = self.conv_local(temp_local)
y_global = self.conv(temp_global)
# (b,c,local_size,local_size) <- (b,c,local_size*local_size)<-(b,local_size*local_size,c) <- (b,1,local_size*local_size*c)
y_local_transpose=y_local.reshape(b, self.local_size * self.local_size,c).transpose(-1,-2).view(b, c, self.local_size , self.local_size)
# (b,1,c) -> (b,c,1) -> (b,c,1,1)
y_global_transpose = y_global.transpose(-1,-2).unsqueeze(-1)
# 反池化
att_local = y_local_transpose.sigmoid()
att_global = F.adaptive_avg_pool2d(y_global_transpose.sigmoid(),[self.local_size, self.local_size])
att_all = F.adaptive_avg_pool2d(att_global*(1-self.local_weight)+(att_local*self.local_weight), [m, n])
x = x * att_all
return x
if __name__ == '__main__':
attention = MLCA(in_size=256)
inputs = torch.randn((2, 256, 16, 16))
result = attention(inputs)
print(result.size())