Skip to content

Commit

Permalink
Make CorpusCount support pre-defined vocabularies
Browse files Browse the repository at this point in the history
  • Loading branch information
kpu committed Jul 26, 2016
1 parent dc1db45 commit dcad0cd
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 19 deletions.
80 changes: 65 additions & 15 deletions lm/builder/corpus_count.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,31 +158,81 @@ std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) {
return ngram::GrowableVocab<ngram::WriteUniqueWords>::MemUsage(vocab_estimate);
}

CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::vector<bool> &prune_words, const std::string& prune_vocab_filename, std::size_t entries_per_block, WarningAction disallowed_symbol)
: from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count),
CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, bool dynamic_vocab, uint64_t &token_count, WordIndex &type_count, std::vector<bool> &prune_words, const std::string& prune_vocab_filename, std::size_t entries_per_block, WarningAction disallowed_symbol)
: from_(from), vocab_write_(vocab_write), dynamic_vocab_(dynamic_vocab), token_count_(token_count), type_count_(type_count),
prune_words_(prune_words), prune_vocab_filename_(prune_vocab_filename),
dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)),
dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)),
disallowed_symbol_action_(disallowed_symbol) {
}

namespace {
void ComplainDisallowed(StringPiece word, WarningAction &action) {
switch (action) {
case SILENT:
return;
case COMPLAIN:
std::cerr << "Warning: " << word << " appears in the input. All instances of <s>, </s>, and <unk> will be interpreted as whitespace." << std::endl;
action = SILENT;
return;
case THROW_UP:
UTIL_THROW(FormatLoadException, "Special word " << word << " is not allowed in the corpus. I plan to support models containing <unk> in the future. Pass --skip_symbols to convert these symbols to whitespace.");
}
void ComplainDisallowed(StringPiece word, WarningAction &action) {
switch (action) {
case SILENT:
return;
case COMPLAIN:
std::cerr << "Warning: " << word << " appears in the input. All instances of <s>, </s>, and <unk> will be interpreted as whitespace." << std::endl;
action = SILENT;
return;
case THROW_UP:
UTIL_THROW(FormatLoadException, "Special word " << word << " is not allowed in the corpus. I plan to support models containing <unk> in the future. Pass --skip_symbols to convert these symbols to whitespace.");
}
}

// Vocab ids are given in a precompiled hash table.
class VocabGiven {
public:
explicit VocabGiven(int fd) {
util::MapRead(util::POPULATE_OR_READ, fd, 0, util::CheckOverflow(util::SizeOrThrow(fd)), table_backing_);
// Leave space for header with size.
table_ = Table(static_cast<char*>(table_backing_.get()) + sizeof(uint64_t), table_backing_.size() - sizeof(uint64_t));
bos_ = FindOrInsert("<s>");
eos_ = FindOrInsert("</s>");
}

WordIndex FindOrInsert(const StringPiece &word) const {
Table::ConstIterator it;
if (table_.Find(util::MurmurHash64A(word.data(), word.size()), it)) {
return it->value;
} else {
return 0; // <unk>.
}
}

WordIndex Index(const StringPiece &word) const {
return FindOrInsert(word);
}

WordIndex Size() const {
return *static_cast<const uint64_t*>(table_backing_.get());
}

bool IsSpecial(WordIndex word) const {
return word == 0 || word == bos_ || word == eos_;
}

private:
util::scoped_memory table_backing_;

typedef util::ProbingHashTable<ngram::ProbingVocabularyEntry, util::IdentityHash> Table;
Table table_;

WordIndex bos_, eos_;
};
} // namespace

void CorpusCount::Run(const util::stream::ChainPosition &position) {
ngram::GrowableVocab<ngram::WriteUniqueWords> vocab(type_count_, vocab_write_);
if (dynamic_vocab_) {
ngram::GrowableVocab<ngram::WriteUniqueWords> vocab(type_count_, vocab_write_);
RunWithVocab(position, vocab);
} else {
VocabGiven vocab(vocab_write_);
RunWithVocab(position, vocab);
}
}

template <class Vocab> void CorpusCount::RunWithVocab(const util::stream::ChainPosition &position, Vocab &vocab) {
token_count_ = 0;
type_count_ = 0;
const WordIndex end_sentence = vocab.FindOrInsert("</s>");
Expand All @@ -195,7 +245,7 @@ void CorpusCount::Run(const util::stream::ChainPosition &position) {
writer.StartSentence();
while (from_.ReadWordSameLine(w, delimiters)) {
WordIndex word = vocab.FindOrInsert(w);
if (word <= 2) {
if (UTIL_UNLIKELY(vocab.IsSpecial(word))) {
ComplainDisallowed(w, disallowed_symbol_action_);
continue;
}
Expand Down
5 changes: 4 additions & 1 deletion lm/builder/corpus_count.hh
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@ class CorpusCount {

// token_count: out.
// type_count aka vocabulary size. Initialize to an estimate. It is set to the exact value.
CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::vector<bool> &prune_words, const std::string& prune_vocab_filename, std::size_t entries_per_block, WarningAction disallowed_symbol);
CorpusCount(util::FilePiece &from, int vocab_write, bool dynamic_vocab, uint64_t &token_count, WordIndex &type_count, std::vector<bool> &prune_words, const std::string& prune_vocab_filename, std::size_t entries_per_block, WarningAction disallowed_symbol);

void Run(const util::stream::ChainPosition &position);

private:
template <class Vocab> void RunWithVocab(const util::stream::ChainPosition &position, Vocab &vocab);

util::FilePiece &from_;
int vocab_write_;
bool dynamic_vocab_;
uint64_t &token_count_;
WordIndex &type_count_;
std::vector<bool>& prune_words_;
Expand Down
2 changes: 1 addition & 1 deletion lm/builder/corpus_count_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ BOOST_AUTO_TEST_CASE(Short) {
uint64_t token_count;
WordIndex type_count = 10;
std::vector<bool> prune_words;
CorpusCount counter(input_piece, vocab.get(), token_count, type_count, prune_words, "", chain.BlockSize() / chain.EntrySize(), SILENT);
CorpusCount counter(input_piece, vocab.get(), true, token_count, type_count, prune_words, "", chain.BlockSize() / chain.EntrySize(), SILENT);
chain >> boost::ref(counter);
NGramStream<BuildingPayload> stream(chain.Add());
chain >> util::stream::kRecycle;
Expand Down
2 changes: 1 addition & 1 deletion lm/builder/pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ util::stream::Sort<SuffixOrder, CombineCounts> *CountText(int text_file /* input
type_count = config.vocab_estimate;
util::FilePiece text(text_file, NULL, &std::cerr);
text_file_name = text.FileName();
CorpusCount counter(text, vocab_file, token_count, type_count, prune_words, config.prune_vocab_file, chain.BlockSize() / chain.EntrySize(), config.disallowed_symbol_action);
CorpusCount counter(text, vocab_file, true, token_count, type_count, prune_words, config.prune_vocab_file, chain.BlockSize() / chain.EntrySize(), config.disallowed_symbol_action);
chain >> boost::ref(counter);

util::scoped_ptr<util::stream::Sort<SuffixOrder, CombineCounts> > sorter(new util::stream::Sort<SuffixOrder, CombineCounts>(chain, config.sort, SuffixOrder(config.order), CombineCounts()));
Expand Down
6 changes: 5 additions & 1 deletion lm/vocab.hh
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ template <class NewWordAction = NoOpUniqueWords> class GrowableVocab {
return Lookup::MemUsage(content > 2 ? content : 2);
}

// Does not take ownership of write_wordi
// Does not take ownership of new_word_construct
template <class NewWordConstruct> GrowableVocab(WordIndex initial_size, const NewWordConstruct &new_word_construct = NewWordAction())
: lookup_(initial_size), new_word_(new_word_construct) {
FindOrInsert("<unk>"); // Force 0
Expand All @@ -265,6 +265,10 @@ template <class NewWordAction = NoOpUniqueWords> class GrowableVocab {

WordIndex Size() const { return lookup_.Size(); }

bool IsSpecial(WordIndex word) const {
return word <= 2;
}

private:
typedef util::AutoProbing<ProbingVocabularyEntry, util::IdentityHash> Lookup;

Expand Down

0 comments on commit dcad0cd

Please sign in to comment.