Skip to content

Commit

Permalink
Merge branch 'master' of github.com:santi-pdp/segan_pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
santi-pdp committed Sep 3, 2019
2 parents 100bcd2 + f2259ce commit 27580b7
Show file tree
Hide file tree
Showing 10 changed files with 195 additions and 40 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ Latest denoising audio samples with baselines can be found in the [segan+ sample

The voicing/dewhispering audio samples can be found in the [whispersegan samples website](http://veu.talp.cat/whispersegan). Artifacts can now be palliated a bit more with `--interf_pair` fake signals, more data than the one we had available (just 20 mins with 1 speaker per model) and longer training session by iterating more than `100 epoch`.

### Pretrained Models

SEGAN+ generator weights are released and can be downloaded in [this link](http://veu.talp.cat/seganp/release_weights/segan+_generator.ckpt). Make sure you place this file into the `ckpt_segan+` directory to make it work with the proper `train.opts` config file within that folder. The script `run_segan+_clean.sh` will properly read the ckpt in that directory as it is configured to be used with this referenced file.

### Introduction to scripts

Two models are ready to train and use to make wav2wav speech enhancement conversions. SEGAN+ is an
Expand Down
92 changes: 92 additions & 0 deletions ckpt_segan+/train.opts
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
{
"save_path": "ckpt_segan+",
"d_pretrained_ckpt": null,
"g_pretrained_ckpt": null,
"cache_dir": "data_tmp",
"clean_trainset": "data_veu4/expanded_segan1_additive/clean_trainset",
"noisy_trainset": "data_veu4/expanded_segan1_additive/noisy_trainset",
"clean_valset": null,
"noisy_valset": null,
"h5_data_root": null,
"h5": false,
"data_stride": 0.5,
"seed": 111,
"epoch": 100,
"patience": 100,
"batch_size": 300,
"save_freq": 50,
"slice_size": 16384,
"opt": "rmsprop",
"l1_dec_epoch": 100,
"l1_weight": 100,
"l1_dec_step": 1e-05,
"g_lr": 5e-05,
"d_lr": 5e-05,
"preemph": 0.95,
"max_samples": null,
"eval_workers": 2,
"slice_workers": 1,
"num_workers": 3,
"no_cuda": false,
"random_scale": [
1
],
"no_train_gen": true,
"preemph_norm": false,
"wsegan": false,
"aewsegan": false,
"vanilla_gan": false,
"no_bias": false,
"n_fft": 2048,
"l1_loss": false,
"skip_merge": "concat",
"skip_type": "alpha",
"skip_init": "one",
"skip_kwidth": 11,
"gkwidth": 31,
"genc_fmaps": [
64,
128,
256,
512,
1024
],
"genc_poolings": [
4,
4,
4,
4,
4
],
"z_dim": 1024,
"gdec_fmaps": null,
"gdec_poolings": null,
"gdec_kwidth": null,
"gnorm_type": null,
"no_z": false,
"no_skip": false,
"pow_weight": 0.001,
"misalign_pair": false,
"interf_pair": false,
"denc_fmaps": [
64,
128,
256,
512,
1024
],
"dpool_type": "none",
"dpool_slen": 16,
"dkwidth": null,
"denc_poolings": [
4,
4,
4,
4,
4
],
"dnorm_type": "bnorm",
"phase_shift": 5,
"sinc_conv": false,
"bias": true
}
28 changes: 19 additions & 9 deletions clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,28 @@ def main(opts):
if opts.cuda:
segan.cuda()
segan.G.eval()
# process every wav in the test_files
if len(opts.test_files) == 1:
# assume we read directory
twavs = glob.glob(os.path.join(opts.test_files[0], '*.wav'))
if opts.h5:
with h5py.File(opts.test_files[0], 'r') as f:
twavs = f['data'][:]
else:
# assume we have list of files in input
twavs = opts.test_files
# process every wav in the test_files
if len(opts.test_files) == 1:
# assume we read directory
twavs = glob.glob(os.path.join(opts.test_files[0], '*.wav'))
else:
# assume we have list of files in input
twavs = opts.test_files
print('Cleaning {} wavs'.format(len(twavs)))
beg_t = timeit.default_timer()
for t_i, twav in enumerate(twavs, start=1):
tbname = os.path.basename(twav)
rate, wav = wavfile.read(twav)
wav = normalize_wave_minmax(wav)
if not opts.h5:
tbname = os.path.basename(twav)
rate, wav = wavfile.read(twav)
wav = normalize_wave_minmax(wav)
else:
tbname = 'tfile_{}.wav'.format(t_i)
wav = twav
twav = tbname
wav = pre_emphasize(wav, args.preemph)
pwav = torch.FloatTensor(wav).view(1,1,-1)
if opts.cuda:
Expand All @@ -76,6 +85,7 @@ def main(opts):
parser = argparse.ArgumentParser()
parser.add_argument('--g_pretrained_ckpt', type=str, default=None)
parser.add_argument('--test_files', type=str, nargs='+', default=None)
parser.add_argument('--h5', action='store_true', default=False)
parser.add_argument('--seed', type=int, default=111,
help="Random seed (Def: 111).")
parser.add_argument('--synthesis_path', type=str, default='segan_samples',
Expand Down
36 changes: 36 additions & 0 deletions purge_ckpts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import argparse
import json
import glob
import os


def clean(opts):
logs = glob.glob(os.path.join(opts.ckpt_dir, '*checkpoint*'))
print(logs)
for log in logs:
with open(log, 'r') as log_f:
log_ = json.load(log_f)
# first assertive check that all files exist, no mismatch
# b/w log and filenames existence
for fname in log_['latest']:
fpath = os.path.join(opts.ckpt_dir, 'weights_' + fname)
assert os.path.exists(fpath), fpath
to_rm = [l for l in log_['latest'][:-1] if l != log_['current']]
to_kp = log_['latest'][-1]
for fname in to_rm:
fpath = os.path.join(opts.ckpt_dir, 'weights_' + fname)
os.unlink(fpath)
print('Removed file ', fpath)
print('Kept file ', os.path.join(opts.ckpt_dir, 'weights_' + \
to_kp))
# re-write log
with open(log, 'w') as log_f:
log_['latest'] = [log_['latest'][-1]]
log_f.write(json.dumps(log_, indent=2))

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('ckpt_dir', type=str, default=None)
opts = parser.parse_args()

clean(opts)
6 changes: 3 additions & 3 deletions run_segan+_clean.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ CKPT_PATH="ckpt_segan+"

# please specify the path to your G model checkpoint
# as in weights_G-EOE_<iter>.ckpt
G_PRETRAINED_CKPT=""
G_PRETRAINED_CKPT="segan+_generator.ckpt"

# please specify the path to your folder containing
# noisy test files, each wav in there will be processed
TEST_FILES_PATH=""
TEST_FILES_PATH="data_veu4/expanded_segan1_additive/noisy_testset/"

# please specify the output folder where cleaned files
# will be saved
SAVE_PATH=""
SAVE_PATH="synth_segan+"

python -u clean.py --g_pretrained_ckpt $CKPT_PATH/$G_PRETRAINED_CKPT \
--test_files $TEST_FILES_PATH --cfg_file $CKPT_PATH/train.opts \
Expand Down
11 changes: 5 additions & 6 deletions segan/datasets/se_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def slice_signal_index(path, window_size, stride):
n_samples = signal.shape[0]
slices = []
offset = int(window_size * stride)
for beg_i in range(0, n_samples - (offset), offset):
#for beg_i in range(0, n_samples - window_size + 1, offset):
#for beg_i in range(0, n_samples - (offset), offset):
for beg_i in range(0, n_samples - window_size + 1, offset):
end_i = beg_i + window_size
#if end_i >= n_samples:
# last slice is offset to past to fit full window
Expand Down Expand Up @@ -531,14 +531,13 @@ class SEH5Dataset(Dataset):
to fixed size).
"""
def __init__(self, data_root, split, preemph,
max_samples=None, verbose=False,
verbose=False,
preemph_norm=False,
random_scale=[1]):
super().__init__()
self.data_root = data_root
self.split = split
self.preemph = preemph
self.max_samples = max_samples
self.verbose = verbose
self.random_scale = random_scale
h5_file = os.path.join(data_root, split + '.h5')
Expand All @@ -561,8 +560,8 @@ def __getitem__(self, index):
c_slice = rscale * c_slice
n_slice = rscale * n_slice
# uttname not known with H5
returns = ['N/A', torch.FloatTensor(c_slice),
torch.FloatTensor(n_slice), 0]
returns = ['N/A', torch.FloatTensor(c_slice).squeeze(-1),
torch.FloatTensor(n_slice).squeeze(-1), 0]
return returns

def __len__(self):
Expand Down
4 changes: 1 addition & 3 deletions segan/models/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,8 @@ def __init__(self, skip_type, size, skip_init, skip_dropout=0,
self.skip_k = nn.Parameter(alpha_.view(1, -1, 1))
else:
# constant, not learnable
self.skip_k = alpha_
self.skip_k = nn.Parameter(alpha_.view(1, -1, 1))
self.skip_k.requires_grad = False
self.skip_k = self.skip_k.view(1, -1, 1)
elif skip_type == 'conv':
if kwidth > 1:
pad = kwidth // 2
Expand Down Expand Up @@ -230,7 +229,6 @@ def forward(self, x, z=None, ret_hid=False):
else:
return hi


class Generator1D(Model):

def __init__(self, ninputs, enc_fmaps, kwidth,
Expand Down
30 changes: 15 additions & 15 deletions segan/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(self, opts, name='SEGAN',
super(SEGAN, self).__init__(name)
self.save_path = opts.save_path
self.preemph = opts.preemph
self.reg_loss = getattr(F, opts.reg_loss)
if generator is None:
# Build G and D
self.G = Generator(1,
Expand Down Expand Up @@ -112,7 +113,7 @@ def __init__(self, opts, name='SEGAN',
self.D.apply(weights_init)
print('Discriminator: ', self.D)

def generate(self, inwav, z = None):
def generate(self, inwav, z = None, device='cpu'):
self.G.eval()
N = 16384
x = np.zeros((1, 1, N))
Expand All @@ -127,11 +128,13 @@ def generate(self, inwav, z = None):
pad = 0
if pad > 0:
x[0, 0] = torch.cat((inwav[0, 0, beg_i:beg_i + length],
torch.zeros(pad)), dim=0)
torch.zeros(pad).to(device)), dim=0)
else:
x[0, 0] = inwav[0, 0, beg_i:beg_i + length]
x = torch.FloatTensor(x)
#canvas_w, hall = self.G(x, z=z, ret_hid=True)
#x = torch.FloatTensor(x)
if isinstance(x, np.ndarray):
x = torch.FloatTensor(x)
x = x.to(device)
canvas_w, hall = self.infer_G(x, z=z, ret_hid=True)
nums = []
for k in hall.keys():
Expand All @@ -143,7 +146,7 @@ def generate(self, inwav, z = None):
z = self.G.z
if pad > 0:
canvas_w = canvas_w[0, 0, :-pad]
canvas_w = canvas_w.data.numpy().squeeze()
canvas_w = canvas_w.data.cpu().numpy().squeeze()
if c_res is None:
c_res = canvas_w
else:
Expand Down Expand Up @@ -225,7 +228,7 @@ def build_optimizers(self, opts):
return Gopt, Dopt

def train(self, opts, dloader, criterion, l1_init, l1_dec_step,
l1_dec_epoch, log_freq, va_dloader=None,
l1_dec_epoch, log_freq, va_dloader=None,
device='cpu'):
""" Train the SEGAN """

Expand Down Expand Up @@ -311,7 +314,8 @@ def train(self, opts, dloader, criterion, l1_init, l1_dec_step,
lab = label.fill_(1)
d_fake_, _ = self.infer_D(Genh, noisy)
g_adv_loss = criterion(d_fake_.view(-1), lab)
g_l1_loss = l1_weight * F.l1_loss(Genh, clean)
#g_l1_loss = l1_weight * F.l1_loss(Genh, clean)
g_l1_loss = l1_weight * self.reg_loss(Genh, clean)
g_loss = g_adv_loss + g_l1_loss
g_loss.backward()
Gopt.step()
Expand Down Expand Up @@ -597,7 +601,7 @@ def train(self, opts, dloader, criterion, l1_init, l1_dec_step,
d_fake_shuf, _ = self.infer_D(clean, clean_shuf)
d_fake_shuf_loss = cost(d_fake_shuf, fk_lab)
d_weight = 1 / 3 # count 3 components now
d_loss + d_fake_shuf_loss
d_loss += d_fake_shuf_loss

if self.interf_pair:
# put interferring squared signals with random amplitude and
Expand Down Expand Up @@ -769,10 +773,6 @@ class AEWSEGAN(WSEGAN):
def __init__(self, opts, name='AEWSEGAN',
generator=None,
discriminator=None):
if hasattr(opts, 'l1_loss'):
self.l1_loss = opts.l1_loss
else:
self.l1_loss = False
super().__init__(opts, name=name, generator=generator,
discriminator=discriminator)
# delete discriminator
Expand Down Expand Up @@ -811,12 +811,12 @@ def train(self, opts, dloader, criterion, l1_init, l1_dec_step,
best_val_obj = np.inf
# acumulator for exponential avg of valid curve
acum_val_obj = 0
alpha_val = opts.alpha_val
G = self.G

for iteration in range(1, opts.epoch * len(dloader) + 1):
beg_t = timeit.default_timer()
uttname, clean, noisy, slice_idx = self.sample_dloader(dloader)
uttname, clean, noisy, slice_idx = self.sample_dloader(dloader,
device)
bsz = clean.size(0)
Genh = self.infer_G(noisy, clean)
Gopt.zero_grad()
Expand Down Expand Up @@ -867,7 +867,7 @@ def train(self, opts, dloader, criterion, l1_init, l1_dec_step,
''.format(timings[-1],
np.mean(timings))
print(log)
self.writer.add_scalar('g_l2_loss', loss.item(),
self.writer.add_scalar('g_l2/l1_loss', loss.item(),
iteration)
self.writer.add_scalar('G_pow_loss', pow_loss.item(),
iteration)
Expand Down
Loading

0 comments on commit 27580b7

Please sign in to comment.