Skip to content

Commit

Permalink
Refactor GrowableVocab to write file, ReadNGrams with iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
kpu committed Apr 12, 2014
1 parent f0572c1 commit a803f73
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 49 deletions.
47 changes: 8 additions & 39 deletions lm/builder/corpus_count.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,41 +38,6 @@ struct VocabEntry {
};
#pragma pack(pop)

const float kProbingMultiplier = 1.5;

class VocabHandout {
public:
static std::size_t MemUsage(WordIndex initial_guess) {
return ngram::GrowableVocab::MemUsage(initial_guess);
}

VocabHandout(int fd, WordIndex initial_guess) :
backing_(initial_guess),
word_list_(fd) {
Lookup("<unk>"); // Force 0
Lookup("<s>"); // Force 1
Lookup("</s>"); // Force 2
}

WordIndex Lookup(const StringPiece &word) {
WordIndex old_size = backing_.Size();
WordIndex got = backing_.FindOrInsert(word);
if (got == old_size) {
word_list_ << word << '\0';
}
return got;
}

WordIndex Size() const {
return backing_.Size();
}

private:
ngram::GrowableVocab backing_;

util::FakeOFStream word_list_;
};

class DedupeHash : public std::unary_function<const WordIndex *, bool> {
public:
explicit DedupeHash(std::size_t order) : size_(order * sizeof(WordIndex)) {}
Expand Down Expand Up @@ -109,6 +74,10 @@ struct DedupeEntry {
}
};


// TODO: don't have this here, should be with probing hash table defaults?
const float kProbingMultiplier = 1.5;

typedef util::ProbingHashTable<DedupeEntry, DedupeHash, DedupeEquals> Dedupe;

class Writer {
Expand Down Expand Up @@ -202,7 +171,7 @@ float CorpusCount::DedupeMultiplier(std::size_t order) {
}

std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) {
return VocabHandout::MemUsage(vocab_estimate);
return ngram::GrowableVocab::MemUsage(vocab_estimate);
}

CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block, WarningAction disallowed_symbol)
Expand All @@ -228,10 +197,10 @@ namespace {
} // namespace

void CorpusCount::Run(const util::stream::ChainPosition &position) {
VocabHandout vocab(vocab_write_, type_count_);
ngram::GrowableVocab vocab(type_count_, vocab_write_);
token_count_ = 0;
type_count_ = 0;
const WordIndex end_sentence = vocab.Lookup("</s>");
const WordIndex end_sentence = vocab.FindOrInsert("</s>");
Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_);
uint64_t count = 0;
bool delimiters[256];
Expand All @@ -241,7 +210,7 @@ void CorpusCount::Run(const util::stream::ChainPosition &position) {
StringPiece line(from_.ReadLine());
writer.StartSentence();
for (util::TokenIter<util::BoolCharacter, true> w(line, delimiters); w; ++w) {
WordIndex word = vocab.Lookup(*w);
WordIndex word = vocab.FindOrInsert(*w);
if (word <= 2) {
ComplainDisallowed(*w, disallowed_symbol_action_);
continue;
Expand Down
7 changes: 4 additions & 3 deletions lm/read_arpa.hh
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,16 @@ template <class Voc, class Weights> void Read1Grams(util::FilePiece &f, std::siz
vocab.FinishedLoading(unigrams);
}

template <class Voc, class Weights> void ReadNGram(util::FilePiece &f, const unsigned char n, const Voc &vocab, WordIndex *const reverse_indices, Weights &weights, PositiveProbWarn &warn) {
// Read ngram, write vocab ids to indices_out.
template <class Voc, class Weights, class Iterator> void ReadNGram(util::FilePiece &f, const unsigned char n, const Voc &vocab, Iterator indices_out, Weights &weights, PositiveProbWarn &warn) {
try {
weights.prob = f.ReadFloat();
if (weights.prob > 0.0) {
warn.Warn(weights.prob);
weights.prob = 0.0;
}
for (WordIndex *vocab_out = reverse_indices + n - 1; vocab_out >= reverse_indices; --vocab_out) {
*vocab_out = vocab.Index(f.ReadDelimited(kARPASpaces));
for (unsigned char i = 0; i < n; ++i, ++indices_out) {
*indices_out = vocab.Index(f.ReadDelimited(kARPASpaces));
}
ReadBackoff(f, weights);
} catch(util::Exception &e) {
Expand Down
2 changes: 1 addition & 1 deletion lm/search_hashed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ template <class Build, class Activate, class Store> void ReadNGrams(
typename Store::Entry entry;
std::vector<typename Value::Weights *> between;
for (size_t i = 0; i < count; ++i) {
ReadNGram(f, n, vocab, &*vocab_ids.begin(), entry.value, warn);
ReadNGram(f, n, vocab, vocab_ids.rbegin(), entry.value, warn);
build.SetRest(&*vocab_ids.begin(), n, entry.value);

keys[0] = detail::CombineWordHash(static_cast<uint64_t>(vocab_ids.front()), vocab_ids[1]);
Expand Down
7 changes: 5 additions & 2 deletions lm/trie_sort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <cstdio>
#include <cstdlib>
#include <deque>
#include <iterator>
#include <limits>
#include <vector>

Expand Down Expand Up @@ -248,11 +249,13 @@ void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vo
uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size;
if (order == counts.size()) {
for (; out != out_end; out += entry_size) {
ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<Prob*>(out + words_size), warn);
std::reverse_iterator<WordIndex*> it(reinterpret_cast<WordIndex*>(out) + order);
ReadNGram(f, order, vocab, it, *reinterpret_cast<Prob*>(out + words_size), warn);
}
} else {
for (; out != out_end; out += entry_size) {
ReadNGram(f, order, vocab, reinterpret_cast<WordIndex*>(out), *reinterpret_cast<ProbBackoff*>(out + words_size), warn);
std::reverse_iterator<WordIndex*> it(reinterpret_cast<WordIndex*>(out) + order);
ReadNGram(f, order, vocab, it, *reinterpret_cast<ProbBackoff*>(out + words_size), warn);
}
}
// Sort full records by full n-gram.
Expand Down
18 changes: 14 additions & 4 deletions lm/vocab.hh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "lm/enumerate_vocab.hh"
#include "lm/lm_exception.hh"
#include "lm/virtual_interface.hh"
#include "util/fake_ofstream.hh"
#include "util/murmur_hash.hh"
#include "util/pool.hh"
#include "util/probing_hash_table.hh"
Expand Down Expand Up @@ -187,19 +188,26 @@ class GrowableVocab {
return Lookup::MemUsage(content > 2 ? content : 2);
}

explicit GrowableVocab(WordIndex initial_size) : lookup_(initial_size) {}
// Does not take ownership of write_wordi
explicit GrowableVocab(WordIndex initial_size, int write_words_fd)
: lookup_(initial_size), word_list_(write_words_fd) {
FindOrInsert("<unk>"); // Force 0
FindOrInsert("<s>"); // Force 1
FindOrInsert("</s>"); // Force 2
}

WordIndex Index(const StringPiece &str) const {
Lookup::ConstIterator i;
return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0;
}

// To see if it was just added, compare with Size() before insertion.
WordIndex FindOrInsert(const StringPiece &word) {
ProbingVocabularyEntry entry = ProbingVocabularyEntry::Make(util::MurmurHashNative(word.data(), word.size()), Size());
Lookup::MutableIterator it;
lookup_.FindOrInsert(entry, it);
UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh");
if (lookup_.FindOrInsert(entry, it)) {
word_list_ << word << '\0';
UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh");
}
return it->value;
}

Expand All @@ -209,6 +217,8 @@ class GrowableVocab {
typedef util::AutoProbing<ProbingVocabularyEntry, util::IdentityHash> Lookup;

Lookup lookup_;

util::FakeOFStream word_list_;
};

} // namespace ngram
Expand Down

0 comments on commit a803f73

Please sign in to comment.