Skip to content

Commit

Permalink
refactor: G and D without ARs
Browse files Browse the repository at this point in the history
  • Loading branch information
santi-pdp committed Nov 12, 2018
1 parent 387d145 commit 807c630
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 175 deletions.
69 changes: 8 additions & 61 deletions segan/models/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from collections import OrderedDict
from torch.nn.modules import conv, Linear
try:
from core import Model, LayerNorm, VirtualBatchNorm1d
from core import Model, LayerNorm
from modules import ResARModule, SincConv
except ImportError:
from .core import Model, LayerNorm, VirtualBatchNorm1d
from .core import Model, LayerNorm
from .modules import ResARModule, SincConv

# BEWARE: PyTorch >= 0.4.1 REQUIRED
from torch.nn.utils.spectral_norm import spectral_norm


Expand Down Expand Up @@ -250,68 +252,13 @@ def forward(self, x):
int_act['logit'] = y
return y, int_act

class ARDiscriminator(Model):

def __init__(self,
ninp=2,
dilations=[2, 4, 8, 16, 32],
kwidth=4,
fmaps=[256] * 5,
expansion_fmaps=128,
norm_type='snorm',
name='ARDiscriminator'):
super().__init__(name=name)
self.enc_blocks = nn.ModuleList()
self.in_conv = nn.Conv1d(ninp, expansion_fmaps, 2)
for pi, (fmap, dil) in enumerate(zip(fmaps,
dilations),
start=1):
enc_block = ResARModule(expansion_fmaps, fmap,
expansion_fmaps,
kwidth=kwidth,
dilation=dil,
norm_type=norm_type)
self.enc_blocks.append(enc_block)

self.mlp = nn.Sequential(
nn.PReLU(expansion_fmaps, init=0),
nn.Conv1d(expansion_fmaps, expansion_fmaps,
1),
nn.PReLU(expansion_fmaps, init=0),
nn.Conv1d(expansion_fmaps, 1,
1)
)

def forward(self, x):
x_p = F.pad(x, (1, 0))
h = self.in_conv(x_p)
skip = None
int_act = {'in_conv':h}
all_res = None
for ei, enc_block in enumerate(self.enc_blocks):
h, res = enc_block(h)
if skip is None:
skip = h
all_res = res
else:
skip += h
all_res += res
int_act['skip_{}'.format(ei)] = h
h = self.mlp(skip + all_res)
int_act['logit'] = h
return h, int_act


if __name__ == '__main__':
#disc = Discriminator(2, [16, 32, 32, 64, 64, 128, 128, 256,
# 256, 512, 1024], 31,
# nn.LeakyReLU(0.3))
#disc = BiDiscriminator([16, 32, 32, 64, 64, 128, 128, 256,
# 256, 512, 1024], 31,
# nn.LeakyReLU(0.3))
disc = ARDiscriminator()
disc = Discriminator(2, [16, 32, 32, 64, 64, 128, 128, 256,
256, 512, 1024], 31,
nn.LeakyReLU(0.3))
print(disc)
print(disc.num_parameters())
print('Num params: ', disc.get_n_params())
from torch.autograd import Variable
x = torch.randn(1, 2, 16384)
y, _ = disc(x)
Expand Down
115 changes: 1 addition & 114 deletions segan/models/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,109 +14,8 @@
from .modules import *
from .attention import *

#if int(torch.__version__[2]) > 4:
# BEWARE: PyTorch >= 0.4.1 REQUIRED
from torch.nn.utils.spectral_norm import spectral_norm
#else:
# from .spectral_norm import SpectralNorm as spectral_norm


class ARGenerator(Model):

def __init__(self, ninputs, enc_fmaps, kwidth,
activations,
pooling=4, z_dim=1024,
cuda=False,
bias=False,
dilations=[2, 4, 8, 16, 32],
ar_kwidth=4,
ar_fmaps=[256] * 5,
expansion_fmaps=256,
norm_type='snorm',
name='ARGenerator'):
super().__init__(name=name)
self.z_dim = z_dim
self.no_z = False
# do not place any z
self.do_cuda = cuda
self.gen_enc = nn.ModuleList()

skips = {}
inp = ninputs
# Build Encoder
for layer_idx, (fmaps, pool, act) in enumerate(zip(enc_fmaps,
pooling,
activations)):
self.gen_enc.append(GBlock(inp, fmaps, kwidth, act,
padding=None, lnorm=False,
dropout=0, pooling=pool,
enc=True, bias=bias,
aal_h=None,
snorm=(norm_type == 'snorm'),
convblock=False,
satt=False))
inp = fmaps
ninp = inp
self.dec_blocks = nn.ModuleList()
self.in_conv = nn.Conv1d(2, expansion_fmaps, 4)
for pi, (fmap, dil) in enumerate(zip(ar_fmaps,
dilations),
start=1):
enc_block = ResARModule(expansion_fmaps, fmap,
expansion_fmaps,
kwidth=ar_kwidth,
dilation=dil,
norm_type=norm_type)
self.dec_blocks.append(enc_block)

self.mlp = nn.Sequential(
nn.PReLU(expansion_fmaps, init=0),
nn.Conv1d(expansion_fmaps, expansion_fmaps,
1),
nn.PReLU(expansion_fmaps, init=0),
nn.Conv1d(expansion_fmaps, 1,
1)
)

def forward(self, x, z=None, ret_hid=False, spkid=None,
slice_idx=0, att_weight=0):
bsz, nch, time = x.size()
hall = {}
hi = x
for l_i, enc_layer in enumerate(self.gen_enc):
hi, linear_hi = enc_layer(hi, att_weight=0)
if ret_hid:
hall['enc_{}'.format(l_i)] = hi
# reshape tensor to match time resolution
hi = hi.view(bsz, -1, time)
# make z latent variable and concat
z = torch.randn(bsz, *hi.size()[1:])
if hi.is_cuda:
z = z.to('cuda')
if not hasattr(self, 'z'):
self.z = z
hi = torch.cat((z, hi), dim=1)
if ret_hid:
hall['enc_zc'] = hi
h_p = F.pad(hi, (3, 0))
h = self.in_conv(h_p)
skip = None
hall = {'in_conv':h}
all_res = None
for di, dec_block in enumerate(self.dec_blocks):
h, res = dec_block(h)
if skip is None:
skip = h
all_res = res
else:
skip += h
all_res += res
hall['skip_{}'.format(di)] = h
h = self.mlp(skip + all_res)
hall['logit'] = h
if ret_hid:
return h, hall
else:
return h

class GSkip(nn.Module):

Expand Down Expand Up @@ -172,17 +71,6 @@ def forward(self, hj, hi):
else:
raise TypeError('Unrecognized skip merge mode: ', self.merge_mode)

class LinterpAffine(nn.Module):

def __init__(self, num_params=1, std=1, bias=0):
super().__init__()
self.linterp_w = nn.Parameter(std * torch.randn(num_params))
self.linterp_b = nn.Parameter(torch.ones(num_params) * bias)

def forward(self, x):
return self.linterp_w.view(1, -1, 1) * x + self.linterp_b.view(1, -1,
1)


class GBlock(nn.Module):

Expand Down Expand Up @@ -242,7 +130,6 @@ def __init__(self, ninputs, fmaps, kwidth,
self.glu_conv = spectral_norm(self.glu_conv)
else:
if linterp:
#self.linterp_aff = LinterpAffine(ninputs, std=0.1)
self.linterp_norm = nn.InstanceNorm1d(ninputs)
self.conv = nn.Conv1d(ninputs, fmaps, kwidth,
stride=1, padding=kwidth//2,
Expand Down

0 comments on commit 807c630

Please sign in to comment.