Skip to content

Commit

Permalink
[fix] model loading to be compatible with model lower than 3.2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Oneplus committed Nov 10, 2014
1 parent 250485a commit 475727a
Show file tree
Hide file tree
Showing 11 changed files with 88 additions and 31 deletions.
8 changes: 4 additions & 4 deletions src/segmentor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ set_target_properties (otcws
# -----------------------------------------------
# TOOLKIT #2
# -----------------------------------------------
add_executable (otcws-inc otcws_inc.cpp ${segment_SRC} customized_segmentor.cpp)
add_executable (otcws-customized otcws_customized.cpp ${segment_SRC} customized_segmentor.cpp)

target_link_libraries (otcws-inc boost_regex_static_lib)
target_link_libraries (otcws-customized boost_regex_static_lib)

set_target_properties (otcws-inc
set_target_properties (otcws-customized
PROPERTIES
OUTPUT_NAME otcws-inc
OUTPUT_NAME otcws-customized
RUNTIME_OUTPUT_DIRECTORY ${TOOLS_DIR}/train/)


Expand Down
Empty file modified src/segmentor/decoder.h
100755 → 100644
Empty file.
41 changes: 33 additions & 8 deletions src/segmentor/model.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
#include "model.h"
#include <cstring>

#define SEGMENTOR_MODEL "otcws" // for model version lower than 3.2.0
#define SEGMENTOR_MODEL_FULL "otcws-full"
#define SEGMENTOR_MODEL_MINIMAL "otcws-minimal"

namespace ltp {
namespace segmentor {
Expand All @@ -10,11 +15,20 @@ Model::~Model() {
}

void
Model::save(std::ostream & ofs) {
Model::save(std::ostream & ofs, bool full) {
// write a signature into the file
char chunk[16] = {'o','t','c','w','s', '\0'};
char chunk[16];
if (full) {
strncpy(chunk, SEGMENTOR_MODEL_FULL, 16);
} else {
strncpy(chunk, SEGMENTOR_MODEL_MINIMAL, 16);
}

ofs.write(chunk, 16);
ofs.write(reinterpret_cast<const char *>(&end_time), sizeof(int));

if (full) {
ofs.write(reinterpret_cast<const char *>(&end_time), sizeof(int));
}

int off = ofs.tellp();

Expand All @@ -38,7 +52,7 @@ Model::save(std::ostream & ofs) {
space.dump(ofs);

parameter_offset = ofs.tellp();
param.dump(ofs);
param.dump(ofs, full);

ofs.seekp(off);
write_uint(ofs, labels_offset);
Expand All @@ -52,12 +66,23 @@ Model::load(std::istream & ifs) {
char chunk[16];
ifs.read(chunk, 16);

if (strcmp(chunk, "otcws")) {
bool full = false;
if (!strcmp(chunk, SEGMENTOR_MODEL_FULL)) {
full = true;
} else if (!strcmp(chunk, SEGMENTOR_MODEL) ||
!strcmp(chunk, SEGMENTOR_MODEL_MINIMAL)) {
full = false;
} else {
return false;
}
ifs.read(reinterpret_cast<char *>(&end_time), sizeof(int));

unsigned labels_offset = read_uint(ifs);
if (full) {
ifs.read(reinterpret_cast<char *>(&end_time), sizeof(int));
} else {
end_time = 0;
}

unsigned labels_offset = read_uint(ifs);
unsigned lexicon_offset = read_uint(ifs);
unsigned feature_offset = read_uint(ifs);
unsigned parameter_offset = read_uint(ifs);
Expand All @@ -78,7 +103,7 @@ Model::load(std::istream & ifs) {
}

ifs.seekg(parameter_offset);
if (!param.load(ifs)) {
if (!param.load(ifs, full)) {
return false;
}

Expand Down
3 changes: 2 additions & 1 deletion src/segmentor/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ class Model {
* save the model to a output stream
*
* @param[out] ofs the output stream
* @param[in] full use to specify if dump full model.
*/
void save(std::ostream & ofs);
void save(std::ostream & ofs, bool full);

/**
* load the model from an input stream
Expand Down
1 change: 1 addition & 0 deletions src/segmentor/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ struct TrainOptions {
int max_iter;
int display_interval;
int rare_feature_threshold;
bool enable_incremental_training;
};

struct TestOptions {
Expand Down
File renamed without changes.
43 changes: 34 additions & 9 deletions src/segmentor/parameter.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
#define __LTP_SEGMENTOR_PARAMETER_H__

#include <iostream>
#include <cstring>
#include "utils/math/sparsevec.h"
#include "utils/math/featurevec.h"

#define SEGMENTOR_PARAM "param" // for model version lower than 3.2.0
#define SEGMENTOR_PARAM_FULL "param-full"
#define SEGMENTOR_PARAM_MINIMAL "param-minimal"

namespace ltp {
namespace segmentor {

Expand Down Expand Up @@ -142,29 +147,49 @@ class Parameters {
}
}

void dump(std::ostream & out) {
char chunk[16] = {'p', 'a', 'r', 'a', 'm', 0};
//! Dump the model. since version 3.2.0, fully dumped model is supported.
//! using a tag full to distinguish between old and new model.
void dump(std::ostream & out, bool full) {
char chunk[16];

if (full) {
strncpy(chunk, SEGMENTOR_PARAM_FULL,16);
} else {
strncpy(chunk, SEGMENTOR_PARAM_MINIMAL, 16);
}

out.write(chunk, 16);
out.write(reinterpret_cast<const char *>(&_dim), sizeof(int));
if (_dim > 0) {
out.write(reinterpret_cast<const char *>(_W), sizeof(double) * _dim);
if (full) {
out.write(reinterpret_cast<const char *>(_W), sizeof(double) * _dim);
}
out.write(reinterpret_cast<const char *>(_W_sum), sizeof(double) * _dim);
}
}

bool load(std::istream & in) {
bool load(std::istream & in, bool full) {
char chunk[16];

in.read(chunk, 16);
if (strcmp(chunk, "param")) {
if ((!strcmp(chunk, SEGMENTOR_PARAM_FULL) && full) ||
(!strcmp(chunk, SEGMENTOR_PARAM_MINIMAL) && !full)) {
return false;
}

in.read(reinterpret_cast<char *>(&_dim), sizeof(int));

if (_dim > 0) {
_W = new double[_dim];
_W_sum = new double[_dim];
in.read(reinterpret_cast<char *>(_W), sizeof(double) * _dim);
in.read(reinterpret_cast<char *>(_W_sum), sizeof(double) * _dim);
if (full) {
_W = new double[_dim];
_W_sum = new double[_dim];
in.read(reinterpret_cast<char *>(_W), sizeof(double) * _dim);
in.read(reinterpret_cast<char *>(_W_sum), sizeof(double) * _dim);
} else {
_W = new double[_dim];
in.read(reinterpret_cast<char *>(_W), sizeof(double) * _dim);
_W_sum = _W;
}
}

return true;
Expand Down
7 changes: 6 additions & 1 deletion src/segmentor/segmentor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ Segmentor::parse_cfg(utils::ConfigParser & cfg) {
train_opt->max_iter = 10;
train_opt->display_interval = 5000;
train_opt->rare_feature_threshold = 0;
train_opt->enable_incremental_training = 0;

if (cfg.has_section("train")) {
int intbuf;
Expand Down Expand Up @@ -153,6 +154,10 @@ Segmentor::parse_cfg(utils::ConfigParser & cfg) {
} else {
WARNING_LOG("max-iter is not configed, [10] is set as default.");
}

if (cfg.get_integer("train", "enable-incremental-training", intbuf)) {
train_opt->enable_incremental_training = (intbuf == 1);
}
}

test_opt->test_file = "";
Expand Down Expand Up @@ -771,7 +776,7 @@ Segmentor::train(void) {
std::ofstream ofs(saved_model_file.c_str(), std::ofstream::binary);

std::swap(model, new_model);
new_model->save(ofs);
new_model->save(ofs, train_opt->enable_incremental_training);
delete new_model;
TRACE_LOG("Model for iteration [%d] is saved to [%s]",
iter + 1,
Expand Down
File renamed without changes.
8 changes: 8 additions & 0 deletions tools/train/conf/cws/customized-cws.cnf
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[train]
train-file = sample/seg/example-train-2.seg
holdout-file = sample/seg/example-holdout.seg
algorithm = pa
model-name = build/cws/example-seg-customized
baseline-model-name = build/cws/example-seg.0.model
max-iter = 5
rare-feature-threshold = 0
8 changes: 0 additions & 8 deletions tools/train/conf/cws/personal.cnf

This file was deleted.

0 comments on commit 475727a

Please sign in to comment.