Skip to content

Commit

Permalink
Replace simple_ctc with Python greedy decoder (pytorch#1558)
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok authored Jul 27, 2021
1 parent 1b52e72 commit d49e6e4
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 57 deletions.
3 changes: 0 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,3 @@
path = third_party/kaldi/submodule
url = https://github.com/kaldi-asr/kaldi
ignore = dirty
[submodule "examples/libtorchaudio/simplectc"]
path = examples/libtorchaudio/simplectc
url = https://github.com/mthrok/ctcdecode
1 change: 0 additions & 1 deletion examples/libtorchaudio/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,5 @@ message("libtorchaudio CMakeLists: ${TORCH_CXX_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")

add_subdirectory(../.. libtorchaudio)
add_subdirectory(simplectc)
add_subdirectory(augmentation)
add_subdirectory(speech_recognition)
1 change: 0 additions & 1 deletion examples/libtorchaudio/simplectc
Submodule simplectc deleted from b1a30d
4 changes: 2 additions & 2 deletions examples/libtorchaudio/speech_recognition/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
add_executable(transcribe transcribe.cpp)
add_executable(transcribe_list transcribe_list.cpp)
target_link_libraries(transcribe "${TORCH_LIBRARIES}" "${TORCHAUDIO_LIBRARY}" "${CTCDECODE_LIBRARY}")
target_link_libraries(transcribe_list "${TORCH_LIBRARIES}" "${TORCHAUDIO_LIBRARY}" "${CTCDECODE_LIBRARY}")
target_link_libraries(transcribe "${TORCH_LIBRARIES}" "${TORCHAUDIO_LIBRARY}")
target_link_libraries(transcribe_list "${TORCH_LIBRARIES}" "${TORCHAUDIO_LIBRARY}")
set_property(TARGET transcribe PROPERTY CXX_STANDARD 14)
set_property(TARGET transcribe_list PROPERTY CXX_STANDARD 14)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import torchaudio
from torchaudio.models.wav2vec2.utils.import_fairseq import import_fairseq_model
import fairseq
import simple_ctc

from greedy_decoder import Decoder

_LG = logging.getLogger(__name__)

Expand Down Expand Up @@ -77,17 +78,7 @@ def __init__(self, encoder: torch.nn.Module):

def forward(self, waveform: torch.Tensor) -> torch.Tensor:
result, _ = self.encoder(waveform)
return result


class Decoder(torch.nn.Module):
def __init__(self, decoder: torch.nn.Module):
super().__init__()
self.decoder = decoder

def forward(self, emission: torch.Tensor) -> str:
result = self.decoder.decode(emission)
return ''.join(result.label_sequences[0][0]).replace('|', ' ')
return result[0]


def _get_decoder():
Expand Down Expand Up @@ -125,18 +116,7 @@ def _get_decoder():
"Q",
"Z",
]

return Decoder(
simple_ctc.BeamSearchDecoder(
labels,
cutoff_top_n=40,
cutoff_prob=0.8,
beam_size=100,
num_processes=1,
blank_id=0,
is_nll=True,
)
)
return Decoder(labels)


def _load_fairseq_model(input_file, data_dir=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import torch
import torchaudio
from torchaudio.models.wav2vec2.utils.import_huggingface import import_huggingface_model
import simple_ctc

from greedy_decoder import Decoder

_LG = logging.getLogger(__name__)

Expand Down Expand Up @@ -59,19 +58,8 @@ def __init__(self, encoder: torch.nn.Module):
self.encoder = encoder

def forward(self, waveform: torch.Tensor) -> torch.Tensor:
length = torch.tensor([waveform.shape[1]])
result, length = self.encoder(waveform, length)
return result


class Decoder(torch.nn.Module):
def __init__(self, decoder: torch.nn.Module):
super().__init__()
self.decoder = decoder

def forward(self, emission: torch.Tensor) -> str:
result = self.decoder.decode(emission)
return ''.join(result.label_sequences[0][0]).replace('|', ' ')
result, _ = self.encoder(waveform)
return result[0]


def _get_model(model_id):
Expand All @@ -84,17 +72,7 @@ def _get_model(model_id):


def _get_decoder(labels):
return Decoder(
simple_ctc.BeamSearchDecoder(
labels,
cutoff_top_n=40,
cutoff_prob=0.8,
beam_size=100,
num_processes=1,
blank_id=0,
is_nll=True,
)
)
return Decoder(labels)


def _main():
Expand Down
28 changes: 28 additions & 0 deletions examples/libtorchaudio/speech_recognition/greedy_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch


class Decoder(torch.nn.Module):
def __init__(self, labels):
super().__init__()
self.labels = labels

def forward(self, logits: torch.Tensor) -> str:
"""Given a sequence logits over labels, get the best path string
Args:
logits (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
Returns:
str: The resulting transcript
"""
best_path = torch.argmax(logits, dim=-1) # [num_seq,]
best_path = torch.unique_consecutive(best_path, dim=-1)
hypothesis = ''
for i in best_path:
char = self.labels[i]
if char in ['<s>', '<pad>']:
continue
if char == '|':
char = ' '
hypothesis += char
return hypothesis

0 comments on commit d49e6e4

Please sign in to comment.