Skip to content

Commit

Permalink
feat: correct bug in regression loss for SEGAN G in training (Def L1)
Browse files Browse the repository at this point in the history
  • Loading branch information
santi-pdp committed Feb 15, 2019
1 parent 8670800 commit 9ade47b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
7 changes: 1 addition & 6 deletions segan/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,7 @@ def __init__(self, opts, name='SEGAN',
super(SEGAN, self).__init__(name)
self.save_path = opts.save_path
self.preemph = opts.preemph
if hasattr(opts, 'l1_loss'):
self.l1_loss = opts.l1_loss
self.reg_loss = F.l1_loss
else:
self.l1_loss = False
self.reg_loss = F.mse_loss
self.reg_loss = getattr(F, opts.reg_loss)
if generator is None:
# Build G and D
self.G = Generator(1,
Expand Down
4 changes: 3 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ def main(opts):
parser.add_argument('--no_bias', action='store_true', default=False,
help='Disable all biases in Generator')
parser.add_argument('--n_fft', type=int, default=2048)
parser.add_argument('--l1_loss', action='store_true', default=False)
parser.add_argument('--reg_loss', type=str, default='l1_loss'm
help='Regression loss (l1_loss or mse_loss) in the '
'output of G (Def: l1_loss)')

# Skip connections options for G
parser.add_argument('--skip_merge', type=str, default='concat')
Expand Down

0 comments on commit 9ade47b

Please sign in to comment.