-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathRefinementNet_core.py
71 lines (58 loc) · 3.2 KB
/
RefinementNet_core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
from torch import nn
from DFNet_core import get_norm, get_activation, Conv2dSame, ConvTranspose2dSame, UpBlock, EncodeBlock, DecodeBlock
from builtins import *
class RefinementNet(nn.Module):
def __init__(
self, c_img=19, c_mask=1,
mode='nearest', norm='batch', act_en='relu', act_de='leaky_relu',
en_ksize=[7, 5, 5, 3, 3, 3, 3, 3], de_ksize=[3]*8):
super(RefinementNet, self).__init__()
c_in = c_img + c_mask
self.en1 = EncodeBlock(c_in, 96, en_ksize[0], 2, None, None)
self.en2 = EncodeBlock(96, 192, en_ksize[1], stride=2, normalization=norm, activation=act_en)
self.en3 = EncodeBlock(192, 384, en_ksize[2], stride=2, normalization=norm, activation=act_en)
self.en4 = EncodeBlock(384, 512, en_ksize[3], stride=2, normalization=norm, activation=act_en)
self.en5 = EncodeBlock(512, 512, en_ksize[4], stride=2, normalization=norm, activation=act_en)
self.en6 = EncodeBlock(512, 512, en_ksize[5], stride=2, normalization=norm, activation=act_en)
self.en7 = EncodeBlock(512, 512, en_ksize[6], stride=2, normalization=norm, activation=act_en)
self.en8 = EncodeBlock(512, 512, en_ksize[7], stride=2, normalization=norm, activation=act_en)
self.de1 = DecodeBlock(512, 512, 512, mode, 3, scale=2,normalization=norm, activation=act_de)
self.de2 = DecodeBlock(512, 512, 512, mode, 3, scale=2,normalization=norm, activation=act_de)
self.de3 = DecodeBlock(512, 512, 512, mode, 3, scale=2,normalization=norm, activation=act_de)
self.de4 = DecodeBlock(512, 512, 512, mode, 3, scale=2,normalization=norm, activation=act_de)
self.de5 = DecodeBlock(512, 384, 384, mode, 3, scale=2,normalization=norm, activation=act_de)
self.de6 = DecodeBlock(384, 192, 192, mode, 3, scale=2,normalization=norm, activation=act_de)
self.de7 = DecodeBlock(192, 96, 96, mode, 3, scale=2,normalization=norm, activation=act_de)
self.de8 = DecodeBlock(96, 20, 20, mode, 3, scale=2,normalization=norm, activation=act_de)
self.last_conv = nn.Sequential(Conv2dSame(c_in, 3, 1, 1), nn.Sigmoid())
def forward(self, img, mask):
out = torch.cat([mask, img], dim=1)
out_en = [out]
out = self.en1(out)
out_en.append(out)
out = self.en2(out)
out_en.append(out)
out = self.en3(out)
out_en.append(out)
out = self.en4(out)
out_en.append(out)
out = self.en5(out)
out_en.append(out)
out = self.en6(out)
out_en.append(out)
out = self.en7(out)
out_en.append(out)
out = self.en8(out)
out_en.append(out)
out = self.de1(out, out_en[-0-2])
out = self.de2(out, out_en[-1-2])
out = self.de3(out, out_en[-2-2])
out = self.de4(out, out_en[-3-2])
out = self.de5(out, out_en[-4-2])
out = self.de6(out, out_en[-5-2])
out = self.de7(out, out_en[-6-2])
out = self.de8(out, out_en[-7-2])
output = self.last_conv(out)
output = mask * output + (1 - mask) * img[:, :3]
return output