Skip to content

Commit

Permalink
Vocabulary pad
Browse files Browse the repository at this point in the history
  • Loading branch information
kpu committed Feb 6, 2014
1 parent 99203ab commit 70d48ae
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 6 deletions.
5 changes: 3 additions & 2 deletions lm/builder/interpolate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ class Callback {
};
} // namespace

Interpolate::Interpolate(uint64_t unigram_count, const ChainPositions &backoffs)
: uniform_prob_(1.0 / static_cast<float>(unigram_count - 1)), backoffs_(backoffs) {}
Interpolate::Interpolate(uint64_t vocab_size, const ChainPositions &backoffs)
: uniform_prob_(1.0 / static_cast<float>(vocab_size)), // Includes <unk> but excludes <s>.
backoffs_(backoffs) {}

// perform order-wise interpolation
void Interpolate::Run(const ChainPositions &positions) {
Expand Down
4 changes: 3 additions & 1 deletion lm/builder/interpolate.hh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ namespace lm { namespace builder {
*/
class Interpolate {
public:
explicit Interpolate(uint64_t unigram_count, const ChainPositions &backoffs);
// Normally the unigram count-1 (since p(<s>) = 0) but might be larger to
// set a consistent vocabulary size.
explicit Interpolate(uint64_t vocab_size, const ChainPositions &backoffs);

void Run(const ChainPositions &positions);

Expand Down
10 changes: 8 additions & 2 deletions lm/builder/lmplz_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ int main(int argc, char *argv[]) {
("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory")
("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow")
("sort_block", SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)")
("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table")
("block_count", po::value<std::size_t>(&pipeline.block_count)->default_value(2), "Block count (per order)")
("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file")
("vocab_estimate", po::value<lm::WordIndex>(&pipeline.vocab_estimate)->default_value(1000000), "Assume this vocabulary size for purposes of calculating memory in step 1 (corpus count) and pre-sizing the hash table")
("vocab_file", po::value<std::string>(&pipeline.vocab_file)->default_value(""), "Location to write a file containing the unique vocabulary strings delimited by null bytes")
("vocab_pad", po::value<std::size_t>(&pipeline.vocab_size_for_unk)->default_value(0), "If the vocabulary is smaller than this value, pad with <unk> to reach this size. Requires --interpolate_unigrams")
("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc.")
("text", po::value<std::string>(&text), "Read text from a file instead of stdin")
("arpa", po::value<std::string>(&arpa), "Write ARPA to a file instead of stdout");
Expand Down Expand Up @@ -95,6 +96,11 @@ int main(int argc, char *argv[]) {
}
#endif

if (vm.count("vocab_pad") && !pipeline.initial_probs.interpolate_unigrams) {
std::cerr << "--vocab_pad requires --interpolate_unigrams" << std::endl;
return 1;
}

util::NormalizeTempPrefix(pipeline.sort.temp_prefix);

lm::builder::InitialProbabilitiesConfig &initial = pipeline.initial_probs;
Expand Down
2 changes: 1 addition & 1 deletion lm/builder/pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ void InterpolateProbabilities(const std::vector<uint64_t> &counts, Master &maste
gamma_chains.push_back(read_backoffs);
gamma_chains.back() >> gammas[i].Source();
}
master >> Interpolate(counts[0], ChainPositions(gamma_chains));
master >> Interpolate(std::max(master.Config().vocab_size_for_unk, counts[0] - 1 /* <s> is not included */), ChainPositions(gamma_chains));
gamma_chains >> util::stream::kRecycle;
master.BufferFinal(counts);
}
Expand Down
12 changes: 12 additions & 0 deletions lm/builder/pipeline.hh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@ struct PipelineConfig {
// Number of blocks to use. This will be overridden to 1 if everything fits.
std::size_t block_count;

/* Computing the perplexity of LMs with different vocabularies is hard. For
* example, the lowest perplexity is attained by a unigram model that
* predicts p(<unk>) = 1 and has no other vocabulary. Also, linearly
* interpolated models will sum to more than 1 because <unk> is duplicated
* (SRI just pretends p(<unk>) = 0 for these purposes, which makes it sum to
* 1 but comes with its own problems). This option will make the vocabulary
* a particular size by replicating <unk> multiple times for purposes of
* computing vocabulary size. It has no effect if the actual vocabulary is
* larger. This parameter serves the same purpose as IRSTLM's "dub".
*/
uint64_t vocab_size_for_unk;

const std::string &TempPrefix() const { return sort.temp_prefix; }
std::size_t TotalMemory() const { return sort.total_memory; }
};
Expand Down

0 comments on commit 70d48ae

Please sign in to comment.