-
Notifications
You must be signed in to change notification settings - Fork 0
/
merge_vocab.cc
131 lines (109 loc) · 3.53 KB
/
merge_vocab.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#include "lm/interpolate/merge_vocab.hh"
#include "lm/enumerate_vocab.hh"
#include "lm/interpolate/universal_vocab.hh"
#include "lm/lm_exception.hh"
#include "lm/vocab.hh"
#include "util/file_piece.hh"
#include <queue>
#include <string>
#include <iostream>
#include <vector>
namespace lm {
namespace interpolate {
namespace {
class VocabFileReader {
public:
explicit VocabFileReader(const int fd, size_t model_num, uint64_t offset = 0);
VocabFileReader &operator++();
operator bool() const { return !eof_; }
uint64_t operator*() const { return Value(); }
uint64_t Value() const { return hash_value_; }
size_t ModelNum() const { return model_num_; }
WordIndex CurrentIndex() const { return current_index_; }
StringPiece Word() const { return word_; }
private:
uint64_t hash_value_;
WordIndex current_index_;
bool eof_;
size_t model_num_;
StringPiece word_;
util::FilePiece file_piece_;
};
VocabFileReader::VocabFileReader(const int fd, const size_t model_num, uint64_t offset) :
hash_value_(0),
current_index_(0),
eof_(false),
model_num_(model_num),
file_piece_(util::DupOrThrow(fd)) {
word_ = file_piece_.ReadLine('\0');
UTIL_THROW_IF(word_ != "<unk>",
FormatLoadException,
"Vocabulary words are in the wrong place.");
// setup to initial value
++*this;
}
VocabFileReader &VocabFileReader::operator++() {
try {
word_ = file_piece_.ReadLine('\0');
} catch(util::EndOfFileException &e) {
eof_ = true;
return *this;
}
uint64_t prev_hash_value = hash_value_;
hash_value_ = ngram::detail::HashForVocab(word_.data(), word_.size());
// hash values should be monotonically increasing
UTIL_THROW_IF(hash_value_ < prev_hash_value, FormatLoadException,
": word index not monotonically increasing."
<< " model_num: " << model_num_
<< " prev hash: " << prev_hash_value
<< " new hash: " << hash_value_);
++current_index_;
return *this;
}
class CompareFiles {
public:
bool operator()(const VocabFileReader* x,
const VocabFileReader* y)
{ return x->Value() > y->Value(); }
};
class Readers : public util::FixedArray<VocabFileReader> {
public:
Readers(std::size_t number) : util::FixedArray<VocabFileReader>(number) {}
void push_back(int fd, std::size_t i) {
new(end()) VocabFileReader(fd, i);
Constructed();
}
};
} // namespace
WordIndex MergeVocab(util::FixedArray<int> &files, UniversalVocab &vocab, EnumerateVocab &enumerate) {
typedef std::priority_queue<VocabFileReader*, std::vector<VocabFileReader*>, CompareFiles> HeapType;
HeapType heap;
Readers readers(files.size());
for (size_t i = 0; i < files.size(); ++i) {
readers.push_back(files[i], i);
heap.push(&readers.back());
// initialize first index to 0 for <unk>
vocab.InsertUniversalIdx(i, 0, 0);
}
uint64_t prev_hash_value = 0;
// global_index starts with <unk> which is 0
WordIndex global_index = 0;
enumerate.Add(0, "<unk>");
while (!heap.empty()) {
VocabFileReader* top_vocab_file = heap.top();
if (top_vocab_file->Value() != prev_hash_value) {
enumerate.Add(++global_index, top_vocab_file->Word());
}
vocab.InsertUniversalIdx(top_vocab_file->ModelNum(),
top_vocab_file->CurrentIndex(),
global_index);
prev_hash_value = top_vocab_file->Value();
heap.pop();
if (++(*top_vocab_file)) {
heap.push(top_vocab_file);
}
}
return global_index + 1;
}
} // namespace interpolate
} // namespace lm