Skip to content

Commit

Permalink
Reorganize package structure.
Browse files Browse the repository at this point in the history
  • Loading branch information
dogancan committed Feb 25, 2019
1 parent 067721e commit bff0573
Show file tree
Hide file tree
Showing 57 changed files with 4,824 additions and 4,737 deletions.
2 changes: 2 additions & 0 deletions kaldi/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
set(PACKAGE "kaldi")

add_subdirectory("base")
add_subdirectory("chain")
add_subdirectory("cudamatrix")
Expand Down
2 changes: 2 additions & 0 deletions kaldi/base/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
set(PACKAGE "${PACKAGE}.base")

add_pyclif_library("_kaldi_error" kaldi-error.clif
LIBRARIES kaldi-base
)
Expand Down
2 changes: 2 additions & 0 deletions kaldi/chain/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
set(PACKAGE "${PACKAGE}.chain")

add_pyclif_library("_chain_datastruct" chain-datastruct.clif)
add_pyclif_library("_chain_den_graph" chain-den-graph.clif
CLIF_DEPS _cu_matrixdim _cu_vector _vector_fst _transition_model _context_dep
Expand Down
2 changes: 2 additions & 0 deletions kaldi/cudamatrix/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
set(PACKAGE "${PACKAGE}.cudamatrix")

if(CUDA)
add_pyclif_library("_cu_device" cu-device.clif
LIBRARIES kaldi-base kaldi-cudamatrix
Expand Down
2 changes: 2 additions & 0 deletions kaldi/decoder/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
set(PACKAGE "${PACKAGE}.decoder")

add_pyclif_library("_grammar_fst" grammar-fst.clif
CLIF_DEPS _fst _const_fst _vector_fst
LIBRARIES kaldi-decoder
Expand Down
338 changes: 2 additions & 336 deletions kaldi/decoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,344 +2,10 @@
from ._decodable_matrix import *
from ._decodable_mapped import *
from ._decodable_sum import *
from ._faster_decoder import *
from ._biglm_faster_decoder import *
from ._lattice_faster_decoder import *
from ._lattice_faster_decoder_ext import *
from ._lattice_biglm_faster_decoder import *
from ._lattice_faster_online_decoder import *
from ._lattice_faster_online_decoder_ext import *
from ._training_graph_compiler import *
from ._training_graph_compiler_ext import *
from .. import fstext as _fst
from .. import lat as _lat
from ._decoder import *
from ._compiler import *


class _DecoderBase(object):
"""Base class defining the Python API for decoders."""

def get_best_path(self, use_final_probs=True):
"""Gets best path as a lattice.
Args:
use_final_probs (bool): If ``True`` and a final state of the graph
is reached, then the output will include final probabilities
given by the graph. Otherwise all final probabilities are
treated as one.
Returns:
LatticeVectorFst: The best path.
Raises:
RuntimeError: In the unusual circumstances where no tokens survive.
"""
ofst = _fst.LatticeVectorFst()
success = self._get_best_path(ofst, use_final_probs)
if not success:
raise RuntimeError("Decoding failed. No tokens survived.")
return ofst


class _LatticeDecoderBase(_DecoderBase):
"""Base class defining the Python API for lattice generating decoders."""

def get_raw_lattice(self, use_final_probs=True):
"""Gets raw state-level lattice.
The output raw lattice will be topologically sorted.
Args:
use_final_probs (bool): If ``True`` and a final state of the graph
is reached, then the output will include final probabilities
given by the graph. Otherwise all final probabilities are
treated as one.
Returns:
LatticeVectorFst: The state-level lattice.
Raises:
RuntimeError: In the unusual circumstances where no tokens survive.
"""
ofst = _fst.LatticeVectorFst()
success = self._get_raw_lattice(ofst, use_final_probs)
if not success:
raise RuntimeError("Decoding failed. No tokens survived.")
return ofst

def get_lattice(self, use_final_probs=True):
"""Gets the lattice-determinized compact lattice.
The output is a deterministic compact lattice with a unique path for
each word sequence.
Args:
use_final_probs (bool): If ``True`` and a final state of the graph
is reached, then the output will include final probabilities
given by the graph. Otherwise all final probabilities are
treated as one.
Returns:
CompactLatticeVectorFst: The lattice-determinized compact lattice.
Raises:
RuntimeError: In the unusual circumstances where no tokens survive.
"""
ofst = _fst.CompactLatticeVectorFst()
success = self._get_lattice(ofst, use_final_probs)
if not success:
raise RuntimeError("Decoding failed. No tokens survived.")
return ofst


class _LatticeOnlineDecoderBase(_LatticeDecoderBase):
"""Base class defining the Python API for lattice generating online decoders."""

def get_raw_lattice_pruned(self, beam, use_final_probs=True):
"""Prunes and returns raw state-level lattice.
Behaves like :meth:`get_raw_lattice` but only processes tokens whose
extra-cost is smaller than the best-cost plus the specified beam. It is
worthwhile to call this function only if :attr:`beam` is less than the
lattice-beam specified in the decoder options. Otherwise, it returns
essentially the same thing as :meth:`get_raw_lattice`, but more slowly.
The output raw lattice will be topologically sorted.
Args:
beam (float): Pruning beam.
use_final_probs (bool): If ``True`` and a final state of the graph
is reached, then the output will include final probabilities
given by the graph. Otherwise all final probabilities are
treated as one.
Returns:
LatticeVectorFst: The state-level lattice.
Raises:
RuntimeError: In the unusual circumstances where no tokens survive.
"""
ofst = _fst.LatticeVectorFst()
success = self._get_raw_lattice_pruned(ofst, use_final_probs, beam)
if not success:
raise RuntimeError("Decoding failed. No tokens survived.")
return ofst


class FasterDecoder(_DecoderBase, _faster_decoder.FasterDecoder):
"""Faster decoder.
Args:
fst (StdFst): Decoding graph `HCLG`.
opts (FasterDecoderOptions): Decoder options.
"""
def __init__(self, fst, opts):
super(FasterDecoder, self).__init__(fst, opts)
self._fst = fst # keep a reference to FST to keep it in scope


class BiglmFasterDecoder(_DecoderBase,
_biglm_faster_decoder.BiglmFasterDecoder):
"""Faster decoder for decoding with big language models.
This is as :class:`LatticeFasterDecoder`, but does online composition
between decoding graph :attr:`fst` and the difference language model
:attr:`lm_diff_fst`.
Args:
fst (StdFst): Decoding graph.
opts (BiglmFasterDecoderOptions): Decoder options.
lm_diff_fst (StdDeterministicOnDemandFst): The deterministic on-demand
FST representing the difference in scores between the LM to decode
with and the LM the decoding graph :attr:`fst` was compiled with.
"""
def __init__(self, fst, opts, lm_diff_fst):
super(BiglmFasterDecoder, self).__init__(fst, opts, lm_diff_fst)
self._fst = fst # keep references to FSTs
self._lm_diff_fst = lm_diff_fst # to keep them in scope

class LatticeFasterDecoder(_LatticeDecoderBase,
_lattice_faster_decoder.LatticeFasterDecoder):
"""Lattice generating faster decoder.
Args:
fst (StdFst): Decoding graph `HCLG`.
opts (LatticeFasterDecoderOptions): Decoder options.
"""
def __init__(self, fst, opts):
super(LatticeFasterDecoder, self).__init__(fst, opts)
self._fst = fst # keep a reference to FST to keep it in scope

class LatticeFasterGrammarDecoder(
_LatticeDecoderBase,
_lattice_faster_decoder_ext.LatticeFasterGrammarDecoder):
"""Lattice generating faster grammar decoder.
Args:
fst (GrammarFst): Decoding graph `HCLG`.
opts (LatticeFasterDecoderOptions): Decoder options.
"""
def __init__(self, fst, opts):
super(LatticeFasterGrammarDecoder, self).__init__(fst, opts)
self._fst = fst # keep a reference to FST to keep it in scope

class LatticeBiglmFasterDecoder(
_LatticeDecoderBase,
_lattice_biglm_faster_decoder.LatticeBiglmFasterDecoder):
"""Lattice generating faster decoder for decoding with big language models.
This is as :class:`LatticeFasterDecoder`, but does online composition
between decoding graph :attr:`fst` and the difference language model
:attr:`lm_diff_fst`.
Args:
fst (StdFst): Decoding graph `HCLG`.
opts (LatticeFasterDecoderOptions): Decoder options.
lm_diff_fst (StdDeterministicOnDemandFst): The deterministic on-demand
FST representing the difference in scores between the LM to decode
with and the LM the decoding graph :attr:`fst` was compiled with.
"""
def __init__(self, fst, opts, lm_diff_fst):
super(LatticeBiglmFasterDecoder, self).__init__(fst, opts, lm_diff_fst)
self._fst = fst # keep references to FSTs
self._lm_diff_fst = lm_diff_fst # to keep them in scope


class LatticeFasterOnlineDecoder(
_LatticeOnlineDecoderBase,
_lattice_faster_online_decoder.LatticeFasterOnlineDecoder):
"""Lattice generating faster online decoder.
Similar to :class:`LatticeFasterDecoder` but computes the best path
without generating the entire raw lattice and finding the best path
through it. Instead, it traces back through the lattice.
Args:
fst (StdFst): Decoding graph `HCLG`.
opts (LatticeFasterDecoderOptions): Decoder options.
"""
def __init__(self, fst, opts):
super(LatticeFasterOnlineDecoder, self).__init__(fst, opts)
self._fst = fst # keep a reference to FST to keep it in scope

# This method is missing from the C++ class so we implement it here.
def _get_lattice(self, use_final_probs=True):
raw_fst = self.get_raw_lattice(use_final_probs).invert().arcsort()
lat_opts = _lat.DeterminizeLatticePrunedOptions()
config = self.get_options()
lat_opts.max_mem = config.det_opts.max_mem
ofst = _fst.CompactLatticeVectorFst()
_lat.determinize_lattice_pruned(raw_fst, config.lattice_beam,
ofst, lat_opts)
ofst.connect()
if ofst.num_states() == 0:
raise RuntimeError("Decoding failed. No tokens survived.")
return ofst


class LatticeFasterOnlineGrammarDecoder(
_LatticeOnlineDecoderBase,
_lattice_faster_online_decoder_ext.LatticeFasterOnlineGrammarDecoder):
"""Lattice generating faster online grammar decoder.
Similar to :class:`LatticeFasterGrammarDecoder` but computes the best path
without generating the entire raw lattice and finding the best path
through it. Instead, it traces back through the lattice.
Args:
fst (GrammarFst): Decoding graph `HCLG`.
opts (LatticeFasterDecoderOptions): Decoder options.
"""
def __init__(self, fst, opts):
super(LatticeFasterOnlineGrammarDecoder, self).__init__(fst, opts)
self._fst = fst # keep a reference to FST to keep it in scope

# This method is missing from the C++ class so we implement it here.
def _get_lattice(self, use_final_probs=True):
raw_fst = self.get_raw_lattice(use_final_probs).invert().arcsort()
lat_opts = _lat.DeterminizeLatticePrunedOptions()
config = self.get_options()
lat_opts.max_mem = config.det_opts.max_mem
ofst = _fst.CompactLatticeVectorFst()
_lat.determinize_lattice_pruned(raw_fst, config.lattice_beam,
ofst, lat_opts)
ofst.connect()
if ofst.num_states() == 0:
raise RuntimeError("Decoding failed. No tokens survived.")
return ofst


class TrainingGraphCompiler(_training_graph_compiler_ext.TrainingGraphCompiler):
"""Training graph compiler."""
def __init__(self, trans_model, ctx_dep, lex_fst, disambig_syms, opts):
"""
Args:
trans_model (TransitionModel): Transition model `H`.
ctx_dep (ContextDependency): Context dependency model `C`.
lex_fst (StdVectorFst): Lexicon `L`.
disambig_syms (List[int]): Disambiguation symbols.
opts (TrainingGraphCompilerOptions): Compiler options.
"""
super(TrainingGraphCompiler, self).__init__(
trans_model, ctx_dep, lex_fst, disambig_syms, opts)
# keep references to these objects to keep them in scope
self._trans_model = trans_model
self._ctx_dep = ctx_dep
self._lex_fst = lex_fst

def compile_graph(self, word_fst):
"""Compiles a single training graph from a weighted acceptor.
Args:
word_fst (StdVectorFst): Weighted acceptor `G` at the word level.
Returns:
StdVectorFst: The training graph `HCLG`.
"""
ofst = super(TrainingGraphCompiler, self).compile_graph(word_fst)
return _fst.StdVectorFst(ofst)

def compile_graphs(self, word_fsts):
"""Compiles training graphs from weighted acceptors.
Args:
word_fsts (List[StdVectorFst]): Weighted acceptors at the word level.
Returns:
List[StdVectorFst]: The training graphs.
"""
ofsts = super(TrainingGraphCompiler, self).compile_graphs(word_fsts)
for i, fst in enumerate(ofsts):
ofsts[i] = _fst.StdVectorFst(fst)
return ofsts

def compile_graph_from_text(self, transcript):
"""Compiles a single training graph from a transcript.
Args:
transcript (List[int]): The input transcript.
Returns:
StdVectorFst: The training graph `HCLG`.
"""
ofst = super(TrainingGraphCompiler,
self).compile_graph_from_text(transcript)
return _fst.StdVectorFst(ofst)

def compile_graphs_from_text(self, transcripts):
"""Compiles training graphs from transcripts.
Args:
transcripts (List[List[int]]): The input transcripts.
Returns:
List[StdVectorFst]: The training graphs.
"""
ofsts = super(TrainingGraphCompiler,
self).compile_graphs_from_text(transcripts)
for i, fst in enumerate(ofsts):
ofsts[i] = _fst.StdVectorFst(fst)
return ofsts

__all__ = [name for name in dir()
if name[0] != '_'
and not name.endswith('Base')]
Loading

0 comments on commit bff0573

Please sign in to comment.