Skip to content

Commit

Permalink
Renaming and avoid decoding prob if we don't have to
Browse files Browse the repository at this point in the history
git-svn-id: file:///dev/shm/somefilter.svn@455 e102df66-1e2e-11dd-9b44-c24451a4db5e
  • Loading branch information
kpu committed Nov 1, 2010
1 parent 7b9569d commit 7e003da
Show file tree
Hide file tree
Showing 15 changed files with 206 additions and 29 deletions.
2 changes: 1 addition & 1 deletion lm/binary_format.hh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#ifndef LM_BINARY_FORMAT__
#define LM_BINARY_FORMAT__

#include "lm/ngram_config.hh"
#include "lm/config.hh"
#include "lm/read_arpa.hh"

#include "util/file_piece.hh"
Expand Down
2 changes: 1 addition & 1 deletion lm/ngram_build_binary.cc → lm/build_binary.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "lm/ngram.hh"
#include "lm/model.hh"

#include <iostream>

Expand Down
2 changes: 1 addition & 1 deletion lm/ngram_config.cc → lm/config.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "lm/ngram_config.hh"
#include "lm/config.hh"

#include <iostream>

Expand Down
6 changes: 3 additions & 3 deletions lm/ngram_config.hh → lm/config.hh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef LM_NGRAM_CONFIG__
#define LM_NGRAM_CONFIG__
#ifndef LM_CONFIG__
#define LM_CONFIG__

#include <iosfwd>

Expand Down Expand Up @@ -75,4 +75,4 @@ struct Config {

} /* namespace ngram */ } /* namespace lm */

#endif // LM_NGRAM_CONFIG__
#endif // LM_CONFIG__
12 changes: 6 additions & 6 deletions lm/ngram.cc → lm/model.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include "lm/ngram.hh"
#include "lm/model.hh"

#include "lm/lm_exception.hh"
#include "lm/ngram_hashed.hh"
#include "lm/ngram_trie.hh"
#include "lm/search_hashed.hh"
#include "lm/search_trie.hh"
#include "lm/read_arpa.hh"
#include "util/murmur_hash.hh"

Expand Down Expand Up @@ -120,7 +120,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
float *backoff_out = out_state.backoff_ + 1;
const WordIndex *i = context_rbegin + 1;
for (; i < context_rend; ++i, ++backoff_out) {
if (!search_.LookupMiddle(search_.middle[i - context_rbegin - 1], *i, ignored_prob, *backoff_out, node)) {
if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, *backoff_out, node)) {
out_state.valid_length_ = i - context_rbegin;
std::copy(context_rbegin, i, out_state.history_);
return;
Expand All @@ -143,10 +143,10 @@ template <class Search, class VocabularyT> float GenericModel<Search, Vocabulary
if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) {
return 0.0;
}
float ignored_prob, backoff;
float backoff;
// i is the order of the backoff we're looking for.
for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i) {
if (!search_.LookupMiddle(search_.middle[i - context_rbegin - 1], *i, ignored_prob, backoff, node)) break;
if (!search_.LookupMiddleNoProb(search_.middle[i - context_rbegin - 1], *i, backoff, node)) break;
ret += backoff;
}
return ret;
Expand Down
12 changes: 6 additions & 6 deletions lm/ngram.hh → lm/model.hh
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#ifndef LM_NGRAM__
#define LM_NGRAM__
#ifndef LM_MODEL__
#define LM_MODEL__

#include "lm/binary_format.hh"
#include "lm/config.hh"
#include "lm/facade.hh"
#include "lm/ngram_config.hh"
#include "lm/ngram_hashed.hh"
#include "lm/ngram_trie.hh"
#include "lm/search_hashed.hh"
#include "lm/search_trie.hh"
#include "lm/vocab.hh"
#include "lm/weights.hh"

Expand Down Expand Up @@ -123,4 +123,4 @@ typedef detail::GenericModel<trie::TrieSearch, SortedVocabulary> TrieModel;
} // namespace ngram
} // namespace lm

#endif // LM_NGRAM__
#endif // LM_MODEL__
4 changes: 2 additions & 2 deletions lm/ngram_test.cc → lm/model_test.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include "lm/ngram.hh"
#include "lm/model.hh"

#include <stdlib.h>

#define BOOST_TEST_MODULE NGramTest
#define BOOST_TEST_MODULE ModelTest
#include <boost/test/unit_test.hpp>

namespace lm {
Expand Down
2 changes: 1 addition & 1 deletion lm/ngram_query.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "lm/ngram.hh"
#include "lm/model.hh"

#include <cstdlib>
#include <fstream>
Expand Down
2 changes: 1 addition & 1 deletion lm/ngram_hashed.cc → lm/search_hashed.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "lm/ngram_hashed.hh"
#include "lm/search_hashed.hh"

#include "lm/lm_exception.hh"
#include "lm/read_arpa.hh"
Expand Down
156 changes: 156 additions & 0 deletions lm/search_hashed.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
#ifndef LM_SEARCH_HASHED__
#define LM_SEARCH_HASHED__

#include "lm/binary_format.hh"
#include "lm/config.hh"
#include "lm/read_arpa.hh"
#include "lm/weights.hh"

#include "util/key_value_packing.hh"
#include "util/probing_hash_table.hh"
#include "util/sorted_uniform.hh"

#include <algorithm>
#include <vector>

namespace util { class FilePiece; }

namespace lm {
namespace ngram {
namespace detail {

inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) {
uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(next) * 17894857484156487943ULL);
return ret;
}

struct HashedSearch {
typedef uint64_t Node;

class Unigram {
public:
Unigram() {}

Unigram(void *start, std::size_t /*allocated*/) : unigram_(static_cast<ProbBackoff*>(start)) {}

static std::size_t Size(uint64_t count) {
return (count + 1) * sizeof(ProbBackoff); // +1 for hallucinate <unk>
}

const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index]; }

ProbBackoff &Unknown() { return unigram_[0]; }

void LoadedBinary() {}

// For building.
ProbBackoff *Raw() { return unigram_; }

private:
ProbBackoff *unigram_;
};

Unigram unigram;

bool LookupUnigram(WordIndex word, float &prob, float &backoff, Node &next) const {
const ProbBackoff &entry = unigram.Lookup(word);
prob = entry.prob;
backoff = entry.backoff;
next = static_cast<Node>(word);
return true;
}
};

template <class MiddleT, class LongestT> struct TemplateHashedSearch : public HashedSearch {
typedef MiddleT Middle;
std::vector<Middle> middle;

typedef LongestT Longest;
Longest longest;

static std::size_t Size(const std::vector<uint64_t> &counts, const Config &config) {
std::size_t ret = Unigram::Size(counts[0]);
for (unsigned char n = 1; n < counts.size() - 1; ++n) {
ret += Middle::Size(counts[n], config.probing_multiplier);
}
return ret + Longest::Size(counts.back(), config.probing_multiplier);
}

uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
std::size_t allocated = Unigram::Size(counts[0]);
unigram = Unigram(start, allocated);
start += allocated;
for (unsigned int n = 2; n < counts.size(); ++n) {
allocated = Middle::Size(counts[n - 1], config.probing_multiplier);
middle.push_back(Middle(start, allocated));
start += allocated;
}
allocated = Longest::Size(counts.back(), config.probing_multiplier);
longest = Longest(start, allocated);
start += allocated;
return start;
}

template <class Voc> void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab);

bool LookupMiddle(const Middle &middle, WordIndex word, float &prob, float &backoff, Node &node) const {
node = CombineWordHash(node, word);
typename Middle::ConstIterator found;
if (!middle.Find(node, found)) return false;
prob = found->GetValue().prob;
backoff = found->GetValue().backoff;
return true;
}

bool LookupMiddleNoProb(const Middle &middle, WordIndex word, float &backoff, Node &node) const {
node = CombineWordHash(node, word);
typename Middle::ConstIterator found;
if (!middle.Find(node, found)) return false;
backoff = found->GetValue().backoff;
return true;
}

bool LookupLongest(WordIndex word, float &prob, Node &node) const {
node = CombineWordHash(node, word);
typename Longest::ConstIterator found;
if (!longest.Find(node, found)) return false;
prob = found->GetValue().prob;
return true;
}

// Geenrate a node without necessarily checking that it actually exists.
// Optionally return false if it's know to not exist.
bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
assert(begin != end);
node = static_cast<Node>(*begin);
for (const WordIndex *i = begin + 1; i < end; ++i) {
node = CombineWordHash(node, *i);
}
return true;
}
};

// std::identity is an SGI extension :-(
struct IdentityHash : public std::unary_function<uint64_t, size_t> {
size_t operator()(uint64_t arg) const { return static_cast<size_t>(arg); }
};

struct ProbingHashedSearch : public TemplateHashedSearch<
util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, ProbBackoff>, IdentityHash>,
util::ProbingHashTable<util::ByteAlignedPacking<uint64_t, Prob>, IdentityHash> > {

static const ModelType kModelType = HASH_PROBING;
};

struct SortedHashedSearch : public TemplateHashedSearch<
util::SortedUniformMap<util::ByteAlignedPacking<uint64_t, ProbBackoff> >,
util::SortedUniformMap<util::ByteAlignedPacking<uint64_t, Prob> > > {

static const ModelType kModelType = HASH_SORTED;
};

} // namespace detail
} // namespace ngram
} // namespace lm

#endif // LM_SEARCH_HASHED__
2 changes: 1 addition & 1 deletion lm/ngram_trie.cc → lm/search_trie.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "lm/ngram_trie.hh"
#include "lm/search_trie.hh"

#include "lm/lm_exception.hh"
#include "lm/read_arpa.hh"
Expand Down
12 changes: 8 additions & 4 deletions lm/ngram_trie.hh → lm/search_trie.hh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef LM_NGRAM_TRIE__
#define LM_NGRAM_TRIE__
#ifndef LM_SEARCH_TRIE__
#define LM_SEARCH_TRIE__

#include "lm/binary_format.hh"
#include "lm/trie.hh"
Expand Down Expand Up @@ -56,6 +56,10 @@ struct TrieSearch {
return mid.Find(word, prob, backoff, node);
}

bool LookupMiddleNoProb(const Middle &mid, WordIndex word, float &backoff, Node &node) const {
return mid.FindNoProb(word, backoff, node);
}

bool LookupLongest(WordIndex word, float &prob, const Node &node) const {
return longest.Find(word, prob, node);
}
Expand All @@ -66,7 +70,7 @@ struct TrieSearch {
float ignored_prob, ignored_backoff;
LookupUnigram(*begin, ignored_prob, ignored_backoff, node);
for (const WordIndex *i = begin + 1; i < end; ++i) {
if (!LookupMiddle(middle[i - begin - 1], *i, ignored_prob, ignored_backoff, node)) return false;
if (!LookupMiddleNoProb(middle[i - begin - 1], *i, ignored_backoff, node)) return false;
}
return true;
}
Expand All @@ -76,4 +80,4 @@ struct TrieSearch {
} // namespace ngram
} // namespace lm

#endif // LM_NGRAM_TRIE__
#endif // LM_SEARCH_TRIE__
15 changes: 15 additions & 0 deletions lm/trie.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,21 @@ bool BitPackedMiddle::Find(WordIndex word, float &prob, float &backoff, NodeRang
return true;
}

bool BitPackedMiddle::FindNoProb(WordIndex word, float &backoff, NodeRange &range) const {
uint64_t at_pointer;
if (!FindBitPacked(base_, word_mask_, total_bits_, range.begin, range.end, word, at_pointer)) return false;
at_pointer *= total_bits_;
at_pointer += word_bits_;
at_pointer += prob_bits_;
backoff = util::ReadFloat32(base_ + (at_pointer >> 3), at_pointer & 7);
at_pointer += backoff_bits_;
range.begin = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_);
// Read the next entry's pointer.
at_pointer += total_bits_;
range.end = util::ReadInt57(base_ + (at_pointer >> 3), at_pointer & 7, next_mask_);
return true;
}

void BitPackedMiddle::FinishedLoading(uint64_t next_end) {
assert(next_end <= next_mask_);
uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - next_bits_;
Expand Down
2 changes: 2 additions & 0 deletions lm/trie.hh
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class BitPackedMiddle : public BitPacked {

bool Find(WordIndex word, float &prob, float &backoff, NodeRange &range) const;

bool FindNoProb(WordIndex word, float &backoff, NodeRange &range) const;

void FinishedLoading(uint64_t next_end);

private:
Expand Down
4 changes: 2 additions & 2 deletions lm/vocab.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#include "lm/enumerate_vocab.hh"
#include "lm/lm_exception.hh"
#include "lm/ngram_config.hh"
#include "lm/config.hh"
#include "lm/weights.hh"
#include "util/exception.hh"
#include "util/joint_sort.hh"
Expand Down Expand Up @@ -121,7 +121,7 @@ void SortedVocabulary::FinishedLoading(ProbBackoff *reorder_vocab) {
if (enumerate_) {
util::PairedIterator<ProbBackoff*, std::string*> values(reorder_vocab + 1, &*strings_to_enumerate_.begin());
util::JointSort(begin_, end_, values);
for (WordIndex i = 0; i < end_ - begin_; ++i) {
for (WordIndex i = 0; i < static_cast<WordIndex>(end_ - begin_); ++i) {
// <unk> strikes again: +1 here.
enumerate_->Add(i + 1, strings_to_enumerate_[i]);
}
Expand Down

0 comments on commit 7e003da

Please sign in to comment.