From a13dc51751b5d7469ae7253cd775021db6fba781 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Tue, 11 Feb 2014 14:08:15 -0800 Subject: [PATCH] Allow special symbols to be mapped to whitespace with a warning --- lm/builder/corpus_count.cc | 21 ++++++++++++++++++--- lm/builder/corpus_count.hh | 5 ++++- lm/builder/lmplz_main.cc | 8 ++++++++ lm/builder/pipeline.cc | 2 +- lm/builder/pipeline.hh | 7 +++++++ 5 files changed, 38 insertions(+), 5 deletions(-) diff --git a/lm/builder/corpus_count.cc b/lm/builder/corpus_count.cc index ccc06efca..8b9b879fc 100644 --- a/lm/builder/corpus_count.cc +++ b/lm/builder/corpus_count.cc @@ -223,10 +223,11 @@ std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) { return VocabHandout::MemUsage(vocab_estimate); } -CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block) +CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block, WarningAction disallowed_symbol) : from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count), dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)), - dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) { + dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)), + disallowed_symbol_action_(disallowed_symbol) { } void CorpusCount::Run(const util::stream::ChainPosition &position) { @@ -244,13 +245,27 @@ void CorpusCount::Run(const util::stream::ChainPosition &position) { for (const char *i = kDelimiterSet; i < kDelimiterSet + sizeof(kDelimiterSet); ++i) { delimiters[static_cast(*i)] = true; } + bool complained_about_disallowed = false; try { while(true) { StringPiece line(from_.ReadLine()); writer.StartSentence(); for (util::TokenIter w(line, delimiters); w; ++w) { WordIndex word = vocab.Lookup(*w); - UTIL_THROW_IF(word <= 2, FormatLoadException, "Special word " << *w << " is not allowed in the corpus. I plan to support models containing in the future."); + if (word <= 2) { + switch (disallowed_symbol_action_) { + case SILENT: + continue; + case COMPLAIN: + if (!complained_about_disallowed) { + std::cerr << "Warning: " << *w << " appears in the input. All instances of , , and will be interpreted as whitespace." << std::endl; + complained_about_disallowed = true; + } + continue; + case THROW_UP: + UTIL_THROW(FormatLoadException, "Special word " << *w << " is not allowed in the corpus. I plan to support models containing in the future. Pass --skip_symbols to convert these symbols to whitespace."); + } + } writer.Append(word); ++count; } diff --git a/lm/builder/corpus_count.hh b/lm/builder/corpus_count.hh index aa0ed8ede..17fc7dbcc 100644 --- a/lm/builder/corpus_count.hh +++ b/lm/builder/corpus_count.hh @@ -1,6 +1,7 @@ #ifndef LM_BUILDER_CORPUS_COUNT__ #define LM_BUILDER_CORPUS_COUNT__ +#include "lm/lm_exception.hh" #include "lm/word_index.hh" #include "util/scoped.hh" @@ -28,7 +29,7 @@ class CorpusCount { // token_count: out. // type_count aka vocabulary size. Initialize to an estimate. It is set to the exact value. - CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block); + CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block, WarningAction disallowed_symbol); void Run(const util::stream::ChainPosition &position); @@ -40,6 +41,8 @@ class CorpusCount { std::size_t dedupe_mem_size_; util::scoped_malloc dedupe_mem_; + + WarningAction disallowed_symbol_action_; }; } // namespace builder diff --git a/lm/builder/lmplz_main.cc b/lm/builder/lmplz_main.cc index 98c8d38f0..fe30d1931 100644 --- a/lm/builder/lmplz_main.cc +++ b/lm/builder/lmplz_main.cc @@ -1,4 +1,5 @@ #include "lm/builder/pipeline.hh" +#include "lm/lm_exception.hh" #include "util/file.hh" #include "util/file_piece.hh" #include "util/usage.hh" @@ -43,6 +44,7 @@ int main(int argc, char *argv[]) { #endif , "Order of the model") ("interpolate_unigrams", po::bool_switch(&pipeline.initial_probs.interpolate_unigrams), "Interpolate the unigrams (default: emulate SRILM by not interpolating)") + ("skip_symbols", po::bool_switch(), "Treat , , and as whitespace instead of throwing an exception") ("temp_prefix,T", po::value(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix") ("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory") ("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow") @@ -101,6 +103,12 @@ int main(int argc, char *argv[]) { return 1; } + if (vm["skip_symbols"].as()) { + pipeline.disallowed_symbol_action = lm::COMPLAIN; + } else { + pipeline.disallowed_symbol_action = lm::THROW_UP; + } + util::NormalizeTempPrefix(pipeline.sort.temp_prefix); lm::builder::InitialProbabilitiesConfig &initial = pipeline.initial_probs; diff --git a/lm/builder/pipeline.cc b/lm/builder/pipeline.cc index 0788335d2..3354add8a 100644 --- a/lm/builder/pipeline.cc +++ b/lm/builder/pipeline.cc @@ -221,7 +221,7 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m WordIndex type_count = config.vocab_estimate; util::FilePiece text(text_file, NULL, &std::cerr); text_file_name = text.FileName(); - CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize()); + CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize(), config.disallowed_symbol_action); chain >> boost::ref(counter); util::stream::Sort sorter(chain, config.sort, SuffixOrder(config.order), AddCombiner()); diff --git a/lm/builder/pipeline.hh b/lm/builder/pipeline.hh index 4f3211e73..05ef8b2bf 100644 --- a/lm/builder/pipeline.hh +++ b/lm/builder/pipeline.hh @@ -3,6 +3,7 @@ #include "lm/builder/initial_probabilities.hh" #include "lm/builder/header_info.hh" +#include "lm/lm_exception.hh" #include "lm/word_index.hh" #include "util/stream/config.hh" #include "util/file_piece.hh" @@ -42,6 +43,12 @@ struct PipelineConfig { */ uint64_t vocab_size_for_unk; + /* What to do the first time , , or appears in the input. If + * this is anything but THROW_UP, then the symbol will always be treated as + * whitespace. + */ + WarningAction disallowed_symbol_action; + const std::string &TempPrefix() const { return sort.temp_prefix; } std::size_t TotalMemory() const { return sort.total_memory; } };