diff --git a/src/srl/lgsrl.cpp b/src/srl/lgsrl.cpp index f72907af1..00e7a0e81 100644 --- a/src/srl/lgsrl.cpp +++ b/src/srl/lgsrl.cpp @@ -104,6 +104,13 @@ bool parse_cfg(ConfigParser & cfg) return false; } + if (cfg.get("train-srl", "dst-config-dir", strbuf)) { + train_opt.dst_config_dir = strbuf; + } else { + ERROR_LOG("[SRL] dst_config_dir config item is not found"); + return false; + } + if (cfg.get_integer("train-srl", "solver-type", intbuf)) { switch (intbuf) { case 0: me_srl_param.solver_type = L1_OWLQN; break; @@ -173,6 +180,13 @@ bool parse_cfg(ConfigParser & cfg) return false; } + if (cfg.get("train-prg", "dst-config-dir", strbuf)) { + train_opt.dst_config_dir = strbuf; + } else { + ERROR_LOG("[PRG] dst_config_dir config item is not found"); + return false; + } + if (cfg.get_integer("train-prg", "solver-type", intbuf)) { switch (intbuf) { case 0: me_prg_param.solver_type = L1_OWLQN; break; @@ -240,6 +254,28 @@ bool parse_cfg(ConfigParser & cfg) return true; } +bool copy_cfg(const string & src_cfg, + const string & dst_cfg) +{ + ifstream fsrc(src_cfg.c_str()); + ofstream fdst(dst_cfg.c_str()); + + if (!fdst) + { + ERROR_LOG("Cannot open [%s]", dst_cfg.c_str()); + return false; + } + + string line; + while (getline(fsrc, line)) + fdst << line << endl; + + fsrc.close(); + fdst.close(); + + return true; +} + bool collect_prg_instances() { Configuration configuration(train_opt.core_config); @@ -673,6 +709,8 @@ int main(int argc, char *argv[]) train(prg_model, train_opt.prg_instance_file, train_opt.prg_model_file); + copy_cfg(train_opt.core_config, + train_opt.dst_config_dir + "/Chinese.xml"); } if (__TRAIN_SRL__) { @@ -689,6 +727,8 @@ int main(int argc, char *argv[]) train(srl_model, train_opt.srl_instance_file, train_opt.srl_model_file); + copy_cfg(train_opt.srl_config, + train_opt.dst_config_dir + "/srl.cfg"); } if (__TEST__) { diff --git a/src/srl/options.h b/src/srl/options.h index 5452460db..009515834 100644 --- a/src/srl/options.h +++ b/src/srl/options.h @@ -14,6 +14,7 @@ struct TrainOptions { std::string prg_instance_file; std::string srl_model_file; std::string prg_model_file; + std::string dst_config_dir; // destination cfgs }; struct TestOptions { diff --git a/tools/train/conf/srl/srl-prg.cnf b/tools/train/conf/srl/srl-prg.cnf index 83bae58c9..eb9cd9e23 100644 --- a/tools/train/conf/srl/srl-prg.cnf +++ b/tools/train/conf/srl/srl-prg.cnf @@ -3,6 +3,7 @@ prg-train-file = sample/srl/example-train.srl core-config = conf/srl/assets/Chinese.xml prg-instance-file = build/srl/prg-instances.train/train.inst prg-model-file = build/srl/prg.model +dst-config-dir = build/srl/ solver-type = 0 # L1-owlqn #solver-type = 1 # L1-sgd #solver-type = 2 # L2-lbfgs diff --git a/tools/train/conf/srl/srl-srl.cnf b/tools/train/conf/srl/srl-srl.cnf index 5f543349c..c58aea977 100644 --- a/tools/train/conf/srl/srl-srl.cnf +++ b/tools/train/conf/srl/srl-srl.cnf @@ -5,6 +5,7 @@ srl-config = conf/srl/assets/srl.cfg srl-feature-dir = build/srl/srl-features.train srl-instance-file = build/srl/srl-instances.train/train.inst srl-model-file = build/srl/srl.model +dst-config-dir = build/srl/ solver-type = 0 # L1-owlqn #solver-type = 1 # L1-sgd #solver-type = 2 # L2-lbfgs