Skip to content

Commit

Permalink
Update trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
sooftware committed Feb 3, 2021
1 parent ad36c04 commit 9caa4b9
Showing 1 changed file with 31 additions and 14 deletions.
45 changes: 31 additions & 14 deletions kospeech/trainer/supervised_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,27 +274,44 @@ def _train_epoches(
train_elapsed = (current_time - train_begin_time) / 3600.0

if self.architecture in ('rnnt', 'conformer'):
logger.info(self.log_format.format(
timestep, epoch_time_step, loss,
elapsed, epoch_elapsed, train_elapsed,
self.optimizer.get_lr(),
))
if isinstance(model, nn.DataParallel):
if model.module.decoder is not None:
logger.info(self.rnnt_log_format.format(
timestep, epoch_time_step, loss,
elapsed, epoch_elapsed, train_elapsed,
self.optimizer.get_lr(),
))
else:
logger.info(self.log_format.format(
timestep, epoch_time_step, loss,
cer, elapsed, epoch_elapsed, train_elapsed,
self.optimizer.get_lr(),
))
else:
if model.module.decoder is not None:
logger.info(self.rnnt_log_format.format(
timestep, epoch_time_step, loss,
elapsed, epoch_elapsed, train_elapsed,
self.optimizer.get_lr(),
))
else:
logger.info(self.log_format.format(
timestep, epoch_time_step, loss,
cer, elapsed, epoch_elapsed, train_elapsed,
self.optimizer.get_lr(),
))
else:
if self.joint_ctc_attention:
logger.info(self.log_format.format(
timestep, epoch_time_step,
loss,
ctc_loss, cross_entropy_loss,
cer,
timestep, epoch_time_step, loss,
ctc_loss, cross_entropy_loss, cer,
elapsed, epoch_elapsed, train_elapsed,
self.optimizer.get_lr(),
))
else:
logger.info(self.log_format.format(
timestep, epoch_time_step,
loss,
cer,
elapsed, epoch_elapsed, train_elapsed,
timestep, epoch_time_step, loss,
cer, elapsed, epoch_elapsed, train_elapsed,
self.optimizer.get_lr(),
))
begin_time = time.time()
Expand Down Expand Up @@ -462,7 +479,7 @@ def _save_result(self, target_list: list, predict_list: list) -> None:

results = pd.DataFrame(results)
results.to_csv(save_path, index=False, encoding='cp949')

def _save_epoch_result(self, train_result: list, valid_result: list) -> None:
""" Save result of epoch """
train_dict, train_loss, train_cer = train_result
Expand Down

0 comments on commit 9caa4b9

Please sign in to comment.