Skip to content

Commit

Permalink
merge multithreaded parser modify by b.han
Browse files Browse the repository at this point in the history
  • Loading branch information
Oneplus committed Sep 27, 2013
1 parent bb9ca7c commit 26f8a1e
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 32 deletions.
18 changes: 10 additions & 8 deletions src/parser/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,28 +435,30 @@ bool Parser::read_instances(const char * filename, vector<Instance *> & dat) {
return true;
}

void Parser::build_decoder(void) {
Decoder * Parser::build_decoder(void) {
Decoder * deco;
if (model_opt.decoder_name == "1o") {
if (!model_opt.labeled) {
decoder = new Decoder1O();
deco = new Decoder1O();
} else {
decoder = new Decoder1O(model->num_deprels());
deco = new Decoder1O(model->num_deprels());
}

} else if (model_opt.decoder_name == "2o-sib") {
if (!model_opt.labeled) {
decoder = new Decoder2O();
deco = new Decoder2O();
} else {
decoder = new Decoder2O(model->num_deprels());
deco = new Decoder2O(model->num_deprels());
}

} else if (model_opt.decoder_name == "2o-carreras") {
if (!model_opt.labeled) {
decoder = new Decoder2OCarreras();
deco = new Decoder2OCarreras();
} else {
decoder = new Decoder2OCarreras(model->num_deprels());
deco = new Decoder2OCarreras(model->num_deprels());
}
}
return deco;
}


Expand Down Expand Up @@ -765,7 +767,7 @@ void Parser::train(void) {
model->param.realloc(model->dim());
TRACE_LOG("Allocate a parameter vector of [%d] dimension.", model->dim());

build_decoder();
decoder= build_decoder();

for (int iter = 0; iter < train_opt.max_iter; ++ iter) {
TRACE_LOG("Start training epoch #%d.", (iter + 1));
Expand Down
3 changes: 1 addition & 2 deletions src/parser/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ class Parser{
bool gold = false);

protected:
void build_decoder(void);

Decoder * build_decoder(void);
void extract_features(Instance * inst);

void calculate_score(Instance * inst, const Parameters& param, bool use_avg = false);
Expand Down
7 changes: 5 additions & 2 deletions src/parser/parser_dll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class ParserWrapper : public ltp::parser::Parser {
return false;
}

ltp::parser::Parser::build_decoder();
// ltp::parser::Parser::build_decoder();

return true;
}
Expand All @@ -50,7 +50,9 @@ class ParserWrapper : public ltp::parser::Parser {

ltp::parser::Parser::extract_features(inst);
ltp::parser::Parser::calculate_score(inst, ltp::parser::Parser::model->param);
decoder->decode(inst);
ltp::parser::Decoder * deco;
deco=build_decoder();
deco->decode(inst);

int len = inst->size();
heads.resize(len - 1);
Expand All @@ -62,6 +64,7 @@ class ParserWrapper : public ltp::parser::Parser {
}

delete inst;
delete deco;
return inst->size();
}
};
Expand Down
34 changes: 21 additions & 13 deletions src/srl/DepSRL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,8 @@
// Load necessary resources into memory
int DepSRL::LoadResource(const string &ConfigDir)
{
string configXml = ConfigDir + "/Chinese.xml";
string selectFeats = ConfigDir + "/srl.cfg";

m_srlBaseline = new SRLBaselineExt(configXml, selectFeats);

m_configXml = ConfigDir + "/Chinese.xml";
m_selectFeats = ConfigDir + "/srl.cfg";
// load srl and prg model
m_srlModel = new maxent::ME_Model;
m_srlModel->load(ConfigDir + "/srl.model");
Expand All @@ -44,7 +41,14 @@ int DepSRL::ReleaseResource()

return 1;
}

string DepSRL::GetConfigXml()
{
return m_configXml;
}
string DepSRL::GetSelectFeats()
{
return m_selectFeats;
}
int DepSRL::GetSRLResult(
const vector<string> &words,
const vector<string> &POSs,
Expand All @@ -63,15 +67,16 @@ int DepSRL::GetSRLResult(
// construct a DataPreProcess instance
DataPreProcess* dataPreProc = new DataPreProcess(&ltpData);

SRLBaselineExt * m_srlBaseline=new SRLBaselineExt(GetConfigXml(),GetSelectFeats());
// extract features !
m_srlBaseline->setDataPreProc(dataPreProc);

// GetPredicateFromSentence(POSs,predicates);
vector<int> predicates;
GetPredicateFromSentence(predicates);
GetPredicateFromSentence(predicates,m_srlBaseline);

// return GetSRLResult(words, POSs, NEs, parse, predicates, vecSRLResult);
return GetSRLResult(ltpData, predicates, vecSRLResult);
return GetSRLResult(ltpData, predicates, vecSRLResult,m_srlBaseline);
}

// produce DepSRL result for a sentence
Expand Down Expand Up @@ -101,7 +106,8 @@ int DepSRL::GetSRLResult(
int DepSRL::GetSRLResult(
const LTPData &ltpData,
const vector<int> &predicates,
vector< pair< int, vector< pair< string, pair< int, int > > > > > &vecSRLResult
vector< pair< int, vector< pair< string, pair< int, int > > > > > &vecSRLResult,
SRLBaselineExt * m_srlBaseline
)
{
vecSRLResult.clear();
Expand All @@ -122,7 +128,7 @@ int DepSRL::GetSRLResult(
vector< vector< pair<string, double> > > vecAllPairNextArgs;

// extract features
if (!ExtractSrlFeatures(ltpData, predicates,vecAllFeatures,vecAllPos))
if (!ExtractSrlFeatures(ltpData, predicates,vecAllFeatures,vecAllPos,m_srlBaseline))
return 0;

// predict
Expand All @@ -139,6 +145,7 @@ int DepSRL::GetSRLResult(

// rename arguments to short forms (ARGXYZ->AXYZ)
if (!RenameArguments(vecSRLResult)) return 0;
delete m_srlBaseline;

return 1;
}
Expand All @@ -147,7 +154,8 @@ int DepSRL::ExtractSrlFeatures(
const LTPData &ltpData,
const vector<int> &VecAllPredicates,
VecFeatForSent &vecAllFeatures,
VecPosForSent &vecAllPos
VecPosForSent &vecAllPos,
SRLBaselineExt* m_srlBaseline
)
{
vecAllFeatures.clear();
Expand Down Expand Up @@ -467,7 +475,7 @@ void DepSRL::GetParAndRel(const vector< pair<int, string> >& vecParser,
}

void DepSRL::GetPredicateFromSentence(const vector<string>& vecPos,
vector<int>& vecPredicate) const
vector<int>& vecPredicate,SRLBaselineExt* m_srlBaseline) const
{
int index;
vector<string>::const_iterator itPos;
Expand All @@ -485,7 +493,7 @@ void DepSRL::GetPredicateFromSentence(const vector<string>& vecPos,
}
}

void DepSRL::GetPredicateFromSentence(vector<int>& vecPredicate) const
void DepSRL::GetPredicateFromSentence(vector<int>& vecPredicate,SRLBaselineExt * m_srlBaseline) const
{
/* extract features for each word in sentence */
vector< vector<string> > vecFeatures;
Expand Down
17 changes: 10 additions & 7 deletions src/srl/DepSRL.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,20 @@ class DepSRL {
int GetSRLResult(
const LTPData &ltpData,
const vector<int> &predicates,
vector< pair< int, vector< pair< string, pair< int, int > > > > > &vecSRLResult
vector< pair< int, vector< pair< string, pair< int, int > > > > > &vecSRLResult,
SRLBaselineExt * m_srlBaseline
);

string GetConfigXml();
string GetSelectFeats();
private:
/* 1.Extract SRL Features from input
*/
int ExtractSrlFeatures(
const LTPData &ltpData,
const vector<int> &VecAllPredicates,
VecFeatForSent &vecAllFeatures,
VecPosForSent &vecAllPos
VecPosForSent &vecAllPos,
SRLBaselineExt* m_srlBaseline
);

/* 2.Predict with the maxent library
Expand Down Expand Up @@ -119,11 +122,11 @@ class DepSRL {
*/
void GetPredicateFromSentence(
const vector<string>& vecPos,
vector<int>& vecPredicate) const;
vector<int>& vecPredicate,SRLBaselineExt * m_srlBaseline) const;

/* Version 2: find predicates according to a MaxEnt model
*/
void GetPredicateFromSentence(vector<int>& vecPredicate) const;
void GetPredicateFromSentence(vector<int>& vecPredicate,SRLBaselineExt * m_srlBaselie) const;

void ProcessOnePredicate(
const vector<string>& vecWords,
Expand Down Expand Up @@ -240,8 +243,8 @@ class DepSRL {

private:
bool m_resourceLoaded;
SRLBaselineExt* m_srlBaseline;

string m_configXml;
string m_selectFeats;
maxent::ME_Model *m_srlModel; // for role labeling
maxent::ME_Model *m_prgModel; // for predicate recognition
};
Expand Down

0 comments on commit 26f8a1e

Please sign in to comment.