Skip to content

Commit

Permalink
[fix] multi-threaded bug in segmentor
Browse files Browse the repository at this point in the history
  • Loading branch information
Oneplus committed Nov 13, 2014
1 parent 280fdc8 commit 9bc1a9c
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 97 deletions.
30 changes: 15 additions & 15 deletions src/segmentor/segment_dll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

#include <iostream>

class SegmentorWrapper : public ltp::segmentor::Segmentor {
namespace seg = ltp::segmentor;

class SegmentorWrapper : public seg::Segmentor {
public:
SegmentorWrapper() :
beg_tag0(-1),
Expand All @@ -22,7 +24,7 @@ class SegmentorWrapper : public ltp::segmentor::Segmentor {
return false;
}

model = new ltp::segmentor::Model;
model = new seg::Model;
if (!model->load(mfs)) {
delete model;
model = 0;
Expand All @@ -46,7 +48,7 @@ class SegmentorWrapper : public ltp::segmentor::Segmentor {

// don't need to allocate a decoder
// one sentence, one decoder
baseAll = new ltp::segmentor::rulebase::RuleBase(model->labels);
baseAll = new seg::rulebase::RuleBase(model->labels);

beg_tag0 = model->labels.index( ltp::segmentor::__b__ );
beg_tag1 = model->labels.index( ltp::segmentor::__s__ );
Expand All @@ -56,9 +58,9 @@ class SegmentorWrapper : public ltp::segmentor::Segmentor {

int segment(const char * str,
std::vector<std::string> & words) {
ltp::segmentor::Instance * inst = new ltp::segmentor::Instance;
seg::Instance * inst = new seg::Instance;
// ltp::strutils::codecs::decode(str, inst->forms);
int ret = ltp::segmentor::rulebase::preprocess(str,
int ret = seg::rulebase::preprocess(str,
inst->raw_forms,
inst->forms,
inst->chartypes);
Expand All @@ -69,21 +71,19 @@ class SegmentorWrapper : public ltp::segmentor::Segmentor {
return 0;
}

ltp::segmentor::Segmentor::extract_features(inst);
ltp::segmentor::Segmentor::calculate_scores(inst, true);
seg::DecodeContext* ctx = new seg::DecodeContext;
seg::Segmentor::extract_features(inst, ctx);
seg::Segmentor::calculate_scores(inst, ctx, true);

// allocate a new decoder so that the segmentor support multithreaded
// decoding. this modification was committed by niuox
ltp::segmentor::Decoder deco(model->num_labels(), *baseAll);
seg::Decoder decoder(model->num_labels(), *baseAll);

deco.decode(inst);
ltp::segmentor::Segmentor::build_words(inst,
inst->predicted_tagsidx,
words,
beg_tag0,
beg_tag1);
decoder.decode(inst);
seg::Segmentor::build_words(inst, inst->predicted_tagsidx,
words, beg_tag0, beg_tag1);

ltp::segmentor::Segmentor::cleanup_decode_context();
delete ctx;
delete inst;
return words.size();
}
Expand Down
118 changes: 53 additions & 65 deletions src/segmentor/segmentor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Segmentor::Segmentor()
test_opt(0),
dump_opt(0),
decoder(0),
decode_context(0),
baseAll(0),
__TRAIN__(false),
__TEST__(false),
Expand All @@ -39,6 +40,7 @@ Segmentor::Segmentor()
Segmentor::Segmentor(utils::ConfigParser & cfg) :
model(0),
decoder(0),
decode_context(0),
baseAll(0),
__TRAIN__(false),
__TEST__(false),
Expand Down Expand Up @@ -74,7 +76,9 @@ Segmentor::~Segmentor() {
delete baseAll;
}

cleanup_decode_context();
if (decode_context) {
decode_context->clear();
}
}

void
Expand Down Expand Up @@ -289,31 +293,9 @@ Segmentor::build_configuration(void) {
}

void
Segmentor::cleanup_decode_context(void) {
if (uni_features.total_size() > 0) {
int d1 = uni_features.nrows();
int d2 = uni_features.ncols();

for (int i = 0; i < d1; ++ i) {
if (uni_features[i][0]) {
uni_features[i][0]->clear();
}

for (int j = 0; j < d2; ++ j) {
if (uni_features[i][j]) {
delete uni_features[i][j];
}
}
}
}

uni_features.dealloc();
correct_features.zero();
predicted_features.zero();
}

void
Segmentor::extract_features(Instance * inst, bool create) {
Segmentor::extract_features(Instance * inst,
DecodeContext* ctx,
bool create) {
const int N = Extractor::num_templates();
const int L = model->num_labels();

Expand All @@ -324,8 +306,8 @@ Segmentor::extract_features(Instance * inst, bool create) {
int len = inst->size();

// allocate the uni_features
uni_features.resize(len, L);
uni_features = 0;
ctx->uni_features.resize(len, L);
ctx->uni_features = 0;

// cache lexicon features.
if (0 == inst->lexicon_match_state.size()) {
Expand Down Expand Up @@ -395,18 +377,18 @@ Segmentor::extract_features(Instance * inst, bool create) {
idx[j] = cache_again[j];
}

uni_features[pos][l] = new math::FeatureVector;
uni_features[pos][l]->n = num_feat;
uni_features[pos][l]->val = 0;
uni_features[pos][l]->loff = 0;
uni_features[pos][l]->idx = idx;
ctx->uni_features[pos][l] = new math::FeatureVector;
ctx->uni_features[pos][l]->n = num_feat;
ctx->uni_features[pos][l]->val = 0;
ctx->uni_features[pos][l]->loff = 0;
ctx->uni_features[pos][l]->idx = idx;

for (l = 1; l < L; ++ l) {
uni_features[pos][l] = new math::FeatureVector;
uni_features[pos][l]->n = num_feat;
uni_features[pos][l]->idx = idx;
uni_features[pos][l]->val = 0;
uni_features[pos][l]->loff = l;
ctx->uni_features[pos][l] = new math::FeatureVector;
ctx->uni_features[pos][l]->n = num_feat;
ctx->uni_features[pos][l]->idx = idx;
ctx->uni_features[pos][l]->val = 0;
ctx->uni_features[pos][l]->loff = l;
}
}
}
Expand Down Expand Up @@ -444,8 +426,8 @@ Segmentor::build_feature_space(void) {
model->space.set_num_labels(L);

for (int i = 0; i < train_dat.size(); ++ i) {
extract_features(train_dat[i], true);
cleanup_decode_context();
extract_features(train_dat[i], decode_context, true);
decode_context->clear();

if ((i + 1) % train_opt->display_interval == 0) {
TRACE_LOG("[%d] instances is extracted.", (i+1));
Expand All @@ -454,7 +436,8 @@ Segmentor::build_feature_space(void) {
}

void
Segmentor::calculate_scores(Instance * inst, bool use_avg) {
Segmentor::calculate_scores(Instance * inst, const DecodeContext* ctx,
bool use_avg) {
int len = inst->size();
int L = model->num_labels();

Expand All @@ -463,12 +446,12 @@ Segmentor::calculate_scores(Instance * inst, bool use_avg) {

for (int i = 0; i < len; ++ i) {
for (int l = 0; l < L; ++ l) {
math::FeatureVector * fv = uni_features[i][l];
math::FeatureVector * fv = ctx->uni_features[i][l];
if (!fv) {
continue;
}

inst->uni_scores[i][l] = model->param.dot(uni_features[i][l], use_avg);
inst->uni_scores[i][l] = model->param.dot(ctx->uni_features[i][l], use_avg);
}
}

Expand All @@ -481,7 +464,7 @@ Segmentor::calculate_scores(Instance * inst, bool use_avg) {
}

void
Segmentor::collect_features(const math::Mat< math::FeatureVector* >& uni_features,
Segmentor::collect_features(const math::Mat< math::FeatureVector* >& features,
Model * model,
Instance * inst,
const std::vector<int> & tagsidx,
Expand All @@ -491,7 +474,7 @@ Segmentor::collect_features(const math::Mat< math::FeatureVector* >& uni_feature
vec.zero();
for (int i = 0; i < len; ++ i) {
int l = tagsidx[i];
const math::FeatureVector * fv = uni_features[i][l];
const math::FeatureVector * fv = features[i][l];

if (!fv) {
continue;
Expand Down Expand Up @@ -619,20 +602,22 @@ Segmentor::erase_rare_features(const int * feature_updated_times) {

void
Segmentor::collect_correct_and_predicted_features(Instance* inst) {
collect_features(uni_features, model, inst, inst->tagsidx, correct_features);
collect_features(uni_features, model, inst, inst->predicted_tagsidx, predicted_features);

updated_features.zero();
updated_features.add(correct_features, 1.);
updated_features.add(predicted_features, -1.);
collect_features(decode_context->uni_features,
model, inst, inst->tagsidx, decode_context->correct_features);
collect_features(decode_context->uni_features,
model, inst, inst->predicted_tagsidx, decode_context->predicted_features);

decode_context->updated_features.zero();
decode_context->updated_features.add(decode_context->correct_features, 1.);
decode_context->updated_features.add(decode_context->predicted_features, -1.);
}

void
Segmentor::train_passive_aggressive(int nr_errors) {
//double error = train_dat[i]->num_errors();
double error = nr_errors;
double score = model->param.dot(updated_features, false);
double norm = updated_features.L2();
double score = model->param.dot(decode_context->updated_features, false);
double norm = decode_context->updated_features.L2();

double step = 0.;
if (norm < EPS) {
Expand All @@ -641,12 +626,12 @@ Segmentor::train_passive_aggressive(int nr_errors) {
step = (error - score) / norm;
}

model->param.add(updated_features, timestamp, step);
model->param.add(decode_context->updated_features, timestamp, step);
}

void
Segmentor::train_averaged_perceptron() {
model->param.add(updated_features, timestamp, 1.);
model->param.add(decode_context->updated_features, timestamp, 1.);
}

bool
Expand Down Expand Up @@ -717,6 +702,7 @@ Segmentor::train(void) {
} else {
// use pa or average perceptron algorithm
rulebase::RuleBase base(model->labels);
decode_context = new DecodeContext;
decoder = new Decoder(model->num_labels(), base);
TRACE_LOG("Allocated plain decoder");

Expand All @@ -732,14 +718,14 @@ Segmentor::train(void) {
set_timestamp(iter * train_dat.size() + i + 1);

Instance * inst = train_dat[i];
extract_features(inst);
calculate_scores(inst, false);
extract_features(inst, decode_context);
calculate_scores(inst, decode_context, false);
decoder->decode(inst);

collect_correct_and_predicted_features(inst);

if (feature_group_updated_time) {
increase_group_updated_time(updated_features,
increase_group_updated_time(decode_context->updated_features,
feature_group_updated_time);
}

Expand All @@ -749,7 +735,7 @@ Segmentor::train(void) {
train_averaged_perceptron();
}

cleanup_decode_context();
decode_context->clear();

if ((i+1) % train_opt->display_interval == 0) {
TRACE_LOG("[%d] instances is trained.", i+1);
Expand Down Expand Up @@ -821,17 +807,19 @@ Segmentor::evaluate(double &p, double &r, double &f) {
int beg_tag0 = model->labels.index( __b__ );
int beg_tag1 = model->labels.index( __s__ );

decode_context = new DecodeContext;

while ((inst = reader.next())) {
int len = inst->size();
inst->tagsidx.resize(len);
for (int i = 0; i < len; ++ i) {
inst->tagsidx[i] = model->labels.index(inst->tags[i]);
}

extract_features(inst);
calculate_scores(inst, true);
extract_features(inst, decode_context);
calculate_scores(inst, decode_context, true);
decoder->decode(inst);
cleanup_decode_context();
decode_context->clear();

if (inst->words.size() == 0) {
build_words(inst, inst->tagsidx, inst->words, beg_tag0, beg_tag1);
Expand Down Expand Up @@ -931,10 +919,10 @@ Segmentor::test(void) {
int len = inst->size();
inst->tagsidx.resize(len);

extract_features(inst);
calculate_scores(inst, true);
extract_features(inst, decode_context);
calculate_scores(inst, decode_context, true);
decoder->decode(inst);
cleanup_decode_context();
decode_context->clear();

build_words(inst,
inst->predicted_tagsidx,
Expand Down
Loading

0 comments on commit 9bc1a9c

Please sign in to comment.