diff --git a/kospeech/trainer/supervised_trainer.py b/kospeech/trainer/supervised_trainer.py index 78d99a2c..faf4ab5a 100644 --- a/kospeech/trainer/supervised_trainer.py +++ b/kospeech/trainer/supervised_trainer.py @@ -274,27 +274,44 @@ def _train_epoches( train_elapsed = (current_time - train_begin_time) / 3600.0 if self.architecture in ('rnnt', 'conformer'): - logger.info(self.log_format.format( - timestep, epoch_time_step, loss, - elapsed, epoch_elapsed, train_elapsed, - self.optimizer.get_lr(), - )) + if isinstance(model, nn.DataParallel): + if model.module.decoder is not None: + logger.info(self.rnnt_log_format.format( + timestep, epoch_time_step, loss, + elapsed, epoch_elapsed, train_elapsed, + self.optimizer.get_lr(), + )) + else: + logger.info(self.log_format.format( + timestep, epoch_time_step, loss, + cer, elapsed, epoch_elapsed, train_elapsed, + self.optimizer.get_lr(), + )) + else: + if model.module.decoder is not None: + logger.info(self.rnnt_log_format.format( + timestep, epoch_time_step, loss, + elapsed, epoch_elapsed, train_elapsed, + self.optimizer.get_lr(), + )) + else: + logger.info(self.log_format.format( + timestep, epoch_time_step, loss, + cer, elapsed, epoch_elapsed, train_elapsed, + self.optimizer.get_lr(), + )) else: if self.joint_ctc_attention: logger.info(self.log_format.format( - timestep, epoch_time_step, - loss, - ctc_loss, cross_entropy_loss, - cer, + timestep, epoch_time_step, loss, + ctc_loss, cross_entropy_loss, cer, elapsed, epoch_elapsed, train_elapsed, self.optimizer.get_lr(), )) else: logger.info(self.log_format.format( - timestep, epoch_time_step, - loss, - cer, - elapsed, epoch_elapsed, train_elapsed, + timestep, epoch_time_step, loss, + cer, elapsed, epoch_elapsed, train_elapsed, self.optimizer.get_lr(), )) begin_time = time.time() @@ -462,7 +479,7 @@ def _save_result(self, target_list: list, predict_list: list) -> None: results = pd.DataFrame(results) results.to_csv(save_path, index=False, encoding='cp949') - + def _save_epoch_result(self, train_result: list, valid_result: list) -> None: """ Save result of epoch """ train_dict, train_loss, train_cer = train_result