diff --git a/configs/model/conformer-large.yaml b/configs/model/conformer-large.yaml index 13a6905c..d1a7c7aa 100644 --- a/configs/model/conformer-large.yaml +++ b/configs/model/conformer-large.yaml @@ -17,3 +17,4 @@ num_encoder_layers: 17 num_decoder_layers: 1 num_attention_heads: 8 decoder_rnn_type: lstm +decoder: None diff --git a/configs/model/conformer-medium.yaml b/configs/model/conformer-medium.yaml index 359dc7e4..4626a94f 100644 --- a/configs/model/conformer-medium.yaml +++ b/configs/model/conformer-medium.yaml @@ -16,3 +16,4 @@ decoder_dim: 640 num_encoder_layers: 17 num_decoder_layers: 1 num_attention_heads: 4 +decoder: None diff --git a/configs/model/conformer-small.yaml b/configs/model/conformer-small.yaml index 82f1efdd..6de0b47f 100644 --- a/configs/model/conformer-small.yaml +++ b/configs/model/conformer-small.yaml @@ -16,3 +16,4 @@ decoder_dim: 320 num_encoder_layers: 16 num_decoder_layers: 1 num_attention_heads: 4 +decoder: None diff --git a/kospeech/model_builder.py b/kospeech/model_builder.py index e2debb61..39f244b4 100644 --- a/kospeech/model_builder.py +++ b/kospeech/model_builder.py @@ -109,6 +109,7 @@ def build_model( conv_kernel_size=config.model.conv_kernel_size, half_step_residual=config.model.half_step_residual, device=device, + decoder=config.model.decoder, ) elif config.model.architecture.lower() == 'rnnt': @@ -187,6 +188,7 @@ def build_conformer( conv_kernel_size: int, half_step_residual: bool, device: torch.device, + decoder: str, ) -> nn.DataParallel: if input_dropout_p < 0.0: raise ParameterError("dropout probability should be positive") @@ -219,6 +221,7 @@ def build_conformer( conv_kernel_size=conv_kernel_size, half_step_residual=half_step_residual, device=device, + decoder=decoder, )) diff --git a/kospeech/models/conformer/__init__.py b/kospeech/models/conformer/__init__.py index 1d386f54..c0d0a3a9 100644 --- a/kospeech/models/conformer/__init__.py +++ b/kospeech/models/conformer/__init__.py @@ -31,7 +31,8 @@ class ConformerConfig(ModelConfig): conv_kernel_size: int = 31 half_step_residual: bool = True num_decoder_layers: int = 1 - decoder_rnn_type: str = 'lstm' + decoder_rnn_type: str = "lstm" + decoder: str = None @dataclass diff --git a/kospeech/models/conformer/model.py b/kospeech/models/conformer/model.py index 57d44546..7448a714 100644 --- a/kospeech/models/conformer/model.py +++ b/kospeech/models/conformer/model.py @@ -44,6 +44,7 @@ class Conformer(TransducerModel): conv_kernel_size (int or tuple, optional): Size of the convolving kernel half_step_residual (bool): Flag indication whether to use half step residual or not device (torch.device): torch device (cuda or cpu) + decoder (str): If decoder is None, train with CTC decoding Inputs: inputs - **inputs** (batch, time, dim): Tensor containing input vector @@ -72,6 +73,7 @@ def __init__( conv_kernel_size: int = 31, half_step_residual: bool = True, device: torch.device = 'cuda', + decoder: str = None, ) -> None: encoder = ConformerEncoder( input_dim=input_dim, @@ -88,15 +90,16 @@ def __init__( half_step_residual=half_step_residual, device=device, ) - decoder = DecoderRNNT( - num_classes=num_classes, - hidden_state_dim=decoder_dim, - output_dim=encoder_dim, - num_layers=num_decoder_layers, - rnn_type=decoder_rnn_type, - dropout_p=decoder_dropout_p, - ) - super(Conformer, self).__init__(encoder, decoder, encoder_dim, num_classes) + if decoder == 'rnnt': + decoder = DecoderRNNT( + num_classes=num_classes, + hidden_state_dim=decoder_dim, + output_dim=encoder_dim, + num_layers=num_decoder_layers, + rnn_type=decoder_rnn_type, + dropout_p=decoder_dropout_p, + ) + super(Conformer, self).__init__(encoder, decoder, encoder_dim >> 1, num_classes) def forward( self, @@ -118,4 +121,44 @@ def forward( Returns: * predictions (torch.FloatTensor): Result of model predictions. """ - return super().forward(inputs, input_lengths, targets, target_lengths) + if self.decoder is not None: + return super().forward(inputs, input_lengths, targets, target_lengths) + encoder_outputs, _ = self.encoder(inputs, input_lengths) + return self.fc(encoder_outputs).log_softmax(dim=-1) + + @torch.no_grad() + def decode(self, encoder_outputs: Tensor, max_length: int = None) -> Tensor: + """ + Decode `encoder_outputs`. + + Args: + encoder_outputs (torch.FloatTensor): A output sequence of encoder. `FloatTensor` of size + ``(seq_length, dimension)`` + max_length (int): max decoding time step + + Returns: + * predicted_log_probs (torch.FloatTensor): Log probability of model predictions. + """ + if self.decoder is not None: + return super().decode(encoder_outputs, max_length) + return encoder_outputs.max(-1)[1] + + @torch.no_grad() + def recognize(self, inputs: Tensor, input_lengths: Tensor) -> Tensor: + """ + Recognize input speech. + + Args: + inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded + `FloatTensor` of size ``(batch, seq_length, dimension)``. + input_lengths (torch.LongTensor): The length of input tensor. ``(batch)`` + + Returns: + * predictions (torch.FloatTensor): Result of model predictions. + """ + if self.decoder is not None: + return super().recognize(inputs, input_lengths) + + encoder_outputs, _ = self.encoder(inputs, input_lengths) + predicted_log_probs = self.fc(encoder_outputs).log_softmax(dim=-1) + return self.decode(predicted_log_probs)