forked from HIT-SCIR/ltp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
postag_dll.cpp
105 lines (87 loc) · 2.73 KB
/
postag_dll.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#include "postagger/postag_dll.h"
#include "postagger/postagger.h"
#include "postagger/settings.h"
#include "postagger/decoder.h"
#include "postagger/extractor.h"
#include "utils/logging.hpp"
#include "utils/codecs.hpp"
#include "utils/sbcdbc.hpp"
#include "utils/tinybitset.hpp"
#include <iostream>
#include <fstream>
class __ltp_dll_postagger_wrapper : public ltp::postagger::Postagger {
private:
ltp::postagger::PostaggerLexicon lex;
public:
__ltp_dll_postagger_wrapper() {}
~__ltp_dll_postagger_wrapper() {}
bool load(const char* model_file, const char* lexicon_file = NULL) {
std::ifstream mfs(model_file, std::ifstream::binary);
if (!mfs) {
return false;
}
model = new ltp::framework::Model(ltp::postagger::Extractor::num_templates());
if (!model->load(ltp::postagger::Postagger::model_header, mfs)) {
delete model;
return false;
}
if (NULL != lexicon_file) { // MSVC need check this.
std::ifstream lfs(lexicon_file);
if (lfs.good()) {
lex.load(lfs, model->labels);
}
}
return true;
}
int postag(const std::vector<std::string> & words,
std::vector<std::string> & tags) {
ltp::framework::ViterbiFeatureContext ctx;
ltp::framework::ViterbiScoreMatrix scm;
ltp::framework::ViterbiDecoder decoder;
ltp::postagger::Instance inst;
inst.forms.resize(words.size());
for (size_t i = 0; i < words.size(); ++ i) {
ltp::strutils::chartypes::sbc2dbc_x(words[i], inst.forms[i]);
}
extract_features(inst, &ctx, false);
calculate_scores(inst, ctx, true, &scm);
if (lex.success()) {
ltp::postagger::PostaggerLexiconConstrain con = lex.get_con(words);
decoder.decode(scm, con, inst.predict_tagsidx);
} else {
decoder.decode(scm, inst.predict_tagsidx);
}
ltp::postagger::Postagger::build_labels(inst, tags);
return tags.size();
}
};
void * postagger_create_postagger(const char* path, const char* lexicon_file) {
__ltp_dll_postagger_wrapper* wrapper = new __ltp_dll_postagger_wrapper();
if (!wrapper->load(path, lexicon_file)) {
delete wrapper;
return 0;
}
return reinterpret_cast<void *>(wrapper);
}
int postagger_release_postagger(void * postagger) {
if (!postagger) {
return -1;
}
delete reinterpret_cast<__ltp_dll_postagger_wrapper*>(postagger);
return 0;
}
int postagger_postag(void * postagger,
const std::vector<std::string> & words,
std::vector<std::string> & tags) {
if (0 == words.size()) {
return 0;
}
for (int i = 0; i < words.size(); ++ i) {
if (words[i].empty()) {
return 0;
}
}
__ltp_dll_postagger_wrapper* wrapper = 0;
wrapper = reinterpret_cast<__ltp_dll_postagger_wrapper*>(postagger);
return wrapper->postag(words, tags);
}