Skip to content

Commit

Permalink
fixed the bug in SRL training suite
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangfeng1124 committed Jun 13, 2014
1 parent cdf5044 commit ce32452
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 0 deletions.
40 changes: 40 additions & 0 deletions src/srl/lgsrl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ bool parse_cfg(ConfigParser & cfg)
return false;
}

if (cfg.get("train-srl", "dst-config-dir", strbuf)) {
train_opt.dst_config_dir = strbuf;
} else {
ERROR_LOG("[SRL] dst_config_dir config item is not found");
return false;
}

if (cfg.get_integer("train-srl", "solver-type", intbuf)) {
switch (intbuf) {
case 0: me_srl_param.solver_type = L1_OWLQN; break;
Expand Down Expand Up @@ -173,6 +180,13 @@ bool parse_cfg(ConfigParser & cfg)
return false;
}

if (cfg.get("train-prg", "dst-config-dir", strbuf)) {
train_opt.dst_config_dir = strbuf;
} else {
ERROR_LOG("[PRG] dst_config_dir config item is not found");
return false;
}

if (cfg.get_integer("train-prg", "solver-type", intbuf)) {
switch (intbuf) {
case 0: me_prg_param.solver_type = L1_OWLQN; break;
Expand Down Expand Up @@ -240,6 +254,28 @@ bool parse_cfg(ConfigParser & cfg)
return true;
}

bool copy_cfg(const string & src_cfg,
const string & dst_cfg)
{
ifstream fsrc(src_cfg.c_str());
ofstream fdst(dst_cfg.c_str());

if (!fdst)
{
ERROR_LOG("Cannot open [%s]", dst_cfg.c_str());
return false;
}

string line;
while (getline(fsrc, line))
fdst << line << endl;

fsrc.close();
fdst.close();

return true;
}

bool collect_prg_instances()
{
Configuration configuration(train_opt.core_config);
Expand Down Expand Up @@ -673,6 +709,8 @@ int main(int argc, char *argv[])
train(prg_model,
train_opt.prg_instance_file,
train_opt.prg_model_file);
copy_cfg(train_opt.core_config,
train_opt.dst_config_dir + "/Chinese.xml");
}

if (__TRAIN_SRL__) {
Expand All @@ -689,6 +727,8 @@ int main(int argc, char *argv[])
train(srl_model,
train_opt.srl_instance_file,
train_opt.srl_model_file);
copy_cfg(train_opt.srl_config,
train_opt.dst_config_dir + "/srl.cfg");
}

if (__TEST__) {
Expand Down
1 change: 1 addition & 0 deletions src/srl/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ struct TrainOptions {
std::string prg_instance_file;
std::string srl_model_file;
std::string prg_model_file;
std::string dst_config_dir; // destination cfgs
};

struct TestOptions {
Expand Down
1 change: 1 addition & 0 deletions tools/train/conf/srl/srl-prg.cnf
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ prg-train-file = sample/srl/example-train.srl
core-config = conf/srl/assets/Chinese.xml
prg-instance-file = build/srl/prg-instances.train/train.inst
prg-model-file = build/srl/prg.model
dst-config-dir = build/srl/
solver-type = 0 # L1-owlqn
#solver-type = 1 # L1-sgd
#solver-type = 2 # L2-lbfgs
Expand Down
1 change: 1 addition & 0 deletions tools/train/conf/srl/srl-srl.cnf
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ srl-config = conf/srl/assets/srl.cfg
srl-feature-dir = build/srl/srl-features.train
srl-instance-file = build/srl/srl-instances.train/train.inst
srl-model-file = build/srl/srl.model
dst-config-dir = build/srl/
solver-type = 0 # L1-owlqn
#solver-type = 1 # L1-sgd
#solver-type = 2 # L2-lbfgs
Expand Down

0 comments on commit ce32452

Please sign in to comment.