-
Notifications
You must be signed in to change notification settings - Fork 9
/
unetd.py
60 lines (57 loc) · 2.31 KB
/
unetd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class UnetD(nn.Module):
def __init__(self,nefilters=24):
super(UnetD, self).__init__()
nlayers = 12
self.num_layers = nlayers
self.nefilters = nefilters
filter_size = 15
merge_filter_size = 5
self.upsampling = 'linear'
self.output_type = 'difference'
self.context = True
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
self.ebatch = nn.ModuleList()
self.dbatch = nn.ModuleList()
echannelin = [1] + [(i + 1) * nefilters for i in range(nlayers-1)]
echannelout = [(i + 1) * nefilters for i in range(nlayers)]
dchannelout = echannelout[::-1]
dchannelin = [dchannelout[0]*2]+[(i) * nefilters + (i - 1) * nefilters for i in range(nlayers,1,-1)]
for i in range(self.num_layers):
self.encoder.append(nn.Conv1d(echannelin[i],echannelout[i],filter_size,dilation=2,padding=filter_size//2*2))
self.decoder.append(nn.Conv1d(dchannelin[i],dchannelout[i],merge_filter_size,padding=merge_filter_size//2))
self.ebatch.append(nn.BatchNorm1d(echannelout[i]))
self.dbatch.append(nn.BatchNorm1d(dchannelout[i]))
self.middle=nn.Conv1d(echannelout[-1],echannelout[-1],filter_size,padding=filter_size//2)
self.out = nn.Conv1d(nefilters+1,1,1)
def forward(self,x):
encoder = list()
input = x
for i in range(self.num_layers):
x = self.encoder[i](x)
x = self.ebatch[i](x)
x = F.leaky_relu(x,0.1)
encoder.append(x)
x = x[:,:,::2]
x = self.middle(x)
x = F.leaky_relu(x, 0.1)
#print(x.shape)
for i in range(self.num_layers):
#print(x.shape)
#x = torch.unsqueeze(x, 2)
#print(x.shape)
x = F.upsample(x,scale_factor=2,mode='linear')
#print(i,x.shape,encoder[self.num_layers - i - 1].shape)
#x = torch.squeeze(x, 2)
x = torch.cat([x,encoder[self.num_layers - i - 1]],dim=1)
x = self.decoder[i](x)
x = self.dbatch[i](x)
x = F.leaky_relu(x,0.1)
x = torch.cat([x,input],dim=1)
x = self.out(x)
x = F.tanh(x)
return x