Skip to content

Commit

Permalink
Update trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
sooftware committed Feb 3, 2021
1 parent 3041b68 commit 62f5a78
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
8 changes: 5 additions & 3 deletions kospeech/models/conformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch
from torch import Tensor
from typing import Tuple, Optional

from kospeech.models.conformer.encoder import ConformerEncoder
from kospeech.models.model import TransducerModel
Expand Down Expand Up @@ -109,7 +110,7 @@ def forward(
input_lengths: Tensor,
targets: Tensor,
target_lengths: Tensor,
) -> Tensor:
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Forward propagate a `inputs` and `targets` pair for training.
Expand All @@ -125,8 +126,9 @@ def forward(
"""
if self.decoder is not None:
return super().forward(inputs, input_lengths, targets, target_lengths)
encoder_outputs, _ = self.encoder(inputs, input_lengths)
return self.fc(encoder_outputs).log_softmax(dim=-1)
encoder_outputs, output_lengths = self.encoder(inputs, input_lengths)
outputs = self.fc(encoder_outputs).log_softmax(dim=-1)
return outputs, output_lengths

@torch.no_grad()
def decode(self, encoder_outputs: Tensor, max_length: int = None) -> Tensor:
Expand Down
25 changes: 22 additions & 3 deletions kospeech/trainer/supervised_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def __init__(
"cer: {:.2f}, elapsed: {:.2f}s {:.2f}m {:.2f}h, lr: {:.6f}"

if self.architecture in ('rnnt', 'conformer'):
self.log_format = "step: {:4d}/{:4d}, loss: {:.6f}, " \
"elapsed: {:.2f}s {:.2f}m {:.2f}h, lr: {:.6f}"
self.rnnt_log_format = "step: {:4d}/{:4d}, loss: {:.6f}, " \
"elapsed: {:.2f}s {:.2f}m {:.2f}h, lr: {:.6f}"

def train(
self,
Expand Down Expand Up @@ -422,7 +422,26 @@ def _model_forward(
outputs, output_lengths = model(inputs, input_lengths)
loss = self.criterion(outputs.transpose(0, 1), targets[:, 1:], output_lengths, target_lengths)

elif self.architecture in ('conformer', 'rnnt'):
elif self.architecture == 'conformer':
if isinstance(model, nn.DataParallel):
if model.module.decoder is not None:
outputs, output_lengths = model(inputs, input_lengths)
loss = self.criterion(outputs.transpose(0, 1), targets[:, 1:], output_lengths, target_lengths)
else:
outputs = model(inputs, input_lengths, targets, target_lengths)
loss = self.criterion(
outputs, targets[:, 1:].contiguous().int(), input_lengths.int(), target_lengths.int()
)
else:
if model.module.decoder is not None:
outputs, output_lengths = model(inputs, input_lengths)
loss = self.criterion(outputs.transpose(0, 1), targets[:, 1:], output_lengths, target_lengths)
else:
outputs = model(inputs, input_lengths, targets, target_lengths)
loss = self.criterion(
outputs, targets[:, 1:].contiguous().int(), input_lengths.int(), target_lengths.int()
)
elif self.architecture in 'rnnt':
outputs = model(inputs, input_lengths, targets, target_lengths)
loss = self.criterion(
outputs, targets[:, 1:].contiguous().int(), input_lengths.int(), target_lengths.int()
Expand Down

0 comments on commit 62f5a78

Please sign in to comment.