-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
Copy pathdecoder.py
72 lines (55 loc) · 1.92 KB
/
decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..base import modules
class PSPBlock(nn.Module):
def __init__(self, in_channels, out_channels, pool_size, use_bathcnorm=True):
super().__init__()
if pool_size == 1:
use_bathcnorm = False # PyTorch does not support BatchNorm for 1x1 shape
self.pool = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size)),
modules.Conv2dReLU(in_channels, out_channels, (1, 1), use_batchnorm=use_bathcnorm)
)
def forward(self, x):
h, w = x.size(2), x.size(3)
x = self.pool(x)
x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
return x
class PSPModule(nn.Module):
def __init__(self, in_channels, sizes=(1, 2, 3, 6), use_bathcnorm=True):
super().__init__()
self.blocks = nn.ModuleList([
PSPBlock(in_channels, in_channels // len(sizes), size, use_bathcnorm=use_bathcnorm) for size in sizes
])
def forward(self, x):
xs = [block(x) for block in self.blocks] + [x]
x = torch.cat(xs, dim=1)
return x
class PSPDecoder(nn.Module):
def __init__(
self,
encoder_channels,
use_batchnorm=True,
out_channels=512,
dropout=0.2,
):
super().__init__()
self.psp = PSPModule(
in_channels=encoder_channels[-1],
sizes=(1, 2, 3, 6),
use_bathcnorm=use_batchnorm,
)
self.conv = modules.Conv2dReLU(
in_channels=encoder_channels[-1] * 2,
out_channels=out_channels,
kernel_size=1,
use_batchnorm=use_batchnorm,
)
self.dropout = nn.Dropout2d(p=dropout)
def forward(self, *features):
x = features[-1]
x = self.psp(x)
x = self.conv(x)
x = self.dropout(x)
return x