Skip to content

Commit 49d53d0

Browse files
committedApr 8, 2020
alphabet is imported as a list rather than a string to allow for multi-byte characters and subword units
1 parent 431408f commit 49d53d0

File tree

4 files changed

+36
-28
lines changed

4 files changed

+36
-28
lines changed
 

‎.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ build
22
_ext
33
ctc_decode
44
*.pyc
5-
5+
.idea
66
# Byte-compiled / optimized / DLL files
77
__pycache__/
88
*.py[cod]

‎ctcdecode/__init__.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def __init__(self, labels, model_path=None, alpha=0, beta=0, cutoff_top_n=40, cu
99
self._beam_width = beam_width
1010
self._scorer = None
1111
self._num_processes = num_processes
12-
self._labels = ''.join(labels).encode()
12+
self._labels = labels
1313
self._num_labels = len(labels)
1414
self._blank_id = blank_id
1515
self._log_probs = 1 if log_probs_input else 0
@@ -33,9 +33,10 @@ def decode(self, probs, seq_lens=None):
3333
if self._scorer:
3434
ctc_decode.paddle_beam_decode_lm(probs, seq_lens, self._labels, self._num_labels, self._beam_width,
3535
self._num_processes, self._cutoff_prob, self.cutoff_top_n, self._blank_id,
36-
self._log_probs ,self._scorer, output, timesteps, scores, out_seq_len)
36+
self._log_probs, self._scorer, output, timesteps, scores, out_seq_len)
3737
else:
38-
ctc_decode.paddle_beam_decode(probs, seq_lens, self._labels, self._num_labels, self._beam_width, self._num_processes,
38+
ctc_decode.paddle_beam_decode(probs, seq_lens, self._labels, self._num_labels, self._beam_width,
39+
self._num_processes,
3940
self._cutoff_prob, self.cutoff_top_n, self._blank_id, self._log_probs,
4041
output, timesteps, scores, out_seq_len)
4142

‎ctcdecode/src/binding.cpp

+27-21
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,38 @@
33
#include <string>
44
#include <vector>
55
#include <torch/torch.h>
6+
#include <memory>
67
#include "scorer.h"
78
#include "ctc_beam_search_decoder.h"
89
#include "utf8.h"
10+
#include "boost/shared_ptr.hpp"
11+
#include "boost/python.hpp"
12+
#include "boost/python/stl_iterator.hpp"
913

10-
int utf8_to_utf8_char_vec(const char* labels, std::vector<std::string>& new_vocab) {
11-
const char* str_i = labels;
12-
const char* end = str_i + strlen(labels)+1;
13-
do {
14-
char u[5] = {0,0,0,0,0};
15-
uint32_t code = utf8::next(str_i, end);
16-
if (code == 0) {
17-
continue;
18-
}
19-
utf8::append(code, u);
20-
new_vocab.push_back(std::string(u));
14+
using namespace std;
15+
16+
template<typename T>
17+
inline
18+
std::vector< T > py_list_to_std_vector( const boost::python::object& iterable )
19+
{
20+
return std::vector< T >( boost::python::stl_input_iterator< T >( iterable ),
21+
boost::python::stl_input_iterator< T >( ) );
22+
}
23+
24+
template <class T>
25+
inline
26+
boost::python::list std_vector_to_py_list(std::vector<T> vector) {
27+
typename std::vector<T>::iterator iter;
28+
boost::python::list list;
29+
for (iter = vector.begin(); iter != vector.end(); ++iter) {
30+
list.append(*iter);
2131
}
22-
while (str_i < end);
32+
return list;
2333
}
2434

2535
int beam_decode(at::Tensor th_probs,
2636
at::Tensor th_seq_lens,
27-
const char* labels,
37+
std::vector<std::string> new_vocab,
2838
int vocab_size,
2939
size_t beam_size,
3040
size_t num_processes,
@@ -38,8 +48,6 @@ int beam_decode(at::Tensor th_probs,
3848
at::Tensor th_scores,
3949
at::Tensor th_out_length)
4050
{
41-
std::vector<std::string> new_vocab;
42-
utf8_to_utf8_char_vec(labels, new_vocab);
4351
Scorer *ext_scorer = NULL;
4452
if (scorer != NULL) {
4553
ext_scorer = static_cast<Scorer *>(scorer);
@@ -67,7 +75,7 @@ int beam_decode(at::Tensor th_probs,
6775

6876
std::vector<std::vector<std::pair<double, Output>>> batch_results =
6977
ctc_beam_search_decoder_batch(inputs, new_vocab, beam_size, num_processes, cutoff_prob, cutoff_top_n, blank_id, log_input, ext_scorer);
70-
auto outputs_accessor = th_output.accessor<int, 3>();
78+
auto outputs_accessor = th_output.accessor<int, 3>();
7179
auto timesteps_accessor = th_timesteps.accessor<int, 3>();
7280
auto scores_accessor = th_scores.accessor<float, 2>();
7381
auto out_length_accessor = th_out_length.accessor<int, 2>();
@@ -93,7 +101,7 @@ int beam_decode(at::Tensor th_probs,
93101

94102
int paddle_beam_decode(at::Tensor th_probs,
95103
at::Tensor th_seq_lens,
96-
const char* labels,
104+
std::vector<std::string> labels,
97105
int vocab_size,
98106
size_t beam_size,
99107
size_t num_processes,
@@ -112,7 +120,7 @@ int paddle_beam_decode(at::Tensor th_probs,
112120

113121
int paddle_beam_decode_lm(at::Tensor th_probs,
114122
at::Tensor th_seq_lens,
115-
const char* labels,
123+
std::vector<std::string> labels,
116124
int vocab_size,
117125
size_t beam_size,
118126
size_t num_processes,
@@ -134,10 +142,8 @@ int paddle_beam_decode_lm(at::Tensor th_probs,
134142
void* paddle_get_scorer(double alpha,
135143
double beta,
136144
const char* lm_path,
137-
const char* labels,
145+
vector<std::string> new_vocab,
138146
int vocab_size) {
139-
std::vector<std::string> new_vocab;
140-
utf8_to_utf8_char_vec(labels, new_vocab);
141147
Scorer* scorer = new Scorer(alpha, beta, lm_path, new_vocab);
142148
return static_cast<void*>(scorer);
143149
}

‎ctcdecode/src/binding.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
int paddle_beam_decode(THFloatTensor *th_probs,
22
THIntTensor *th_seq_lens,
3-
const char* labels,
3+
std::vector<std::string> labels,
44
int vocab_size,
55
size_t beam_size,
66
size_t num_processes,
@@ -13,9 +13,10 @@ int paddle_beam_decode(THFloatTensor *th_probs,
1313
THFloatTensor *th_scores,
1414
THIntTensor *th_out_length);
1515

16+
1617
int paddle_beam_decode_lm(THFloatTensor *th_probs,
1718
THIntTensor *th_seq_lens,
18-
const char* labels,
19+
std::vector<std::string> labels,
1920
int vocab_size,
2021
size_t beam_size,
2122
size_t num_processes,
@@ -32,7 +33,7 @@ int paddle_beam_decode_lm(THFloatTensor *th_probs,
3233
void* paddle_get_scorer(double alpha,
3334
double beta,
3435
const char* lm_path,
35-
const char* labels,
36+
std::vector<std::string> labels,
3637
int vocab_size);
3738

3839
void paddle_release_scorer(void* scorer);

0 commit comments

Comments
 (0)