Skip to content

Commit

Permalink
fix: dnorm_type defaults to bnorm as should
Browse files Browse the repository at this point in the history
  • Loading branch information
santi-pdp committed Nov 16, 2018
1 parent 32200f6 commit ae925c4
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 32 deletions.
2 changes: 1 addition & 1 deletion run_segan+_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
python -u train.py --save_path ckpt_segan+ \
--clean_trainset data_veu4/expanded_segan1_additive/clean_trainset \
--noisy_trainset data_veu4/expanded_segan1_additive/noisy_trainset \
--cache_dir data_tmp --no_train_gen --batch_size 300
--cache_dir data_tmp --no_train_gen --batch_size 300 --no-bias
82 changes: 51 additions & 31 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn as nn
from torch.utils.data import DataLoader
from segan.models import SEGAN, WSEGAN, AEWSEGAN
from segan.datasets import SEDataset, collate_fn
from segan.datasets import SEDataset, SEH5Dataset, collate_fn
from segan.utils import Additive
import numpy as np
import random
Expand Down Expand Up @@ -38,39 +38,53 @@ def main(opts):
segan.G.load_pretrained(opts.g_pretrained_ckpt, True)
if opts.d_pretrained_ckpt is not None:
segan.D.load_pretrained(opts.d_pretrained_ckpt, True)
# create dataset(s) and dataloader(s)
dset = SEDataset(opts.clean_trainset,
opts.noisy_trainset,
opts.preemph,
do_cache=True,
cache_dir=opts.cache_dir,
split='train',
stride=opts.data_stride,
slice_size=opts.slice_size,
max_samples=opts.max_samples,
verbose=True,
slice_workers=opts.slice_workers,
preemph_norm=opts.preemph_norm,
random_scale=opts.random_scale
)
# create Dataset(s) and Dataloader(s)
if opts.h5:
# H5 Dataset with processed speech chunks
if opts.h5_data_root is None:
raise ValueError('Please specify an H5 data root')
dset = SEH5Dataset(opts.h5_data_root, split='train',
preemph=opts.preemph,
verbose=True,
random_scale=opts.random_scale)
else:
# Directory Dataset from raw wav files
dset = SEDataset(opts.clean_trainset,
opts.noisy_trainset,
opts.preemph,
do_cache=True,
cache_dir=opts.cache_dir,
split='train',
stride=opts.data_stride,
slice_size=opts.slice_size,
max_samples=opts.max_samples,
verbose=True,
slice_workers=opts.slice_workers,
preemph_norm=opts.preemph_norm,
random_scale=opts.random_scale
)
dloader = DataLoader(dset, batch_size=opts.batch_size,
shuffle=True, num_workers=opts.num_workers,
pin_memory=CUDA,
collate_fn=collate_fn)
if opts.clean_valset is not None:
# validation dataset
va_dset = SEDataset(opts.clean_valset,
opts.noisy_valset,
opts.preemph,
do_cache=True,
cache_dir=opts.cache_dir,
split='valid',
stride=opts.data_stride,
slice_size=opts.slice_size,
max_samples=opts.max_samples,
verbose=True,
slice_workers=opts.slice_workers,
preemph_norm=opts.preemph_norm)
if opts.h5:
dset = SEH5Dataset(opts.h5_data_root, split='valid',
preemph=opts.preemph,
verbose=True)
else:
va_dset = SEDataset(opts.clean_valset,
opts.noisy_valset,
opts.preemph,
do_cache=True,
cache_dir=opts.cache_dir,
split='valid',
stride=opts.data_stride,
slice_size=opts.slice_size,
max_samples=opts.max_samples,
verbose=True,
slice_workers=opts.slice_workers,
preemph_norm=opts.preemph_norm)
va_dloader = DataLoader(va_dset, batch_size=300,
shuffle=False, num_workers=opts.num_workers,
pin_memory=CUDA,
Expand Down Expand Up @@ -103,6 +117,12 @@ def main(opts):
default=None)#'data/clean_valset')
parser.add_argument('--noisy_valset', type=str,
default=None)#'data/noisy_valset')
parser.add_argument('--h5_data_root', type=str, default=None,
help='H5 data root dir (Def: None). The '
'files will be found by split name '
'{train, valid, test}.h5')
parser.add_argument('--h5', action='store_true', default=False,
help='Activate H5 dataset mode (Def: False).')
parser.add_argument('--data_stride', type=float,
default=0.5, help='Stride in seconds for data read')
parser.add_argument('--seed', type=int, default=111,
Expand Down Expand Up @@ -169,7 +189,7 @@ def main(opts):
'3) constant: with alpha value, set values to' \
' not learnable, just fixed.\n(Def: alpha)')
parser.add_argument('--skip_init', type=str, default='one',
help='Way to init skip connections (Def: ones)')
help='Way to init skip connections (Def: one)')
parser.add_argument('--skip_kwidth', type=int, default=11)

# Generator parameters
Expand Down Expand Up @@ -215,7 +235,7 @@ def main(opts):
parser.add_argument('--denc_poolings', type=int, nargs='+',
default=[4, 4, 4, 4, 4],
help='(Def: [4, 4, 4, 4, 4])')
parser.add_argument('--dnorm_type', type=str, default=None,
parser.add_argument('--dnorm_type', type=str, default='bnorm',
help='Normalization to be used in D. Can '
'be: (1) snorm, (2) bnorm or (3) none '
'(Def: bnorm).')
Expand Down

0 comments on commit ae925c4

Please sign in to comment.