-
Notifications
You must be signed in to change notification settings - Fork 0
/
state.hh
125 lines (99 loc) · 3.36 KB
/
state.hh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#ifndef LM_STATE_H
#define LM_STATE_H
#include "lm/max_order.hh"
#include "lm/word_index.hh"
#include "util/murmur_hash.hh"
#include <cstring>
namespace lm {
namespace ngram {
// This is a POD but if you want memcmp to return the same as operator==, call
// ZeroRemaining first.
class State {
public:
bool operator==(const State &other) const {
if (length != other.length) return false;
return !memcmp(words, other.words, length * sizeof(WordIndex));
}
// Three way comparison function.
int Compare(const State &other) const {
if (length != other.length) return length < other.length ? -1 : 1;
return memcmp(words, other.words, length * sizeof(WordIndex));
}
bool operator<(const State &other) const {
if (length != other.length) return length < other.length;
return memcmp(words, other.words, length * sizeof(WordIndex)) < 0;
}
// Call this before using raw memcmp.
void ZeroRemaining() {
for (unsigned char i = length; i < KENLM_MAX_ORDER - 1; ++i) {
words[i] = 0;
backoff[i] = 0.0;
}
}
unsigned char Length() const { return length; }
// You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD.
// This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit.
WordIndex words[KENLM_MAX_ORDER - 1];
float backoff[KENLM_MAX_ORDER - 1];
unsigned char length;
};
typedef State Right;
inline uint64_t hash_value(const State &state, uint64_t seed = 0) {
return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length, seed);
}
struct Left {
bool operator==(const Left &other) const {
return
length == other.length &&
(!length || (pointers[length - 1] == other.pointers[length - 1] && full == other.full));
}
int Compare(const Left &other) const {
if (length < other.length) return -1;
if (length > other.length) return 1;
if (length == 0) return 0; // Must be full.
if (pointers[length - 1] > other.pointers[length - 1]) return 1;
if (pointers[length - 1] < other.pointers[length - 1]) return -1;
return (int)full - (int)other.full;
}
bool operator<(const Left &other) const {
return Compare(other) == -1;
}
void ZeroRemaining() {
for (uint64_t * i = pointers + length; i < pointers + KENLM_MAX_ORDER - 1; ++i)
*i = 0;
}
uint64_t pointers[KENLM_MAX_ORDER - 1];
unsigned char length;
bool full;
};
inline uint64_t hash_value(const Left &left) {
unsigned char add[2];
add[0] = left.length;
add[1] = left.full;
return util::MurmurHashNative(add, 2, left.length ? left.pointers[left.length - 1] : 0);
}
struct ChartState {
bool operator==(const ChartState &other) const {
return (right == other.right) && (left == other.left);
}
int Compare(const ChartState &other) const {
int lres = left.Compare(other.left);
if (lres) return lres;
return right.Compare(other.right);
}
bool operator<(const ChartState &other) const {
return Compare(other) < 0;
}
void ZeroRemaining() {
left.ZeroRemaining();
right.ZeroRemaining();
}
Left left;
State right;
};
inline uint64_t hash_value(const ChartState &state) {
return hash_value(state.right, hash_value(state.left));
}
} // namespace ngram
} // namespace lm
#endif // LM_STATE_H