forked from ai-dawang/PlugNPlay-Modules
-
Notifications
You must be signed in to change notification settings - Fork 0
/
(ICCV2023)SAFM.py
47 lines (38 loc) · 1.35 KB
/
(ICCV2023)SAFM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torch.nn as nn
import torch.nn.functional as F
#https://github.com/sunny2109/SAFMN
#论文:https://arxiv.org/pdf/2302.13800
class SAFM(nn.Module):
def __init__(self, dim, n_levels=4):
super().__init__()
self.n_levels = n_levels
chunk_dim = dim // n_levels
# Spatial Weighting
self.mfr = nn.ModuleList(
[nn.Conv2d(chunk_dim, chunk_dim, 3, 1, 1, groups=chunk_dim) for i in range(self.n_levels)])
# # Feature Aggregation
self.aggr = nn.Conv2d(dim, dim, 1, 1, 0)
# Activation
self.act = nn.GELU()
def forward(self, x):
h, w = x.size()[-2:]
xc = x.chunk(self.n_levels, dim=1)
out = []
for i in range(self.n_levels):
if i > 0:
p_size = (h // 2 ** i, w // 2 ** i)
s = F.adaptive_max_pool2d(xc[i], p_size)
s = self.mfr[i](s)
s = F.interpolate(s, size=(h, w), mode='nearest')
else:
s = self.mfr[i](xc[i])
out.append(s)
out = self.aggr(torch.cat(out, dim=1))
out = self.act(out) * x
return out
if __name__ == '__main__':
input = torch.randn(3,36,64,64) #输入b c h w
block = SAFM(dim=36)
output =block(input)
print(output.size())