Skip to content

Commit

Permalink
Modify ASMFNet
Browse files Browse the repository at this point in the history
  • Loading branch information
sstary committed Nov 17, 2024
1 parent 68fe43a commit 89a499a
Showing 1 changed file with 42 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -325,16 +325,10 @@ def forward(self, x):
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C,从列维度
# print("111")
# print(x0.shape)

x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C

x = self.norm(x)
# print("x1==")
# print(x.shape)
x = self.reduction(x)
# print("x==")
# print(x.shape)
return x

def extra_repr(self) -> str:
Expand Down Expand Up @@ -643,7 +637,7 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
# build encoder and bottleneck layers
self.layers = nn.ModuleList()
self.layersd = nn.ModuleList()
self.layerFu = nn.ModuleList(([ModifyFusion2(3136,96,0),ModifyFusion2(784,192,0), ModifyFusion2(196,384,0),ModifyFusion2(49,768,1)]))
self.layerFu = nn.ModuleList(([AMF(3136,96,0),AMF(784,192,0), AMF(196,384,0),AMF(49,768,1)]))
for i_layer in range(self.num_layers):
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
input_resolution=(patches_resolution[0] // (2 ** i_layer),
Expand Down Expand Up @@ -757,16 +751,11 @@ def forward_features(self, x, y):

x = self.pos_drop(x)
y = self.pos_drop(y)
# 这里的作用是跳跃连接,将原始数据加到decoder里面,为
fusion_downsample = []
res = x
f_res=x
count=0
for layer, layerd,layerf in zip(self.layers, self.layersd,self.layerFu):
# print(layer)
# print(count)
# 为skip_connection进行融合
# 输入Swinunet进行迭代
tempx=x
tempy=y
x = layer(x)
Expand Down Expand Up @@ -851,23 +840,15 @@ def __init__(self, channel, k_size=3):

def forward(self, x):
b, c, h,w = x.size()
# feature descriptor on the global spatial information
y = self.avg_pool(x)

# Two different branches of ECA module
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)

# Multi-scale information fusion
y = self.sigmoid(y)
y = x * y.expand_as(x)

#out = y + x
return y.view(b,c,h*w).permute(0,2,1).contiguous()

class SpFusion(nn.Module):
class AdaptiveFusion(nn.Module):
def __init__(self, third):
# 直接attention,之后加回去
super(SpFusion, self).__init__()
super(AdaptiveFusion, self).__init__()


self.avg_pool = nn.AdaptiveAvgPool2d(1)
Expand All @@ -886,49 +867,39 @@ def __init__(self, third):
)


# 输出是B HW C
def forward(self, rgb, depth):
# 输入是B C H W
b,c,h,w=rgb.size()

fusion = torch.cat([rgb,depth],dim = 1)
fusion = self.avg_pool(fusion)

fusion = fusion.view(b,2*c)

fusion = self.fc(fusion).view(b,c,1,1)

fusion = depth*fusion + rgb
fusion = self.conv1(fusion)

f1 = torch.cat([rgb,fusion],dim = 1)
f2 = torch.cat([depth,fusion],dim = 1)

f1 = self.af1(f1)
f2 = self.af2(f2)


f = F.softmax(torch.cat([f1,f2],dim = 1), dim =1)

res1 = rgb*f
res2 = depth*f

res= res1+res2

return res

class ModifyFusion2(nn.Module):
class AMF(nn.Module):
def __init__(self, second, third,count1):
# 直接attention,之后加回去
super(ModifyFusion2, self).__init__()
super(AMF, self).__init__()

self.atten_rgb_1 = eca_layer(third)
self.atten_depth_1 = eca_layer(third)
self.fusion_conv1 = SpFusion(third) if count1 else None
self.AdaptiveFusion = AdaptiveFusion(third) if count1 else None
self.lin1 = nn.Linear(third, third)
self.lin2 = nn.Linear(third, third)
if count1 ==0:
self.cem1 = CEM(third, third<<1, int(math.sqrt(second)))
self.cem2 = CEM(third, third<<1, int(math.sqrt(second)))
self.hsa1 = HSA(third, third<<1, int(math.sqrt(second)))
self.hsa2 = HSA(third, third<<1, int(math.sqrt(second)))


# 输出是B HW C
Expand All @@ -942,59 +913,58 @@ def forward(self, rgb,rgb2, depth,depth2, count):
int(math.sqrt(depth2.size(1))))

if count == 0:

cem_rgb = self.cem1(rgb1,rgb2)
cem_depth = self.cem2(depth1,depth2)
hsa_rgb = self.hsa1(rgb1,rgb2)
hsa_depth = self.hsa2(depth1,depth2)
atten_rgb = self.atten_rgb_1(rgb1)
atten_depth = self.atten_depth_1(depth1)
if self.fusion_conv1 is not None:
cem_fused = self.fusion_conv1(cem_rgb,cem_depth)
if self.AdaptiveFusion is not None:
hsa_fused = self.AdaptiveFusion(hsa_rgb,hsa_depth)
else:
cem_fused = cem_rgb+cem_depth
cem_fused = cem_fused.view(rgb.size(0), rgb.size(2), -1).permute(0,2,1).contiguous()
hsa_fused = hsa_rgb+hsa_depth
hsa_fused = hsa_fused.view(rgb.size(0), rgb.size(2), -1).permute(0,2,1).contiguous()

m1 = self.lin1(atten_rgb + atten_depth) + self.lin2(hsa_fused)

m1 = self.lin1(atten_rgb + atten_depth) + self.lin2(cem_fused)
elif count == 1:

cem_rgb = self.cem1(rgb1,rgb2)
cem_depth = self.cem2(depth1,depth2)
hsa_rgb = self.hsa1(rgb1,rgb2)
hsa_depth = self.hsa2(depth1,depth2)
atten_rgb = self.atten_rgb_1(rgb1)
atten_depth = self.atten_depth_1(depth1)
if self.fusion_conv1 is not None:
cem_fused = self.fusion_conv1(cem_rgb,cem_depth)
if self.AdaptiveFusion is not None:
hsa_fused = self.AdaptiveFusion(hsa_rgb,hsa_depth)
else:
cem_fused = cem_rgb+cem_depth
cem_fused = cem_fused.view(rgb.size(0), rgb.size(2), -1).permute(0,2,1).contiguous()
hsa_fused = hsa_rgb+hsa_depth
hsa_fused = hsa_fused.view(rgb.size(0), rgb.size(2), -1).permute(0,2,1).contiguous()

m1 = self.lin1(atten_rgb + atten_depth) + self.lin2(cem_fused)
m1 = self.lin1(atten_rgb + atten_depth) + self.lin2(hsa_fused)

elif count == 2:

cem_rgb = self.cem1(rgb1,rgb2)
cem_depth = self.cem2(depth1,depth2)
hsa_rgb = self.hsa1(rgb1,rgb2)
hsa_depth = self.hsa2(depth1,depth2)
atten_rgb = self.atten_rgb_1(rgb1)
atten_depth = self.atten_depth_1(depth1)
if self.fusion_conv1 is not None:
cem_fused = self.fusion_conv1(cem_rgb,cem_depth)
if self.AdaptiveFusion is not None:
hsa_fused = self.AdaptiveFusion(hsa_rgb,hsa_depth)
else:
cem_fused = cem_rgb+cem_depth
cem_fused = cem_fused.view(rgb.size(0), rgb.size(2), -1).permute(0,2,1).contiguous()
hsa_fused = hsa_rgb+hsa_depth
hsa_fused = hsa_fused.view(rgb.size(0), rgb.size(2), -1).permute(0,2,1).contiguous()

m1 = self.lin1(atten_rgb + atten_depth) + self.lin2(hsa_fused)

m1 = self.lin1(atten_rgb + atten_depth) + self.lin2(cem_fused)
else:
self.cem1 =nn.Sequential()
self.cem2 =nn.Sequential()
self.hsa1 =nn.Sequential()
self.hsa2 =nn.Sequential()
atten_rgb = self.atten_rgb_1(rgb1)
atten_depth = self.atten_depth_1(depth1)

m1 =atten_rgb + atten_depth

return m1

class CEM(nn.Module):
class HSA(nn.Module):

def __init__(self,feature1,feature2,size1):
super(CEM, self).__init__()
super(HSA, self).__init__()
self.size = size1
self.conv1 = nn.Conv2d(feature1, feature1, kernel_size=1, stride=1, padding=0)
self.conv2 = nn.Conv2d(feature1*2, feature1, kernel_size=3, stride=1, padding=1)
Expand All @@ -1011,12 +981,12 @@ def __init__(self,feature1,feature2,size1):
self.countr =1
self.countd =1

def oneFusion(self, rgb1, rgb2):
c4 = rgb1
def CrossScaleFusion(self, mol1, mol2):
c4 = mol1
c4_lat = self.conv1(c4)

# upsample x2
c5 =rgb2
c5 =mol2
c5_lat = self.upSampling2x2(c5)
c5_lat = self.relu1(c5_lat)

Expand All @@ -1030,11 +1000,9 @@ def oneFusion(self, rgb1, rgb2):
return out


def forward(self, rgb1,rgb2):
b,c,h,w=rgb1.size()
rgb=self.oneFusion(rgb1,rgb2)
def forward(self, mol1,mol2):


rgb=self.CrossScaleFusion(mol1,mol2)
return rgb


Expand Down

0 comments on commit 89a499a

Please sign in to comment.