-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.hh
155 lines (126 loc) · 6.92 KB
/
model.hh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#ifndef LM_MODEL_H
#define LM_MODEL_H
#include "lm/bhiksha.hh"
#include "lm/binary_format.hh"
#include "lm/config.hh"
#include "lm/facade.hh"
#include "lm/quantize.hh"
#include "lm/search_hashed.hh"
#include "lm/search_trie.hh"
#include "lm/state.hh"
#include "lm/value.hh"
#include "lm/vocab.hh"
#include "lm/weights.hh"
#include "util/murmur_hash.hh"
#include <algorithm>
#include <vector>
#include <cstring>
namespace util { class FilePiece; }
namespace lm {
namespace ngram {
namespace detail {
// Should return the same results as SRI.
// ModelFacade typedefs Vocabulary so we use VocabularyT to avoid naming conflicts.
template <class Search, class VocabularyT> class GenericModel : public base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> {
private:
typedef base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> P;
public:
// This is the model type returned by RecognizeBinary.
static const ModelType kModelType;
static const unsigned int kVersion = Search::kVersion;
/* Get the size of memory that will be mapped given ngram counts. This
* does not include small non-mapped control structures, such as this class
* itself.
*/
static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config = Config());
/* Load the model from a file. It may be an ARPA or binary file. Binary
* files must have the format expected by this class or you'll get an
* exception. So TrieModel can only load ARPA or binary created by
* TrieModel. To classify binary files, call RecognizeBinary in
* lm/binary_format.hh.
*/
explicit GenericModel(const char *file, const Config &config = Config());
/* Score p(new_word | in_state) and incorporate new_word into out_state.
* Note that in_state and out_state must be different references:
* &in_state != &out_state.
*/
FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const;
/* Slower call without in_state. Try to remember state, but sometimes it
* would cost too much memory or your decoder isn't setup properly.
* To use this function, make an array of WordIndex containing the context
* vocabulary ids in reverse order. Then, pass the bounds of the array:
* [context_rbegin, context_rend). The new_word is not part of the context
* array unless you intend to repeat words.
*/
FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const;
/* Get the state for a context. Don't use this if you can avoid it. Use
* BeginSentenceState or NullContextState and extend from those. If
* you're only going to use this state to call FullScore once, use
* FullScoreForgotState.
* To use this function, make an array of WordIndex containing the context
* vocabulary ids in reverse order. Then, pass the bounds of the array:
* [context_rbegin, context_rend).
*/
void GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const;
/* More efficient version of FullScore where a partial n-gram has already
* been scored.
* NOTE: THE RETURNED .rest AND .prob ARE RELATIVE TO THE .rest RETURNED BEFORE.
*/
FullScoreReturn ExtendLeft(
// Additional context in reverse order. This will update add_rend to
const WordIndex *add_rbegin, const WordIndex *add_rend,
// Backoff weights to use.
const float *backoff_in,
// extend_left returned by a previous query.
uint64_t extend_pointer,
// Length of n-gram that the pointer corresponds to.
unsigned char extend_length,
// Where to write additional backoffs for [extend_length + 1, min(Order() - 1, return.ngram_length)]
float *backoff_out,
// Amount of additional content that should be considered by the next call.
unsigned char &next_use) const;
/* Return probabilities minus rest costs for an array of pointers. The
* first length should be the length of the n-gram to which pointers_begin
* points.
*/
float UnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const {
// Compiler should optimize this if away.
return Search::kDifferentRest ? InternalUnRest(pointers_begin, pointers_end, first_length) : 0.0;
}
private:
FullScoreReturn ScoreExceptBackoff(const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const;
// Score bigrams and above. Do not include backoff.
void ResumeScore(const WordIndex *context_rbegin, const WordIndex *const context_rend, unsigned char starting_order_minus_2, typename Search::Node &node, float *backoff_out, unsigned char &next_use, FullScoreReturn &ret) const;
// Appears after Size in the cc file.
void SetupMemory(void *start, const std::vector<uint64_t> &counts, const Config &config);
void InitializeFromARPA(int fd, const char *file, const Config &config);
float InternalUnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const;
BinaryFormat backing_;
VocabularyT vocab_;
Search search_;
};
} // namespace detail
// Instead of typedef, inherit. This allows the Model etc to be forward declared.
// Oh the joys of C and C++.
#define LM_COMMA() ,
#define LM_NAME_MODEL(name, from)\
class name : public from {\
public:\
name(const char *file, const Config &config = Config()) : from(file, config) {}\
};
LM_NAME_MODEL(ProbingModel, detail::GenericModel<detail::HashedSearch<BackoffValue> LM_COMMA() ProbingVocabulary>);
LM_NAME_MODEL(RestProbingModel, detail::GenericModel<detail::HashedSearch<RestValue> LM_COMMA() ProbingVocabulary>);
LM_NAME_MODEL(TrieModel, detail::GenericModel<trie::TrieSearch<DontQuantize LM_COMMA() trie::DontBhiksha> LM_COMMA() SortedVocabulary>);
LM_NAME_MODEL(ArrayTrieModel, detail::GenericModel<trie::TrieSearch<DontQuantize LM_COMMA() trie::ArrayBhiksha> LM_COMMA() SortedVocabulary>);
LM_NAME_MODEL(QuantTrieModel, detail::GenericModel<trie::TrieSearch<SeparatelyQuantize LM_COMMA() trie::DontBhiksha> LM_COMMA() SortedVocabulary>);
LM_NAME_MODEL(QuantArrayTrieModel, detail::GenericModel<trie::TrieSearch<SeparatelyQuantize LM_COMMA() trie::ArrayBhiksha> LM_COMMA() SortedVocabulary>);
// Default implementation. No real reason for it to be the default.
typedef ::lm::ngram::ProbingVocabulary Vocabulary;
typedef ProbingModel Model;
/* Autorecognize the file type, load, and return the virtual base class. Don't
* use the virtual base class if you can avoid it. Instead, use the above
* classes as template arguments to your own virtual feature function.*/
base::Model *LoadVirtual(const char *file_name, const Config &config = Config(), ModelType if_arpa = PROBING);
} // namespace ngram
} // namespace lm
#endif // LM_MODEL_H