-
Notifications
You must be signed in to change notification settings - Fork 0
/
vocab.cc
329 lines (289 loc) · 11.5 KB
/
vocab.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
#include "lm/vocab.hh"
#include "lm/binary_format.hh"
#include "lm/enumerate_vocab.hh"
#include "lm/lm_exception.hh"
#include "lm/config.hh"
#include "lm/weights.hh"
#include "util/exception.hh"
#include "util/file_stream.hh"
#include "util/file.hh"
#include "util/joint_sort.hh"
#include "util/murmur_hash.hh"
#include "util/probing_hash_table.hh"
#include <cstring>
#include <string>
namespace lm {
namespace ngram {
namespace detail {
uint64_t HashForVocab(const char *str, std::size_t len) {
// This proved faster than Boost's hash in speed trials: total load time Murmur 67090000, Boost 72210000
// Chose to use 64A instead of native so binary format will be portable across 64 and 32 bit.
return util::MurmurHash64A(str, len, 0);
}
} // namespace detail
namespace {
// Normally static initialization is a bad idea but MurmurHash is pure arithmetic, so this is ok.
const uint64_t kUnknownHash = detail::HashForVocab("<unk>", 5);
// Sadly some LMs have <UNK>.
const uint64_t kUnknownCapHash = detail::HashForVocab("<UNK>", 5);
// TODO: replace with FilePiece.
void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count, uint64_t offset) {
util::SeekOrThrow(fd, offset);
// Check that we're at the right place by reading <unk> which is always first.
char check_unk[6];
util::ReadOrThrow(fd, check_unk, 6);
UTIL_THROW_IF(
memcmp(check_unk, "<unk>", 6),
FormatLoadException,
"Vocabulary words are in the wrong place. This could be because the binary file was built with stale gcc and old kenlm. Stale gcc, including the gcc distributed with RedHat and OS X, has a bug that ignores pragma pack for template-dependent types. New kenlm works around this, so you'll save memory but have to rebuild any binary files using the probing data structure.");
if (!enumerate) return;
enumerate->Add(0, "<unk>");
// Read all the words after unk.
const std::size_t kInitialRead = 16384;
std::string buf;
buf.reserve(kInitialRead + 100);
buf.resize(kInitialRead);
WordIndex index = 1; // Read <unk> already.
while (true) {
std::size_t got = util::ReadOrEOF(fd, &buf[0], kInitialRead);
if (got == 0) break;
buf.resize(got);
while (buf[buf.size() - 1]) {
char next_char;
util::ReadOrThrow(fd, &next_char, 1);
buf.push_back(next_char);
}
// Ok now we have null terminated strings.
for (const char *i = buf.data(); i != buf.data() + buf.size();) {
std::size_t length = strlen(i);
enumerate->Add(index++, StringPiece(i, length));
i += length + 1 /* null byte */;
}
}
UTIL_THROW_IF(expected_count != index, FormatLoadException, "The binary file has the wrong number of words at the end. This could be caused by a truncated binary file.");
}
// Constructor ordering madness.
int SeekAndReturn(int fd, uint64_t start) {
util::SeekOrThrow(fd, start);
return fd;
}
} // namespace
ImmediateWriteWordsWrapper::ImmediateWriteWordsWrapper(EnumerateVocab *inner, int fd, uint64_t start)
: inner_(inner), stream_(SeekAndReturn(fd, start)) {}
WriteWordsWrapper::WriteWordsWrapper(EnumerateVocab *inner) : inner_(inner) {}
void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) {
if (inner_) inner_->Add(index, str);
buffer_.append(str.data(), str.size());
buffer_.push_back(0);
}
void WriteWordsWrapper::Write(int fd, uint64_t start) {
util::SeekOrThrow(fd, start);
util::WriteOrThrow(fd, buffer_.data(), buffer_.size());
// Free memory from the string.
std::string for_swap;
std::swap(buffer_, for_swap);
}
SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {}
uint64_t SortedVocabulary::Size(uint64_t entries, const Config &/*config*/) {
// Lead with the number of entries.
return sizeof(uint64_t) + sizeof(uint64_t) * entries;
}
void SortedVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config) {
assert(allocated >= Size(entries, config));
// Leave space for number of entries.
begin_ = reinterpret_cast<uint64_t*>(start) + 1;
end_ = begin_;
saw_unk_ = false;
}
void SortedVocabulary::Relocate(void *new_start) {
std::size_t delta = end_ - begin_;
begin_ = reinterpret_cast<uint64_t*>(new_start) + 1;
end_ = begin_ + delta;
}
void SortedVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries) {
enumerate_ = to;
if (enumerate_) {
enumerate_->Add(0, "<unk>");
strings_to_enumerate_.resize(max_entries);
}
}
WordIndex SortedVocabulary::Insert(const StringPiece &str) {
uint64_t hashed = detail::HashForVocab(str);
if (hashed == kUnknownHash || hashed == kUnknownCapHash) {
saw_unk_ = true;
return 0;
}
*end_ = hashed;
if (enumerate_) {
void *copied = string_backing_.Allocate(str.size());
memcpy(copied, str.data(), str.size());
strings_to_enumerate_[end_ - begin_] = StringPiece(static_cast<const char*>(copied), str.size());
}
++end_;
// This is 1 + the offset where it was inserted to make room for unk.
return end_ - begin_;
}
void SortedVocabulary::FinishedLoading(ProbBackoff *reorder) {
GenericFinished(reorder);
}
namespace {
#pragma pack(push)
#pragma pack(4)
struct RenumberEntry {
uint64_t hash;
const char *str;
WordIndex old;
bool operator<(const RenumberEntry &other) const {
return hash < other.hash;
}
};
#pragma pack(pop)
} // namespace
void SortedVocabulary::ComputeRenumbering(WordIndex types, int from_words, int to_words, std::vector<WordIndex> &mapping) {
mapping.clear();
uint64_t file_size = util::SizeOrThrow(from_words);
util::scoped_memory strings;
util::MapRead(util::POPULATE_OR_READ, from_words, 0, file_size, strings);
const char *const start = static_cast<const char*>(strings.get());
UTIL_THROW_IF(memcmp(start, "<unk>", 6), FormatLoadException, "Vocab file does not begin with <unk> followed by null");
std::vector<RenumberEntry> entries;
entries.reserve(types - 1);
RenumberEntry entry;
entry.old = 1;
for (entry.str = start + 6 /* skip <unk>\0 */; entry.str < start + file_size; ++entry.old) {
StringPiece str(entry.str, strlen(entry.str));
entry.hash = detail::HashForVocab(str);
entries.push_back(entry);
entry.str += str.size() + 1;
}
UTIL_THROW_IF2(entries.size() != types - 1, "Wrong number of vocab ids. Got " << (entries.size() + 1) << " expected " << types);
std::sort(entries.begin(), entries.end());
// Write out new vocab file.
{
util::FileStream out(to_words);
out << "<unk>" << '\0';
for (std::vector<RenumberEntry>::const_iterator i = entries.begin(); i != entries.end(); ++i) {
out << i->str << '\0';
}
}
strings.reset();
mapping.resize(types);
mapping[0] = 0; // <unk>
for (std::vector<RenumberEntry>::const_iterator i = entries.begin(); i != entries.end(); ++i) {
mapping[i->old] = i + 1 - entries.begin();
}
}
void SortedVocabulary::Populated() {
saw_unk_ = true;
SetSpecial(Index("<s>"), Index("</s>"), 0);
bound_ = end_ - begin_ + 1;
*(reinterpret_cast<uint64_t*>(begin_) - 1) = end_ - begin_;
}
void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) {
end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1);
SetSpecial(Index("<s>"), Index("</s>"), 0);
bound_ = end_ - begin_ + 1;
if (have_words) ReadWords(fd, to, bound_, offset);
}
template <class T> void SortedVocabulary::GenericFinished(T *reorder) {
if (enumerate_) {
if (!strings_to_enumerate_.empty()) {
util::PairedIterator<T*, StringPiece*> values(reorder + 1, &*strings_to_enumerate_.begin());
util::JointSort(begin_, end_, values);
}
for (WordIndex i = 0; i < static_cast<WordIndex>(end_ - begin_); ++i) {
// <unk> strikes again: +1 here.
enumerate_->Add(i + 1, strings_to_enumerate_[i]);
}
strings_to_enumerate_.clear();
string_backing_.FreeAll();
} else {
util::JointSort(begin_, end_, reorder + 1);
}
SetSpecial(Index("<s>"), Index("</s>"), 0);
// Save size. Excludes UNK.
*(reinterpret_cast<uint64_t*>(begin_) - 1) = end_ - begin_;
// Includes UNK.
bound_ = end_ - begin_ + 1;
}
namespace {
const unsigned int kProbingVocabularyVersion = 0;
} // namespace
namespace detail {
struct ProbingVocabularyHeader {
// Lowest unused vocab id. This is also the number of words, including <unk>.
unsigned int version;
WordIndex bound;
};
} // namespace detail
ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {}
uint64_t ProbingVocabulary::Size(uint64_t entries, float probing_multiplier) {
return ALIGN8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, probing_multiplier);
}
uint64_t ProbingVocabulary::Size(uint64_t entries, const Config &config) {
return Size(entries, config.probing_multiplier);
}
void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated) {
header_ = static_cast<detail::ProbingVocabularyHeader*>(start);
lookup_ = Lookup(static_cast<uint8_t*>(start) + ALIGN8(sizeof(detail::ProbingVocabularyHeader)), allocated);
bound_ = 1;
saw_unk_ = false;
}
void ProbingVocabulary::Relocate(void *new_start) {
header_ = static_cast<detail::ProbingVocabularyHeader*>(new_start);
lookup_.Relocate(static_cast<uint8_t*>(new_start) + ALIGN8(sizeof(detail::ProbingVocabularyHeader)));
}
void ProbingVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t /*max_entries*/) {
enumerate_ = to;
if (enumerate_) {
enumerate_->Add(0, "<unk>");
}
}
WordIndex ProbingVocabulary::Insert(const StringPiece &str) {
uint64_t hashed = detail::HashForVocab(str);
// Prevent unknown from going into the table.
if (hashed == kUnknownHash || hashed == kUnknownCapHash) {
saw_unk_ = true;
return 0;
} else {
if (enumerate_) enumerate_->Add(bound_, str);
lookup_.Insert(ProbingVocabularyEntry::Make(hashed, bound_));
return bound_++;
}
}
void ProbingVocabulary::InternalFinishedLoading() {
lookup_.FinishedInserting();
header_->bound = bound_;
header_->version = kProbingVocabularyVersion;
SetSpecial(Index("<s>"), Index("</s>"), 0);
}
void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) {
UTIL_THROW_IF(header_->version != kProbingVocabularyVersion, FormatLoadException, "The binary file has probing version " << header_->version << " but the code expects version " << kProbingVocabularyVersion << ". Please rerun build_binary using the same version of the code.");
bound_ = header_->bound;
SetSpecial(Index("<s>"), Index("</s>"), 0);
if (have_words) ReadWords(fd, to, bound_, offset);
}
void MissingUnknown(const Config &config) throw(SpecialWordMissingException) {
switch(config.unknown_missing) {
case SILENT:
return;
case COMPLAIN:
if (config.messages) *config.messages << "The ARPA file is missing <unk>. Substituting log10 probability " << config.unknown_missing_logprob << "." << std::endl;
break;
case THROW_UP:
UTIL_THROW(SpecialWordMissingException, "The ARPA file is missing <unk> and the model is configured to throw an exception.");
}
}
void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialWordMissingException) {
switch (config.sentence_marker_missing) {
case SILENT:
return;
case COMPLAIN:
if (config.messages) *config.messages << "Missing special word " << str << "; will treat it as <unk>.";
break;
case THROW_UP:
UTIL_THROW(SpecialWordMissingException, "The ARPA file is missing " << str << " and the model is configured to reject these models. Run build_binary -s to disable this check.");
}
}
} // namespace ngram
} // namespace lm