forked from kpu/kenlm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy paththread.hh
167 lines (130 loc) · 4.11 KB
/
thread.hh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
#ifndef LM_FILTER_THREAD_H
#define LM_FILTER_THREAD_H
#include "../../util/thread_pool.hh"
#include <boost/utility/in_place_factory.hpp>
#include <deque>
#include <stack>
namespace lm {
template <class OutputBuffer> class ThreadBatch {
public:
ThreadBatch() {}
void Reserve(size_t size) {
input_.Reserve(size);
output_.Reserve(size);
}
// File reading thread.
InputBuffer &Fill(uint64_t sequence) {
sequence_ = sequence;
// Why wait until now to clear instead of after output? free in the same
// thread as allocated.
input_.Clear();
return input_;
}
// Filter worker thread.
template <class Filter> void CallFilter(Filter &filter) {
input_.CallFilter(filter, output_);
}
uint64_t Sequence() const { return sequence_; }
// File writing thread.
template <class RealOutput> void Flush(RealOutput &output) {
output_.Flush(output);
}
private:
InputBuffer input_;
OutputBuffer output_;
uint64_t sequence_;
};
template <class Batch, class Filter> class FilterWorker {
public:
typedef Batch *Request;
FilterWorker(const Filter &filter, util::PCQueue<Request> &done) : filter_(filter), done_(done) {}
void operator()(Request request) {
request->CallFilter(filter_);
done_.Produce(request);
}
private:
Filter filter_;
util::PCQueue<Request> &done_;
};
// There should only be one OutputWorker.
template <class Batch, class Output> class OutputWorker {
public:
typedef Batch *Request;
OutputWorker(Output &output, util::PCQueue<Request> &done) : output_(output), done_(done), base_sequence_(0) {}
void operator()(Request request) {
assert(request->Sequence() >= base_sequence_);
// Assemble the output in order.
uint64_t pos = request->Sequence() - base_sequence_;
if (pos >= ordering_.size()) {
ordering_.resize(pos + 1, NULL);
}
ordering_[pos] = request;
while (!ordering_.empty() && ordering_.front()) {
ordering_.front()->Flush(output_);
done_.Produce(ordering_.front());
ordering_.pop_front();
++base_sequence_;
}
}
private:
Output &output_;
util::PCQueue<Request> &done_;
std::deque<Request> ordering_;
uint64_t base_sequence_;
};
template <class Filter, class OutputBuffer, class RealOutput> class Controller : boost::noncopyable {
private:
typedef ThreadBatch<OutputBuffer> Batch;
public:
Controller(size_t batch_size, size_t queue, size_t workers, const Filter &filter, RealOutput &output)
: batch_size_(batch_size), queue_size_(queue),
batches_(queue),
to_read_(queue),
output_(queue, 1, boost::in_place(boost::ref(output), boost::ref(to_read_)), NULL),
filter_(queue, workers, boost::in_place(boost::ref(filter), boost::ref(output_.In())), NULL),
sequence_(0) {
for (size_t i = 0; i < queue; ++i) {
batches_[i].Reserve(batch_size);
local_read_.push(&batches_[i]);
}
NewInput();
}
void AddNGram(const StringPiece &ngram, const StringPiece &line, RealOutput &output) {
input_->AddNGram(ngram, line, output);
if (input_->Size() == batch_size_) {
FlushInput();
NewInput();
}
}
void Flush() {
FlushInput();
while (local_read_.size() < queue_size_) {
MoveRead();
}
NewInput();
}
private:
void FlushInput() {
if (input_->Empty()) return;
filter_.Produce(local_read_.top());
local_read_.pop();
if (local_read_.empty()) MoveRead();
}
void NewInput() {
input_ = &local_read_.top()->Fill(sequence_++);
}
void MoveRead() {
local_read_.push(to_read_.Consume());
}
const size_t batch_size_;
const size_t queue_size_;
std::vector<Batch> batches_;
util::PCQueue<Batch*> to_read_;
std::stack<Batch*> local_read_;
util::ThreadPool<OutputWorker<Batch, RealOutput> > output_;
util::ThreadPool<FilterWorker<Batch, Filter> > filter_;
uint64_t sequence_;
InputBuffer *input_;
};
} // namespace lm
#endif // LM_FILTER_THREAD_H