Skip to content

Commit

Permalink
feat: add pooling list to models to be flexible in design
Browse files Browse the repository at this point in the history
  • Loading branch information
santi-pdp committed Oct 16, 2018
1 parent a8e9979 commit cb4a1ed
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 44 deletions.
21 changes: 18 additions & 3 deletions segan/models/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(self, ninputs, kwidth, nfmaps,
dropout=0):
super().__init__()
self.kwidth = kwidth
self.pooling = pooling
seq_dict = OrderedDict()
self.conv = nn.Conv1d(ninputs, nfmaps, kwidth,
stride=pooling,
Expand All @@ -96,7 +97,10 @@ def __init__(self, ninputs, kwidth, nfmaps,
self.dout = nn.Dropout(dropout)

def forward(self, x):
x = F.pad(x, ((self.kwidth//2)-1, self.kwidth//2))
if self.pooling == 1:
x = F.pad(x, ((self.kwidth//2), self.kwidth//2))
else:
x = F.pad(x, ((self.kwidth//2)-1, self.kwidth//2))
conv_h = self.conv(x)
if self.bnorm:
conv_h = self.bn(conv_h)
Expand Down Expand Up @@ -211,16 +215,20 @@ def __init__(self, ninputs, d_fmaps, kwidth, activation,
if Genc is None:
if not isinstance(activation, list):
activation = [activation] * len(d_fmaps)
if not isinstance(pooling, list):
pooling = [pooling] * len(d_fmaps)
else:
assert len(pooling) == len(d_fmaps), len(pooling)
self.disc = nn.ModuleList()
for d_i, d_fmap in enumerate(d_fmaps):
for d_i, (d_fmap, pool) in enumerate(zip(d_fmaps, pooling)):
act = activation[d_i]
if d_i == 0:
inp = ninputs
else:
inp = d_fmaps[d_i - 1]
self.disc.append(DiscBlock(inp, kwidth, d_fmap,
act, bnorm,
pooling, SND,
pool, SND,
dropout))
else:
print('Assigning Genc to D')
Expand Down Expand Up @@ -261,6 +269,9 @@ def __init__(self, ninputs, d_fmaps, kwidth, activation,
elif pool_type == 'gmax':
self.gmax = nn.AdaptiveMaxPool1d(1)
self.fc = nn.Linear(d_fmaps[-1], 1, 1)
elif pool_type == 'gavg':
self.gavg = nn.AdaptiveAvgPool1d(1)
self.fc = nn.Linear(d_fmaps[-1], 1, 1)
else:
raise TypeError('Unrecognized pool type: ', pool_type)
outs = 1
Expand Down Expand Up @@ -311,6 +322,10 @@ def forward(self, x):
h = self.gmax(h)
h = h.view(h.size(0), -1)
y = self.fc(h)
elif self.pool_type == 'gavg':
h = self.gavg(h)
h = h.view(h.size(0), -1)
y = self.fc(h)
int_act['logit'] = y
#return F.sigmoid(y), int_act
return y, int_act
Expand Down
37 changes: 25 additions & 12 deletions segan/models/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,13 @@ def forward(self, x, att_weight=0.):
mode=self.linterp_mode)
x = self.linterp_aff(x)
if self.enc and self.padding == 0:
# apply proper padding
x = F.pad(x, ((self.kwidth//2)-1, self.kwidth//2))
if self.pooling == 1:
# apply proper padding
x = F.pad(x, ((self.kwidth//2), self.kwidth//2))
else:
# apply proper padding
x = F.pad(x, ((self.kwidth//2)-1, self.kwidth//2))

h = self.conv(x)
if not self.enc and not self.linterp and not self.convblock \
and self.kwidth % 2 != 0:
Expand Down Expand Up @@ -352,7 +357,7 @@ def __init__(self, ninputs, enc_fmaps, kwidth,
dec_fmaps=None, up_poolings=None,
post_proc=False, out_gate=False,
linterp_mode='linear', hidden_comb=False,
big_out_filter=False):
big_out_filter=False, z_std=1):
# if num_spks is specified, do onehot coditioners in dec stages
# subract_mean: from output signal, get rif of mean by windows
# multilayer_out: add some convs in between gblocks in decoder
Expand Down Expand Up @@ -380,6 +385,7 @@ def __init__(self, ninputs, enc_fmaps, kwidth,
self.wd = wd
self.no_tanh = no_tanh
self.skip_blacklist = skip_blacklist
self.z_std = z_std
self.gen_enc = nn.ModuleList()
if aal or aal_out:
# Make cheby1 filter to include into pytorch conv blocks
Expand All @@ -400,11 +406,13 @@ def __init__(self, ninputs, enc_fmaps, kwidth,
activations = getattr(nn, activations)()
if not isinstance(activations, list):
activations = [activations] * len(enc_fmaps)

if not isinstance(pooling, list) or len(pooling) == 1:
pooling = [pooling] * len(enc_fmaps)
skips = {}
# Build Encoder
for layer_idx, (fmaps, act) in enumerate(zip(enc_fmaps,
activations)):
for layer_idx, (fmaps, pool, act) in enumerate(zip(enc_fmaps,
pooling,
activations)):
if layer_idx == 0:
inp = ninputs
else:
Expand All @@ -421,7 +429,7 @@ def __init__(self, ninputs, enc_fmaps, kwidth,
setattr(self, 'alpha_{}'.format(l_i), skips[l_i]['alpha'])
self.gen_enc.append(GBlock(inp, fmaps, kwidth, act,
padding=None, lnorm=lnorm,
dropout=dropout, pooling=pooling,
dropout=dropout, pooling=pool,
enc=True, bias=bias,
aal_h=self.filter_h,
snorm=snorm, convblock=convblock,
Expand All @@ -434,9 +442,12 @@ def __init__(self, ninputs, enc_fmaps, kwidth,
print(dec_fmaps)
up_poolings = [pooling] * (len(dec_fmaps) - 2) + [1] * 3
add_activations = [nn.PReLU(16), nn.PReLU(8), nn.PReLU(1)]
raise NotImplementedError('MLPconv is not useful and should be'
' deleted')
else:
dec_fmaps = enc_fmaps[:-1][::-1] + [1]
up_poolings = [pooling] * len(dec_fmaps)
up_poolings = pooling[::-1]
#up_poolings = [pooling] * len(dec_fmaps)
print('up_poolings: ', up_poolings)
self.up_poolings = up_poolings
else:
Expand Down Expand Up @@ -504,7 +515,7 @@ def __init__(self, ninputs, enc_fmaps, kwidth,
comb=hidden_comb))
else:
self.gen_dec.append(GBlock(dec_inp,
fmaps, kwidth, act,
fmaps, dec_kwidth, act,
lnorm=lnorm,
dropout=dropout, pooling=1,
padding=kwidth//2,
Expand Down Expand Up @@ -567,7 +578,9 @@ def forward(self, x, z=None, ret_hid=False, spkid=None,
# MAKE DETERMINISTIC ZERO
h0 = Variable(torch.zeros(2, hi.size(0), hi.size(1)//2))
else:
h0 = Variable(torch.randn(2, hi.size(0), hi.size(1)//2))
h0 = Variable(self.z_std * torch.randn(2,
hi.size(0),
hi.size(1)//2))
c0 = Variable(torch.zeros(2, hi.size(0), hi.size(1)//2))
if self.do_cuda:
h0 = h0.cuda()
Expand All @@ -584,8 +597,8 @@ def forward(self, x, z=None, ret_hid=False, spkid=None,
if not self.no_z:
if z is None:
# make z
z = Variable(torch.randn(hi.size(0), self.z_dim,
*hi.size()[2:]))
z = Variable(self.z_std * torch.randn(hi.size(0), self.z_dim,
*hi.size()[2:]))
#print('Made z of dim: ', z.size())
if len(z.size()) != len(hi.size()):
raise ValueError('len(z.size) {} != len(hi.size) {}'
Expand Down
Loading

0 comments on commit cb4a1ed

Please sign in to comment.