forked from kpu/kenlm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
adjust_counts.hh
72 lines (58 loc) · 2.21 KB
/
adjust_counts.hh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#ifndef LM_BUILDER_ADJUST_COUNTS_H
#define LM_BUILDER_ADJUST_COUNTS_H
#include "discount.hh"
#include "../lm_exception.hh"
#include "../../util/exception.hh"
#include <vector>
#include <stdint.h>
namespace util { namespace stream { class ChainPositions; } }
namespace lm {
namespace builder {
class BadDiscountException : public util::Exception {
public:
BadDiscountException() throw();
~BadDiscountException() throw();
};
struct DiscountConfig {
// Overrides discounts for orders [1,discount_override.size()].
std::vector<Discount> overwrite;
// If discounting fails for an order, copy them from here.
Discount fallback;
// What to do when discounts are out of range or would trigger divison by
// zero. It it does something other than THROW_UP, use fallback_discount.
WarningAction bad_action;
};
/* Compute adjusted counts.
* Input: unique suffix sorted N-grams (and just the N-grams) with raw counts.
* Output: [1,N]-grams with adjusted counts.
* [1,N)-grams are in suffix order
* N-grams are in undefined order (they're going to be sorted anyway).
*/
class AdjustCounts {
public:
// counts: output
// counts_pruned: output
// discounts: mostly output. If the input already has entries, they will be kept.
// prune_thresholds: input. n-grams with normal (not adjusted) count below this will be pruned.
AdjustCounts(
const std::vector<uint64_t> &prune_thresholds,
std::vector<uint64_t> &counts,
std::vector<uint64_t> &counts_pruned,
const std::vector<bool> &prune_words,
const DiscountConfig &discount_config,
std::vector<Discount> &discounts)
: prune_thresholds_(prune_thresholds), counts_(counts), counts_pruned_(counts_pruned),
prune_words_(prune_words), discount_config_(discount_config), discounts_(discounts)
{}
void Run(const util::stream::ChainPositions &positions);
private:
const std::vector<uint64_t> &prune_thresholds_;
std::vector<uint64_t> &counts_;
std::vector<uint64_t> &counts_pruned_;
const std::vector<bool> &prune_words_;
DiscountConfig discount_config_;
std::vector<Discount> &discounts_;
};
} // namespace builder
} // namespace lm
#endif // LM_BUILDER_ADJUST_COUNTS_H