Skip to content

Commit

Permalink
ctcdecode binding fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
reuben committed Nov 8, 2018
1 parent eee92c2 commit 3255f23
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 20 deletions.
7 changes: 6 additions & 1 deletion native_client/ctcdecode/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ include ../definitions.mk

NUM_PROCESSES ?= 1

# ARM64 can't find the proper libm.so without this
ifeq ($(TARGET),rpi3-armv8)
LDFLAGS_NEEDED += $(RASPBIAN)/lib/aarch64-linux-gnu/libm-2.24.so
endif

all: bindings

clean:
Expand All @@ -12,7 +17,7 @@ clean:

bindings: clean
pip install --quiet $(PYTHON_PACKAGES) wheel==0.31.0 setuptools==39.1.0
AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS) $(CXXFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py build --num_processes $(NUM_PROCESSES) $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS) $(CXXFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py build_ext --num_processes $(NUM_PROCESSES) $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
find temp_build -type f -name "*.o" -delete
AS=$(AS) CC=$(CC) CXX=$(CXX) LD=$(LD) CFLAGS="$(CFLAGS) $(CXXFLAGS)" LDFLAGS="$(LDFLAGS_NEEDED)" $(PYTHON_PATH) $(NUMPY_INCLUDE) python ./setup.py bdist_wheel --num_processes $(NUM_PROCESSES) $(PYTHON_PLATFORM_NAME) $(SETUP_FLAGS)
rm -rf temp_build
22 changes: 10 additions & 12 deletions native_client/ctcdecode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def ctc_beam_search_decoder(probs_seq,
beam_size,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None):
scorer=None):
"""Wrapper for the CTC Beam Search Decoder.
:param probs_seq: 2-D list of probability distributions over each time
Expand All @@ -45,17 +45,16 @@ def ctc_beam_search_decoder(probs_seq,
characters with highest probs in alphabet will be
used in beam search, default 40.
:type cutoff_top_n: int
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
or language model.
:type external_scoring_func: callable
:param scorer: External scorer for partially decoded sentence, e.g. word
count or language model.
:type scorer: Scorer
:return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list
"""
beam_results = swigwrapper.ctc_beam_search_decoder(
probs_seq, alphabet.config_file(), beam_size, cutoff_prob, cutoff_top_n,
ext_scoring_func)
scorer)
beam_results = [(res.probability, alphabet.decode(res.tokens)) for res in beam_results]
return beam_results

Expand All @@ -67,7 +66,7 @@ def ctc_beam_search_decoder_batch(probs_seq,
num_processes,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None):
scorer=None):
"""Wrapper for the batched CTC beam search decoder.
:param probs_seq: 3-D list with each element as an instance of 2-D list
Expand All @@ -88,17 +87,16 @@ def ctc_beam_search_decoder_batch(probs_seq,
:type cutoff_top_n: int
:param num_processes: Number of parallel processes.
:type num_processes: int
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
or language model.
:type external_scoring_function: callable
:param scorer: External scorer for partially decoded sentence, e.g. word
count or language model.
:type scorer: Scorer
:return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list
"""
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(
probs_seq, seq_lengths, alphabet.config_file(), beam_size, num_processes,
cutoff_prob, cutoff_top_n, ext_scoring_func)
cutoff_prob, cutoff_top_n, scorer)
batch_beam_results = [
[(res.probability, alphabet.decode(res.tokens)) for res in beam_results]
for beam_results in batch_beam_results
Expand Down
12 changes: 11 additions & 1 deletion native_client/ctcdecode/scorer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ Scorer::Scorer(double alpha,
setup(lm_path, trie_path);
}

Scorer::Scorer(double alpha,
double beta,
const std::string& lm_path,
const std::string& trie_path,
const char* alphabet_config_path)
: Scorer(alpha, beta, lm_path, trie_path, Alphabet(alphabet_config_path))
{

}

Scorer::~Scorer() {
}

Expand Down Expand Up @@ -114,7 +124,7 @@ void Scorer::save_dictionary(const std::string& path) {
}

double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
double cond_prob;
double cond_prob = OOV_SCORE;
lm::ngram::State state, tmp_state, out_state;
// avoid to inserting <s> in begin
language_model_->NullContextWrite(&state);
Expand Down
5 changes: 5 additions & 0 deletions native_client/ctcdecode/scorer.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ class Scorer {
const std::string &lm_path,
const std::string &trie_path,
const Alphabet &alphabet);
Scorer(double alpha,
double beta,
const std::string &lm_path,
const std::string &trie_path,
const char* alphabet_config_path);
~Scorer();

double get_log_cond_prob(const std::vector<std::string> &words);
Expand Down
6 changes: 0 additions & 6 deletions native_client/ctcdecode/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,9 @@ def _single_compile(obj):
'unittest.cc'))
]

LIBS = ['stdc++']
if platform.system() != 'Darwin':
LIBS.append('rt')

ARGS = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER=6', '-std=c++11',
'-Wno-unused-local-typedef', '-Wno-sign-compare']


decoder_module = Extension(
name='ds_ctcdecoder._swigwrapper',
sources=['swigwrapper.i'] + FILES + glob.glob('*.cpp'),
Expand All @@ -104,7 +99,6 @@ def _single_compile(obj):
'third_party/openfst-1.6.7/src/include',
'third_party/ThreadPool',
],
libraries=LIBS,
extra_compile_args=ARGS
)

Expand Down
4 changes: 4 additions & 0 deletions native_client/ctcdecode/swigwrapper.i
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
%}

%include "pyabc.i"
%include "std_string.i"
%include "std_vector.i"
%include "numpy.i"

Expand Down Expand Up @@ -61,6 +62,9 @@ mod_decoder_batch(const double *probs,
}
%}


%ignore Scorer::dictionary;

%include "output.h"
%include "scorer.h"
%include "ctc_beam_search_decoder.h"
Expand Down

0 comments on commit 3255f23

Please sign in to comment.