Skip to content

Commit

Permalink
faster viterbi decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
endyul committed Aug 20, 2015
1 parent 25c09ea commit 5f96f06
Showing 1 changed file with 64 additions and 93 deletions.
157 changes: 64 additions & 93 deletions src/framework/decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,13 @@
#include "utils/math/mat.h"
#include "utils/math/sparsevec.h"
#include "utils/math/featurevec.h"
#include "utils/logging.hpp"
#include <cmath>
#include <limits>


namespace ltp {
namespace framework {

struct ViterbiLatticeItem {
ViterbiLatticeItem (const size_t& _i, const size_t& _l, const double& _score,
const ViterbiLatticeItem* _prev)
: i(_i), l(_l), score(_score), prev(_prev) {}

ViterbiLatticeItem (const size_t& _l, const double& _score)
: i(0), l(_l), score(_score), prev(0) {}

size_t i;
size_t l;
double score;
const ViterbiLatticeItem* prev;
};

class ViterbiDecodeConstrain {
public:
virtual bool can_emit(const size_t& i, const size_t& j) const {
Expand Down Expand Up @@ -157,117 +144,101 @@ class ViterbiDecoder {

init_lattice(L, T);

for (size_t i = 0; i < L; ++ i) {
for (size_t t = 0; t < T; ++ t) {
if (i == 0) {
ViterbiLatticeItem* item = new ViterbiLatticeItem(i, t, scm.emit(i, t), NULL);
lattice_insert(lattice[i][t], item);
} else {
for (size_t pt = 0; pt < T; ++ pt) {
const ViterbiLatticeItem* prev = lattice[i-1][pt];
if (!prev) { continue; }
for (size_t t = 0; t < T; ++t) {
state[0][t] = scm.emit(0, t);
}

double s = scm.emit(i, t) + scm.tran(pt, t) + prev->score;
ViterbiLatticeItem* item = new ViterbiLatticeItem(i, t, s, prev);
lattice_insert(lattice[i][t], item);
double best = DOUBLE_MIN;
for (size_t i = 1; i < L; ++ i) {
for (size_t t = 0; t < T; ++ t) {
best = DOUBLE_MIN;
for (size_t pt = 0; pt < T; ++ pt) {
double s = state[i-1][pt] + scm.tran(pt, t);
if (s > best) {
best = s;
back[i][t] = pt;
}
}
state[i][t] = best + scm.emit(i, t);
}
}

get_result(L-1, output);
free_lattice();
get_result(output);
}

void decode(const ViterbiScoreMatrix& scm,
const ViterbiDecodeConstrain& con,
std::vector<int>& output) {

size_t L = scm.length();
size_t T = scm.labels();

init_lattice(L, T);

for (size_t i = 0; i < L; ++ i) {
for (size_t t = 0; t < T; ++ t) {
if (!con.can_emit(i, t)) { continue; }

if (i == 0) {
ViterbiLatticeItem* item = new ViterbiLatticeItem(i, t, scm.emit(i, t), NULL);
lattice_insert(lattice[i][t], item);
} else {
for (size_t pt = 0; pt < T; ++ pt) {
if (!con.can_emit(i-1, pt) || !con.can_tran(pt, t)) { continue; }

const ViterbiLatticeItem* prev = lattice[i-1][pt];
if (!prev) { continue; }
for (size_t t = 0; t < T; ++t) {
if (!con.can_emit(0, t)) continue;
state[0][t] = scm.emit(0, t);
}

double s = scm.emit(i, t) + scm.tran(pt, t) + prev->score;
ViterbiLatticeItem* item = new ViterbiLatticeItem(i, t, s, prev);
lattice_insert(lattice[i][t], item);
double best = DOUBLE_MIN;
for (size_t i = 1; i < L; ++ i) {
for (size_t t = 0; t < T; ++ t) {
if (!con.can_emit(i, t)) continue;
best = DOUBLE_MIN;
for (size_t pt = 0; pt < T; ++ pt) {
if (!con.can_emit(i-1, pt) || !con.can_tran(pt, t)) continue;
double s = state[i-1][pt] + scm.tran(pt, t);
if (s > best) {
best = s;
back[i][t] = pt;
}
}
state[i][t] = best + scm.emit(i, t);
}
}
get_result(L-1, output);
free_lattice();

get_result(output);
}

protected:
void init_lattice(const size_t& L, const size_t& T) {
lattice.resize(L, T);
lattice = NULL;
back.resize(L, T);
back = -1;

state.resize(L, T);
state = DOUBLE_MIN;
}

void get_result(std::vector<int>& output) {
size_t L = lattice.nrows();
get_result(L- 1, output);
size_t L = back.nrows();
get_result(L-1, output);
}

void get_result(const size_t& p, std::vector<int>& output) {
size_t T = lattice.ncols();

const ViterbiLatticeItem* best = NULL;
for (size_t t = 0; t < T; ++ t) {
if (!lattice[p][t]) {
continue;
}

if (best == NULL || lattice[p][t]->score > best->score) {
best = lattice[p][t];
}
}
size_t T = back.ncols();

output.resize(p+1);
while (best) {
output[best->i] = best->l;
best = best->prev;
}
}
double best = DOUBLE_MIN;

void free_lattice() {
size_t L = lattice.total_size();
const ViterbiLatticeItem ** p = lattice.c_buf();
for (size_t i = 0; i < L; ++ i) {
if (p[i]) {
delete p[i];
p[i] = 0;
for (size_t t = 0; t < T; ++t) {
double s = state[p][t];
if (s > best) {
best = s;
output[p] = t;
}
}
}

void lattice_insert(const ViterbiLatticeItem* &position,
const ViterbiLatticeItem * const item) {
if (position == NULL) {
position = item;
} else if (position->score < item->score) {
delete position;
position = item;
} else {
delete item;
for (int i = p-1; i >= 0; --i) {
output[i] = back[i+1][output[i+1]];
}
}


math::Mat< const ViterbiLatticeItem * > lattice;
const double DOUBLE_MIN = std::numeric_limits<double>::lowest();
math::Mat<int> back;
math::Mat<double> state;
math::Mat<bool> can_emit;
math::Mat<bool> can_tran;
};


Expand Down Expand Up @@ -428,9 +399,9 @@ class ViterbiDecoderWithMarginal : public ViterbiDecoder {
if (!con.can_emit(0, j)) { continue; }
alpha_score[0][j] = exp_emit[0][j];
}
double sum = row_sum_withcon(alpha_score, 0, con);
double sum = row_sum(alpha_score, 0, con);
scale[0] = (sum == 0.) ? 1. : 1. / sum;
row_scale_withcon(alpha_score, 0, scale[0], con);
row_scale(alpha_score, 0, scale[0], con);

for (size_t i = 1; i < L; ++i) {
for (size_t t = 0; t < T; ++t) {
Expand All @@ -441,9 +412,9 @@ class ViterbiDecoderWithMarginal : public ViterbiDecoder {
}
alpha_score[i][t] *= exp_emit[i][t];
}
sum = row_sum_withcon(alpha_score, i, con);
sum = row_sum(alpha_score, i, con);
scale[i] = (sum == 0.) ? 1. : 1. / sum;
row_scale_withcon(alpha_score, i, scale[i], con);
row_scale(alpha_score, i, scale[i], con);
}
}

Expand Down Expand Up @@ -473,7 +444,7 @@ class ViterbiDecoderWithMarginal : public ViterbiDecoder {
beta_score[i][t] += tmp_row[nt] * exp_tran[t][nt];
}
}
row_scale_withcon(beta_score, i, scale[i], con);
row_scale(beta_score, i, scale[i], con);
}

}
Expand All @@ -492,7 +463,7 @@ class ViterbiDecoderWithMarginal : public ViterbiDecoder {
}
}

double row_sum_withcon(const math::Mat<double>& mat,
double row_sum(const math::Mat<double>& mat,
int i,
const ViterbiDecodeConstrain& con) const {
double sum = 0.;
Expand All @@ -503,7 +474,7 @@ class ViterbiDecoderWithMarginal : public ViterbiDecoder {
return sum;
}

void row_scale_withcon(math::Mat<double>& mat,
void row_scale(math::Mat<double>& mat,
int i,
double scale,
const ViterbiDecodeConstrain& con) {
Expand Down

0 comments on commit 5f96f06

Please sign in to comment.