Skip to content

Commit

Permalink
feat: latest AR training
Browse files Browse the repository at this point in the history
  • Loading branch information
santi-pdp committed Oct 22, 2018
1 parent b8e94e7 commit 43fbccb
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 24 deletions.
10 changes: 10 additions & 0 deletions segan/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,16 @@ def activation(self, name):
def parameters(self):
return filter(lambda p: p.requires_grad, super().parameters())

# https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/5
def get_n_params(self):
pp=0
for p in list(self.parameters()):
nn=1
for s in list(p.size()):
nn = nn*s
pp += nn
return pp

class LayerNorm(nn.Module):

def __init__(self, *args):
Expand Down
7 changes: 5 additions & 2 deletions segan/models/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,14 +279,17 @@ def forward(self, x):
h = self.in_conv(x_p)
skip = None
int_act = {'in_conv':h}
all_res = None
for ei, enc_block in enumerate(self.enc_blocks):
h = enc_block(h)
h, res = enc_block(h)
if skip is None:
skip = h
all_res = res
else:
skip += h
all_res += res
int_act['skip_{}'.format(ei)] = h
h = self.mlp(skip)
h = self.mlp(skip + all_res)
int_act['logit'] = h
return h, int_act

Expand Down
55 changes: 35 additions & 20 deletions segan/models/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self, ninputs, enc_fmaps, kwidth,
name='ARGenerator'):
super().__init__(name=name)
self.z_dim = z_dim
self.no_z = False
# do not place any z
self.do_cuda = cuda
self.gen_enc = nn.ModuleList()
Expand All @@ -56,7 +57,7 @@ def __init__(self, ninputs, enc_fmaps, kwidth,
inp = fmaps
ninp = inp
self.dec_blocks = nn.ModuleList()
self.in_conv = nn.Conv1d(ninp, expansion_fmaps, 4)
self.in_conv = nn.Conv1d(2, expansion_fmaps, 4)
for pi, (fmap, dil) in enumerate(zip(ar_fmaps,
dilations),
start=1):
Expand All @@ -78,41 +79,49 @@ def __init__(self, ninputs, enc_fmaps, kwidth,

def forward(self, x, z=None, ret_hid=False, spkid=None,
slice_idx=0, att_weight=0):
bsz, nch, time = x.size()
hall = {}
hi = x
skips = self.skips
for l_i, enc_layer in enumerate(self.gen_enc):
hi, linear_hi = enc_layer(hi, att_weight=0)
if ret_hid:
hall['enc_{}'.format(l_i)] = hi
z = torch.randn(hi.size(0), self.z_dim,
*hi.size()[2:])
# reshape tensor to match time resolution
hi = hi.view(bsz, -1, time)
# make z latent variable and concat
z = torch.randn(bsz, *hi.size()[1:])
if hi.is_cuda:
z = z.to('cuda')
if not hasattr(self, 'z'):
self.z = z
hi = torch.cat((z, hi), dim=1)
if ret_hid:
hall['enc_zc'] = hi
print('hi after enc: ', hi.size())
raise NotImplementedError
x_p = F.pad(x, (3, 0))
h = self.in_conv(x_p)
h_p = F.pad(hi, (3, 0))
h = self.in_conv(h_p)
skip = None
int_act = {'in_conv':h}
for ei, enc_block in enumerate(self.enc_blocks):
h = enc_block(h)
hall = {'in_conv':h}
all_res = None
for di, dec_block in enumerate(self.dec_blocks):
h, res = dec_block(h)
if skip is None:
skip = h
all_res = res
else:
skip += h
int_act['skip_{}'.format(ei)] = h
h = self.mlp(skip)
int_act['logit'] = h
return h, int_act
all_res += res
hall['skip_{}'.format(di)] = h
h = self.mlp(skip + all_res)
hall['logit'] = h
if ret_hid:
return h, hall
else:
return h

class GSkip(nn.Module):

def __init__(self, skip_type, size, skip_init, skip_dropout=0,
merge_mode='sum', cuda=False):
merge_mode='sum', cuda=False, kwidth=11):
# skip_init only applies to alpha skips
super().__init__()
self.merge_mode = merge_mode
Expand All @@ -135,8 +144,12 @@ def __init__(self, skip_type, size, skip_init, skip_dropout=0,
self.skip_k = Variable(alpha_, requires_grad=False)
self.skip_k = self.skip_k.view(1, -1, 1)
elif skip_type == 'conv':
self.skip_k = nn.Conv1d(size, size, 11, stride=1,
padding=11//2)
if kwidth > 1:
pad = kwidth // 2
else:
pad = 0
self.skip_k = nn.Conv1d(size, size, kwidth, stride=1,
padding=pad)
else:
raise TypeError('Unrecognized GSkip scheme: ', skip_type)
self.skip_type = skip_type
Expand Down Expand Up @@ -448,12 +461,13 @@ def __init__(self, ninputs, enc_fmaps, kwidth,
post_proc=False, out_gate=False,
linterp_mode='linear', hidden_comb=False,
big_out_filter=False, z_std=1,
freeze_enc=False):
freeze_enc=False, skip_kwidth=11):
# if num_spks is specified, do onehot coditioners in dec stages
# subract_mean: from output signal, get rif of mean by windows
# multilayer_out: add some convs in between gblocks in decoder
super().__init__(name='Generator1D')
self.dec_kwidth = dec_kwidth
self.skip_kwidth = skip_kwidth
self.skip = skip
self.skip_init = skip_init
self.skip_dropout = skip_dropout
Expand Down Expand Up @@ -516,7 +530,8 @@ def __init__(self, ninputs, enc_fmaps, kwidth,
skip_init,
skip_dropout,
merge_mode=skip_merge,
cuda=self.do_cuda)
cuda=self.do_cuda,
kwidth=self.skip_kwidth)
skips[l_i] = {'alpha':gskip}
setattr(self, 'alpha_{}'.format(l_i), skips[l_i]['alpha'])
self.gen_enc.append(GBlock(inp, fmaps, kwidth, act,
Expand Down
7 changes: 6 additions & 1 deletion segan/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ def __init__(self, opts, name='SEGAN',
self.d_pool_type = opts.d_pool_type
else:
self.d_pool_type = 'conv'
if hasattr(opts, 'skip_kwidth'):
self.skip_kwidth = opts.skip_kwidth
else:
self.skip_kwidth = 11
if hasattr(opts, 'post_skip'):
self.post_skip = opts.post_skip
else:
Expand Down Expand Up @@ -247,7 +251,8 @@ def __init__(self, opts, name='SEGAN',
linterp_mode=self.linterp_mode,
hidden_comb=self.hidden_comb,
z_std=self.z_std,
freeze_genc=self.freeze_genc)
freeze_enc=self.freeze_genc,
skip_kwidth=self.skip_kwidth)

else:
self.G = generator
Expand Down
2 changes: 1 addition & 1 deletion segan/models/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def forward(self, x):
# conv 1x1 to make residual connection
h = self.conv_1x1_skip(h)
# normalization if applies
h = self.forward_norm(h, self.conv_1x1_norm)
h = self.forward_norm(h, self.conv_1x1_skip_norm)
# return with skip connection
y = x + h
# also return res connection (going to further net point directly)
Expand Down
2 changes: 2 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def main(opts):
segan = WSEGAN(opts)
else:
segan = SEGAN(opts)
print('Total model parameters: ', segan.get_n_params())
if opts.g_pretrained_ckpt is not None:
segan.G.load_pretrained(opts.g_pretrained_ckpt, True)
if opts.d_pretrained_ckpt is not None:
Expand Down Expand Up @@ -214,6 +215,7 @@ def main(opts):
parser.add_argument('--ardiscriminator', action='store_true',
default=False)
parser.add_argument('--n_fft', type=int, default=2048)
parser.add_argument('--skip_kwidth', type=int, default=11)

opts = parser.parse_args()
opts.d_bnorm = not opts.no_dbnorm
Expand Down

0 comments on commit 43fbccb

Please sign in to comment.