-
Notifications
You must be signed in to change notification settings - Fork 70
/
FLAVR_arch.py
178 lines (119 loc) · 5.56 KB
/
FLAVR_arch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import math
import numpy as np
import importlib
import torch
import torch.nn as nn
import torch.nn.functional as F
from .resnet_3D import SEGating
def joinTensors(X1 , X2 , type="concat"):
if type == "concat":
return torch.cat([X1 , X2] , dim=1)
elif type == "add":
return X1 + X2
else:
return X1
class Conv_2d(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0, bias=False, batchnorm=False):
super().__init__()
self.conv = [nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)]
if batchnorm:
self.conv += [nn.BatchNorm2d(out_ch)]
self.conv = nn.Sequential(*self.conv)
def forward(self, x):
return self.conv(x)
class upConv3D(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride, padding, upmode="transpose" , batchnorm=False):
super().__init__()
self.upmode = upmode
if self.upmode=="transpose":
self.upconv = nn.ModuleList(
[nn.ConvTranspose3d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding),
SEGating(out_ch)
]
)
else:
self.upconv = nn.ModuleList(
[nn.Upsample(mode='trilinear', scale_factor=(1,2,2), align_corners=False),
nn.Conv3d(in_ch, out_ch , kernel_size=1 , stride=1),
SEGating(out_ch)
]
)
if batchnorm:
self.upconv += [nn.BatchNorm3d(out_ch)]
self.upconv = nn.Sequential(*self.upconv)
def forward(self, x):
return self.upconv(x)
class Conv_3d(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0, bias=True, batchnorm=False):
super().__init__()
self.conv = [nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
SEGating(out_ch)
]
if batchnorm:
self.conv += [nn.BatchNorm3d(out_ch)]
self.conv = nn.Sequential(*self.conv)
def forward(self, x):
return self.conv(x)
class upConv2D(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride, padding, upmode="transpose" , batchnorm=False):
super().__init__()
self.upmode = upmode
if self.upmode=="transpose":
self.upconv = [nn.ConvTranspose2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding)]
else:
self.upconv = [
nn.Upsample(mode='bilinear', scale_factor=2, align_corners=False),
nn.Conv2d(in_ch, out_ch , kernel_size=1 , stride=1)
]
if batchnorm:
self.upconv += [nn.BatchNorm2d(out_ch)]
self.upconv = nn.Sequential(*self.upconv)
def forward(self, x):
return self.upconv(x)
class UNet_3D_3D(nn.Module):
def __init__(self, block , n_inputs, n_outputs, batchnorm=False , joinType="concat" , upmode="transpose"):
super().__init__()
nf = [512 , 256 , 128 , 64]
out_channels = 3*n_outputs
self.joinType = joinType
self.n_outputs = n_outputs
growth = 2 if joinType == "concat" else 1
self.lrelu = nn.LeakyReLU(0.2, True)
unet_3D = importlib.import_module(".resnet_3D" , "model")
if n_outputs > 1:
unet_3D.useBias = True
self.encoder = getattr(unet_3D , block)(pretrained=False , bn=batchnorm)
self.decoder = nn.Sequential(
Conv_3d(nf[0], nf[1] , kernel_size=3, padding=1, bias=True, batchnorm=batchnorm),
upConv3D(nf[1]*growth, nf[2], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm),
upConv3D(nf[2]*growth, nf[3], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm),
Conv_3d(nf[3]*growth, nf[3] , kernel_size=3, padding=1, bias=True, batchnorm=batchnorm),
upConv3D(nf[3]*growth , nf[3], kernel_size=(3,4,4), stride=(1,2,2), padding=(1,1,1) , upmode=upmode, batchnorm=batchnorm)
)
self.feature_fuse = Conv_2d(nf[3]*n_inputs , nf[3] , kernel_size=1 , stride=1, batchnorm=batchnorm)
self.outconv = nn.Sequential(
nn.ReflectionPad2d(3),
nn.Conv2d(nf[3], out_channels , kernel_size=7 , stride=1, padding=0)
)
def forward(self, images):
images = torch.stack(images , dim=2)
## Batch mean normalization works slightly better than global mean normalization, thanks to https://github.com/myungsub/CAIN
mean_ = images.mean(2, keepdim=True).mean(3, keepdim=True).mean(4,keepdim=True)
images = images-mean_
x_0 , x_1 , x_2 , x_3 , x_4 = self.encoder(images)
dx_3 = self.lrelu(self.decoder[0](x_4))
dx_3 = joinTensors(dx_3 , x_3 , type=self.joinType)
dx_2 = self.lrelu(self.decoder[1](dx_3))
dx_2 = joinTensors(dx_2 , x_2 , type=self.joinType)
dx_1 = self.lrelu(self.decoder[2](dx_2))
dx_1 = joinTensors(dx_1 , x_1 , type=self.joinType)
dx_0 = self.lrelu(self.decoder[3](dx_1))
dx_0 = joinTensors(dx_0 , x_0 , type=self.joinType)
dx_out = self.lrelu(self.decoder[4](dx_0))
dx_out = torch.cat(torch.unbind(dx_out , 2) , 1)
out = self.lrelu(self.feature_fuse(dx_out))
out = self.outconv(out)
out = torch.split(out, dim=1, split_size_or_sections=3)
mean_ = mean_.squeeze(2)
out = [o+mean_ for o in out]
return out