Skip to content

Commit

Permalink
Allow special symbols to be mapped to whitespace with a warning
Browse files Browse the repository at this point in the history
  • Loading branch information
kpu committed Feb 11, 2014
1 parent d983a14 commit a13dc51
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 5 deletions.
21 changes: 18 additions & 3 deletions lm/builder/corpus_count.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,11 @@ std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) {
return VocabHandout::MemUsage(vocab_estimate);
}

CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block)
CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block, WarningAction disallowed_symbol)
: from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count),
dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)),
dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) {
dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)),
disallowed_symbol_action_(disallowed_symbol) {
}

void CorpusCount::Run(const util::stream::ChainPosition &position) {
Expand All @@ -244,13 +245,27 @@ void CorpusCount::Run(const util::stream::ChainPosition &position) {
for (const char *i = kDelimiterSet; i < kDelimiterSet + sizeof(kDelimiterSet); ++i) {
delimiters[static_cast<unsigned char>(*i)] = true;
}
bool complained_about_disallowed = false;
try {
while(true) {
StringPiece line(from_.ReadLine());
writer.StartSentence();
for (util::TokenIter<util::BoolCharacter, true> w(line, delimiters); w; ++w) {
WordIndex word = vocab.Lookup(*w);
UTIL_THROW_IF(word <= 2, FormatLoadException, "Special word " << *w << " is not allowed in the corpus. I plan to support models containing <unk> in the future.");
if (word <= 2) {
switch (disallowed_symbol_action_) {
case SILENT:
continue;
case COMPLAIN:
if (!complained_about_disallowed) {
std::cerr << "Warning: " << *w << " appears in the input. All instances of <s>, </s>, and <unk> will be interpreted as whitespace." << std::endl;
complained_about_disallowed = true;
}
continue;
case THROW_UP:
UTIL_THROW(FormatLoadException, "Special word " << *w << " is not allowed in the corpus. I plan to support models containing <unk> in the future. Pass --skip_symbols to convert these symbols to whitespace.");
}
}
writer.Append(word);
++count;
}
Expand Down
5 changes: 4 additions & 1 deletion lm/builder/corpus_count.hh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef LM_BUILDER_CORPUS_COUNT__
#define LM_BUILDER_CORPUS_COUNT__

#include "lm/lm_exception.hh"
#include "lm/word_index.hh"
#include "util/scoped.hh"

Expand Down Expand Up @@ -28,7 +29,7 @@ class CorpusCount {

// token_count: out.
// type_count aka vocabulary size. Initialize to an estimate. It is set to the exact value.
CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block);
CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block, WarningAction disallowed_symbol);

void Run(const util::stream::ChainPosition &position);

Expand All @@ -40,6 +41,8 @@ class CorpusCount {

std::size_t dedupe_mem_size_;
util::scoped_malloc dedupe_mem_;

WarningAction disallowed_symbol_action_;
};

} // namespace builder
Expand Down
8 changes: 8 additions & 0 deletions lm/builder/lmplz_main.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "lm/builder/pipeline.hh"
#include "lm/lm_exception.hh"
#include "util/file.hh"
#include "util/file_piece.hh"
#include "util/usage.hh"
Expand Down Expand Up @@ -43,6 +44,7 @@ int main(int argc, char *argv[]) {
#endif
, "Order of the model")
("interpolate_unigrams", po::bool_switch(&pipeline.initial_probs.interpolate_unigrams), "Interpolate the unigrams (default: emulate SRILM by not interpolating)")
("skip_symbols", po::bool_switch(), "Treat <s>, </s>, and <unk> as whitespace instead of throwing an exception")
("temp_prefix,T", po::value<std::string>(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix")
("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory")
("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow")
Expand Down Expand Up @@ -101,6 +103,12 @@ int main(int argc, char *argv[]) {
return 1;
}

if (vm["skip_symbols"].as<bool>()) {
pipeline.disallowed_symbol_action = lm::COMPLAIN;
} else {
pipeline.disallowed_symbol_action = lm::THROW_UP;
}

util::NormalizeTempPrefix(pipeline.sort.temp_prefix);

lm::builder::InitialProbabilitiesConfig &initial = pipeline.initial_probs;
Expand Down
2 changes: 1 addition & 1 deletion lm/builder/pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ void CountText(int text_file /* input */, int vocab_file /* output */, Master &m
WordIndex type_count = config.vocab_estimate;
util::FilePiece text(text_file, NULL, &std::cerr);
text_file_name = text.FileName();
CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize());
CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize(), config.disallowed_symbol_action);
chain >> boost::ref(counter);

util::stream::Sort<SuffixOrder, AddCombiner> sorter(chain, config.sort, SuffixOrder(config.order), AddCombiner());
Expand Down
7 changes: 7 additions & 0 deletions lm/builder/pipeline.hh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "lm/builder/initial_probabilities.hh"
#include "lm/builder/header_info.hh"
#include "lm/lm_exception.hh"
#include "lm/word_index.hh"
#include "util/stream/config.hh"
#include "util/file_piece.hh"
Expand Down Expand Up @@ -42,6 +43,12 @@ struct PipelineConfig {
*/
uint64_t vocab_size_for_unk;

/* What to do the first time <s>, </s>, or <unk> appears in the input. If
* this is anything but THROW_UP, then the symbol will always be treated as
* whitespace.
*/
WarningAction disallowed_symbol_action;

const std::string &TempPrefix() const { return sort.temp_prefix; }
std::size_t TotalMemory() const { return sort.total_memory; }
};
Expand Down

0 comments on commit a13dc51

Please sign in to comment.