forked from kpu/kenlm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrie.cc
131 lines (109 loc) · 5.51 KB
/
trie.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#include "trie.hh"
#include "bhiksha.hh"
#include "../util/bit_packing.hh"
#include "../util/exception.hh"
#include "../util/sorted_uniform.hh"
#include <cassert>
namespace lm {
namespace ngram {
namespace trie {
namespace {
class KeyAccessor {
public:
KeyAccessor(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits)
: base_(reinterpret_cast<const uint8_t*>(base)), key_mask_(key_mask), key_bits_(key_bits), total_bits_(total_bits) {}
typedef uint64_t Key;
Key operator()(uint64_t index) const {
return util::ReadInt57(base_, index * static_cast<uint64_t>(total_bits_), key_bits_, key_mask_);
}
private:
const uint8_t *const base_;
const WordIndex key_mask_;
const uint8_t key_bits_, total_bits_;
};
bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, const uint64_t max_vocab, const uint64_t key, uint64_t &at_index) {
KeyAccessor accessor(base, key_mask, key_bits, total_bits);
if (!util::BoundedSortedUniformFind<uint64_t, KeyAccessor, util::PivotSelect<sizeof(WordIndex)>::T>(accessor, begin_index - 1, (uint64_t)0, end_index, max_vocab, key, at_index)) return false;
return true;
}
} // namespace
uint64_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) {
uint8_t total_bits = util::RequiredBits(max_vocab) + remaining_bits;
// Extra entry for next pointer at the end.
// +7 then / 8 to round up bits and convert to bytes
// +sizeof(uint64_t) so that ReadInt57 etc don't go segfault.
// Note that this waste is O(order), not O(number of ngrams).
return ((1 + entries) * total_bits + 7) / 8 + sizeof(uint64_t);
}
void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits) {
util::BitPackingSanity();
word_bits_ = util::RequiredBits(max_vocab);
word_mask_ = (1ULL << word_bits_) - 1ULL;
if (word_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, word indices more than " << (1ULL << 57) << " are not implemented. Edit util/bit_packing.hh and fix the bit packing functions.");
total_bits_ = word_bits_ + remaining_bits;
base_ = static_cast<uint8_t*>(base);
insert_index_ = 0;
max_vocab_ = max_vocab;
}
template <class Bhiksha> uint64_t BitPackedMiddle<Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) {
return Bhiksha::Size(entries + 1, max_ptr, config) + BaseSize(entries, max_vocab, quant_bits + Bhiksha::InlineBits(entries + 1, max_ptr, config));
}
template <class Bhiksha> BitPackedMiddle<Bhiksha>::BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config) :
BitPacked(),
quant_bits_(quant_bits),
// If the offset of the method changes, also change TrieSearch::UpdateConfigFromBinary.
bhiksha_(base, entries + 1, max_next, config),
next_source_(&next_source) {
if (entries + 1 >= (1ULL << 57) || (max_next >= (1ULL << 57))) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions.");
BaseInit(reinterpret_cast<uint8_t*>(base) + Bhiksha::Size(entries + 1, max_next, config), max_vocab, quant_bits_ + bhiksha_.InlineBits());
}
template <class Bhiksha> util::BitAddress BitPackedMiddle<Bhiksha>::Insert(WordIndex word) {
assert(word <= word_mask_);
uint64_t at_pointer = insert_index_ * total_bits_;
util::WriteInt57(base_, at_pointer, word_bits_, word);
at_pointer += word_bits_;
util::BitAddress ret(base_, at_pointer);
at_pointer += quant_bits_;
uint64_t next = next_source_->InsertIndex();
bhiksha_.WriteNext(base_, at_pointer, insert_index_, next);
++insert_index_;
return ret;
}
template <class Bhiksha> util::BitAddress BitPackedMiddle<Bhiksha>::Find(WordIndex word, NodeRange &range, uint64_t &pointer) const {
uint64_t at_pointer;
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) {
return util::BitAddress(NULL, 0);
}
pointer = at_pointer;
at_pointer *= total_bits_;
at_pointer += word_bits_;
bhiksha_.ReadNext(base_, at_pointer + quant_bits_, pointer, total_bits_, range);
return util::BitAddress(base_, at_pointer);
}
template <class Bhiksha> void BitPackedMiddle<Bhiksha>::FinishedLoading(uint64_t next_end, const Config &config) {
// Write at insert_index. . .
uint64_t last_next_write = insert_index_ * total_bits_ +
// at the offset where the next pointers are stored.
(total_bits_ - bhiksha_.InlineBits());
bhiksha_.WriteNext(base_, last_next_write, insert_index_, next_end);
bhiksha_.FinishedLoading(config);
}
util::BitAddress BitPackedLongest::Insert(WordIndex index) {
assert(index <= word_mask_);
uint64_t at_pointer = insert_index_ * total_bits_;
util::WriteInt57(base_, at_pointer, word_bits_, index);
at_pointer += word_bits_;
++insert_index_;
return util::BitAddress(base_, at_pointer);
}
util::BitAddress BitPackedLongest::Find(WordIndex word, const NodeRange &range) const {
uint64_t at_pointer;
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return util::BitAddress(NULL, 0);
at_pointer = at_pointer * total_bits_ + word_bits_;
return util::BitAddress(base_, at_pointer);
}
template class BitPackedMiddle<DontBhiksha>;
template class BitPackedMiddle<ArrayBhiksha>;
} // namespace trie
} // namespace ngram
} // namespace lm