Skip to content

Commit

Permalink
Update conformer
Browse files Browse the repository at this point in the history
  • Loading branch information
sooftware committed Feb 3, 2021
1 parent bcb7c86 commit f753f2a
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 11 deletions.
1 change: 1 addition & 0 deletions configs/model/conformer-large.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ num_encoder_layers: 17
num_decoder_layers: 1
num_attention_heads: 8
decoder_rnn_type: lstm
decoder: None
1 change: 1 addition & 0 deletions configs/model/conformer-medium.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ decoder_dim: 640
num_encoder_layers: 17
num_decoder_layers: 1
num_attention_heads: 4
decoder: None
1 change: 1 addition & 0 deletions configs/model/conformer-small.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ decoder_dim: 320
num_encoder_layers: 16
num_decoder_layers: 1
num_attention_heads: 4
decoder: None
3 changes: 3 additions & 0 deletions kospeech/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def build_model(
conv_kernel_size=config.model.conv_kernel_size,
half_step_residual=config.model.half_step_residual,
device=device,
decoder=config.model.decoder,
)

elif config.model.architecture.lower() == 'rnnt':
Expand Down Expand Up @@ -187,6 +188,7 @@ def build_conformer(
conv_kernel_size: int,
half_step_residual: bool,
device: torch.device,
decoder: str,
) -> nn.DataParallel:
if input_dropout_p < 0.0:
raise ParameterError("dropout probability should be positive")
Expand Down Expand Up @@ -219,6 +221,7 @@ def build_conformer(
conv_kernel_size=conv_kernel_size,
half_step_residual=half_step_residual,
device=device,
decoder=decoder,
))


Expand Down
3 changes: 2 additions & 1 deletion kospeech/models/conformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ class ConformerConfig(ModelConfig):
conv_kernel_size: int = 31
half_step_residual: bool = True
num_decoder_layers: int = 1
decoder_rnn_type: str = 'lstm'
decoder_rnn_type: str = "lstm"
decoder: str = None


@dataclass
Expand Down
63 changes: 53 additions & 10 deletions kospeech/models/conformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class Conformer(TransducerModel):
conv_kernel_size (int or tuple, optional): Size of the convolving kernel
half_step_residual (bool): Flag indication whether to use half step residual or not
device (torch.device): torch device (cuda or cpu)
decoder (str): If decoder is None, train with CTC decoding
Inputs: inputs
- **inputs** (batch, time, dim): Tensor containing input vector
Expand Down Expand Up @@ -72,6 +73,7 @@ def __init__(
conv_kernel_size: int = 31,
half_step_residual: bool = True,
device: torch.device = 'cuda',
decoder: str = None,
) -> None:
encoder = ConformerEncoder(
input_dim=input_dim,
Expand All @@ -88,15 +90,16 @@ def __init__(
half_step_residual=half_step_residual,
device=device,
)
decoder = DecoderRNNT(
num_classes=num_classes,
hidden_state_dim=decoder_dim,
output_dim=encoder_dim,
num_layers=num_decoder_layers,
rnn_type=decoder_rnn_type,
dropout_p=decoder_dropout_p,
)
super(Conformer, self).__init__(encoder, decoder, encoder_dim, num_classes)
if decoder == 'rnnt':
decoder = DecoderRNNT(
num_classes=num_classes,
hidden_state_dim=decoder_dim,
output_dim=encoder_dim,
num_layers=num_decoder_layers,
rnn_type=decoder_rnn_type,
dropout_p=decoder_dropout_p,
)
super(Conformer, self).__init__(encoder, decoder, encoder_dim >> 1, num_classes)

def forward(
self,
Expand All @@ -118,4 +121,44 @@ def forward(
Returns:
* predictions (torch.FloatTensor): Result of model predictions.
"""
return super().forward(inputs, input_lengths, targets, target_lengths)
if self.decoder is not None:
return super().forward(inputs, input_lengths, targets, target_lengths)
encoder_outputs, _ = self.encoder(inputs, input_lengths)
return self.fc(encoder_outputs).log_softmax(dim=-1)

@torch.no_grad()
def decode(self, encoder_outputs: Tensor, max_length: int = None) -> Tensor:
"""
Decode `encoder_outputs`.
Args:
encoder_outputs (torch.FloatTensor): A output sequence of encoder. `FloatTensor` of size
``(seq_length, dimension)``
max_length (int): max decoding time step
Returns:
* predicted_log_probs (torch.FloatTensor): Log probability of model predictions.
"""
if self.decoder is not None:
return super().decode(encoder_outputs, max_length)
return encoder_outputs.max(-1)[1]

@torch.no_grad()
def recognize(self, inputs: Tensor, input_lengths: Tensor) -> Tensor:
"""
Recognize input speech.
Args:
inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded
`FloatTensor` of size ``(batch, seq_length, dimension)``.
input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
Returns:
* predictions (torch.FloatTensor): Result of model predictions.
"""
if self.decoder is not None:
return super().recognize(inputs, input_lengths)

encoder_outputs, _ = self.encoder(inputs, input_lengths)
predicted_log_probs = self.fc(encoder_outputs).log_softmax(dim=-1)
return self.decode(predicted_log_probs)

0 comments on commit f753f2a

Please sign in to comment.