Skip to content

Commit

Permalink
Merge pull request HIT-SCIR#218 from liu946/develop
Browse files Browse the repository at this point in the history
BiLSTM SRL module
fix compiling error with msvc
  • Loading branch information
endyul authored Jun 6, 2017
2 parents 4352866 + a247098 commit 65b0adc
Show file tree
Hide file tree
Showing 1,375 changed files with 266,309 additions and 68,210 deletions.
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ build
###############
# output #
###############
#include/
include/
lib/
bin/
tools/train/lgdpj
tools/train/lgsrl
tools/train/srl*
tools/train/otcws
tools/train/otpos
tools/train/otner
Expand All @@ -39,6 +39,7 @@ tools/train/Debug/
###############
new_ltp_data/
ltp_data/
ltp_data

##################
# running folder #
Expand Down
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ if (APPLE)
add_definitions(-DGTEST_HAS_TR1_TUPLE=0)
set(CMAKE_CXX_FLAGS "-std=c++0x -Wno-c++11-narrowing")
elseif(UNIX)
set(CMAKE_CXX_FLAGS "-std=c++0x")
set(CMAKE_CXX_FLAGS "-std=c++0x -fPIC")
elseif(MINGW)
set(CMAKE_CXX_FLAGS "-std=c++0x")
elseif(MSVC)
add_definitions(-D_WINDOWS) # make dynet happy at `dynet/mem.cc(7)`
add_definitions(-DBOOST_ALL_NO_LIB) # disable boost auto-linking on windows
set(CMAKE_CXX_FLAGS "/EHsc")
endif(APPLE)

Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ set (srl_DIR ${SOURCE_DIR}/srl/)
set (ltp_DIR ${SOURCE_DIR}/ltp/)
set (server_DIR ${SOURCE_DIR}/server/)


add_subdirectory ("xml4nlp")
add_subdirectory ("splitsnt")
add_subdirectory ("segmentor")
Expand Down
1 change: 0 additions & 1 deletion src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#define LTP_VERSION "3.3.2"
#define LTP_COPYRIGHT "(C) 2012-2016 HIT-SCIR"

#define BOOST_ALL_NO_LIB

namespace ltp {

Expand Down
19 changes: 18 additions & 1 deletion src/console/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,15 @@ include_directories (./

set (ltp_test_SRC ltp_test.cpp ${THIRDPARTY_DIR}/tinythreadpp/tinythread.cpp)

# look for Boost
#if(DEFINED ENV{BOOST_ROOT})
# set(Boost_NO_SYSTEM_PATHS ON)
#endif()
#set(Boost_REALPATH ON)
#find_package(Boost COMPONENTS program_options serialization REQUIRED)
#include_directories(${Boost_INCLUDE_DIR})
#set(LIBS ${LIBS} ${Boost_LIBRARIES})

link_directories ( ${LIBRARY_OUTPUT_PATH} )
add_executable (ltp_test ${ltp_test_SRC})
target_link_libraries (ltp_test
Expand All @@ -22,7 +31,9 @@ target_link_libraries (ltp_test
xml4nlp
boost_regex_static_lib
boost_program_options_static_lib
jsoncpp)
boost_serialization_static_lib
dynet
jsoncpp_lib_static)

add_executable (cws_cmdline cws_cmdline.cpp
${THIRDPARTY_DIR}/tinythreadpp/tinythread.cpp)
Expand Down Expand Up @@ -53,6 +64,12 @@ target_link_libraries (ner_cmdline ner_static_lib
set_target_properties (ner_cmdline PROPERTIES
RUNTIME_OUTPUT_DIRECTORY ${EXECUTABLE_OUTPUT_PATH}/examples/)

add_executable (srl_cmdline srl_cmdline.cpp
${THIRDPARTY_DIR}/tinythreadpp/tinythread.cpp)
target_link_libraries (srl_cmdline srl_static_lib
boost_program_options_static_lib)
set_target_properties (srl_cmdline PROPERTIES
RUNTIME_OUTPUT_DIRECTORY ${EXECUTABLE_OUTPUT_PATH}/examples/)

if (NOT MSVC AND NOT MINGW)
target_link_libraries (ltp_test pthread)
Expand Down
18 changes: 18 additions & 0 deletions src/console/dispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@

#include <iostream>
#include <string>
#include <vector>
#include <map>
#include "tinythread.h"

using namespace std;

class Dispatcher {
public:
Dispatcher(void* engine, std::istream& is, std::ostream& os):
Expand All @@ -24,6 +27,21 @@ class Dispatcher {
return _max_idx ++;
}

int next_block(vector<std::string>& block) {
block.clear();
tthread::lock_guard<tthread::mutex> guard(_mutex);
std::string line;
while (std::getline(_is, line, '\n')) {
if (line != "") {
block.push_back(line);
} else {
return _max_idx ++;
}
}
if (block.size()) return _max_idx++;
return -1;
}

void output(const size_t& idx, const std::string& result) {
tthread::lock_guard<tthread::mutex> guard(_mutex);
if (idx > _idx) {
Expand Down
180 changes: 180 additions & 0 deletions src/console/srl_cmdline.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
//
// Created by yliu on 2017/5/24.
//

#define EXECUTABLE "srl_cmdline"
#define DESCRIPTION "The console application for Semantic Role Labelling."

#include <iostream>
#include <fstream>
#include <list>
#include "config.h"
#include "srl/SRL_DLL.h"
#include "console/dispatcher.h"
#include "boost/program_options.hpp"
#include "utils/strutils.hpp"
#include "utils/time.hpp"

using boost::program_options::options_description;
using boost::program_options::value;
using boost::program_options::variables_map;
using boost::program_options::store;
using boost::program_options::parse_command_line;
using ltp::utility::WallClockTimer;
using ltp::strutils::split;


inline int findRoot(vector<pair<int, string> > & parse, pair<int, int> edge) {
int begin = edge.first;
int end = edge.second;
for (int j = begin; j <= end; ++j) {
if (parse[j].first < begin || parse[j].first > end) {
return j;
}
}
return begin;
}

void multithreaded_srl(void *args) {

Dispatcher * dispatcher = (Dispatcher *)args;

while (true) {
vector<std::string> buffer;
int ret = dispatcher->next_block(buffer);
if (ret < 0)
break;
if (!buffer.size()){
continue;
}

std::vector<std::string> words;
std::vector<std::string> postags;
vector<pair<int, string> > parse;
vector<pair<int, vector<pair<string, pair<int, int> > > > > vecSRLResult;
for (int j = 0; j < buffer.size(); ++j) {
std::stringstream S(buffer[j]);
string str; int parent;
S >> str; words.push_back(str);
S >> str; postags.push_back(str);
S >> parent; S >> str; parse.push_back(make_pair(parent, str));
}

srl_dosrl(words, postags, parse, vecSRLResult);

vector<vector<string> > arg(words.size(), vector<string>(vecSRLResult.size(), "_"));
vector<bool> is_pred(words.size(), false);
for (int k = 0; k < vecSRLResult.size(); ++k) {
is_pred[vecSRLResult[k].first] = true;
for (int j = 0; j < vecSRLResult[k].second.size(); ++j) {
arg[findRoot(parse, vecSRLResult[k].second[j].second)][k] = vecSRLResult[k].second[k].first;
}
}

std::stringstream S; S.clear(); S.str("");
for (size_t i = 0; i < words.size(); ++ i) {
S << i << "\t" << words[i] << "\t" << postags[i] << "\t" << parse[i].first << "\t" << parse[i].second;
S << "\t" << (is_pred[i] ? "Y" : "_");
for (int j = 0; j < arg[i].size(); ++j) {
S << "\t" << arg[i][j];
}

S << std::endl;
}
dispatcher->output(ret, S.str());
}

return;

}

int main(int argc, char ** argv) {
std::string usage = EXECUTABLE " in LTP " LTP_VERSION " - " LTP_COPYRIGHT "\n";
usage += DESCRIPTION "\n\n";
usage += "usage: ./" EXECUTABLE " <options>\n\n";
usage += "options";

options_description optparser = options_description(usage);
optparser.add_options()
("threads", value<int>(), "The number of threads [default=1].")
("input", value<std::string>(), "The path to the input file. "
"Input data should contain one word each line. "
"Sentence should be separated by a blank line. "
"(e.g. \"中国 ns 2 ATT\").")
("pisrl-model", value<std::string>(),
"The path to the pi-srl joint model [default=ltp_data/pos.model].")
("help,h", "Show help information");

if (argc == 1) {
std::cerr << optparser << std::endl;
return 1;
}

variables_map vm;
store(parse_command_line(argc, argv, optparser), vm);

if (vm.count("help")) {
std::cerr << optparser << std::endl;
return 0;
}

int threads = 1;
if (vm.count("threads")) {
threads = vm["threads"].as<int>();
if (threads < 0) {
std::cerr << "number of threads should not less than 0, reset to 1." << std::endl;
threads = 1;
}
}

std::string input = "";
if (vm.count("input")) { input = vm["input"].as<std::string>(); }

std::string srl_model = "ltp_data/pos.model";
if (vm.count("pisrl-model")) {
srl_model = vm["pisrl-model"].as<std::string>();
}

std::string postagger_lexcion = "";
if (vm.count("postagger-lexicon")) {
postagger_lexcion = vm["postagger-lexicon"].as<std::string>();
}

if (srl_load_resource(srl_model)) {
return 1;
}

std::cerr << "TRACE: Model is loaded" << std::endl;
std::cerr << "TRACE: Running " << threads << " thread(s)" << std::endl;

std::ifstream ifs(input.c_str());
std::istream* is = NULL;

if (!ifs.good()) {
std::cerr << "WARN: Cann't open file! use stdin instead." << std::endl;
is = (&std::cin);
} else {
is = (&ifs);
}

Dispatcher * dispatcher = new Dispatcher( NULL, (*is), std::cout );
WallClockTimer t;
std::list<tthread::thread *> thread_list;
for (int i = 0; i < threads; ++ i) {
tthread::thread * t = new tthread::thread( multithreaded_srl, (void *)dispatcher );
thread_list.push_back( t );
}

for (std::list<tthread::thread *>::iterator i = thread_list.begin();
i != thread_list.end(); ++ i) {
tthread::thread * t = *i;
t->join();
delete t;
}

std::cerr << "TRACE: consume " << t.elapsed() << " seconds." << std::endl;
delete dispatcher;
srl_release_resource();
return 0;
}

Loading

0 comments on commit 65b0adc

Please sign in to comment.