Skip to content

Commit

Permalink
Keep models in RAM by default
Browse files Browse the repository at this point in the history
  • Loading branch information
kpu committed Feb 9, 2012
1 parent 78fa5fb commit 4f2aa45
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 17 deletions.
40 changes: 26 additions & 14 deletions lm/binary_format.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,32 +103,44 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t
throw e;
}

if (config.write_method == Config::WRITE_AFTER) {
util::MapAnonymous(memory_size, backing.search);
return reinterpret_cast<uint8_t*>(backing.search.get());
}
// mmap it now.
// We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down.
std::size_t page_size = util::SizePage();
std::size_t alignment_cruft = adjusted_vocab % page_size;
backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED);

return reinterpret_cast<uint8_t*>(backing.search.get()) + alignment_cruft;
} else {
util::MapAnonymous(memory_size, backing.search);
return reinterpret_cast<uint8_t*>(backing.search.get());
}
}

void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, Backing &backing) {
if (config.write_mmap) {
util::SyncOrThrow(backing.search.get(), backing.search.size());
util::SyncOrThrow(backing.vocab.get(), backing.vocab.size());
// header and vocab share the same mmap. The header is written here because we know the counts.
Parameters params = Parameters();
params.counts = counts;
params.fixed.order = counts.size();
params.fixed.probing_multiplier = config.probing_multiplier;
params.fixed.model_type = model_type;
params.fixed.has_vocabulary = config.include_vocab;
params.fixed.search_version = search_version;
WriteHeader(backing.vocab.get(), params);
void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing) {
if (!config.write_mmap) return;
util::SyncOrThrow(backing.vocab.get(), backing.vocab.size());
switch (config.write_method) {
case Config::WRITE_MMAP:
util::SyncOrThrow(backing.search.get(), backing.search.size());
break;
case Config::WRITE_AFTER:
util::SeekOrThrow(backing.file.get(), backing.vocab.size() + vocab_pad);
util::WriteOrThrow(backing.file.get(), backing.search.get(), backing.search.size());
util::FSyncOrThrow(backing.file.get());
break;
}
// header and vocab share the same mmap. The header is written here because we know the counts.
Parameters params = Parameters();
params.counts = counts;
params.fixed.order = counts.size();
params.fixed.probing_multiplier = config.probing_multiplier;
params.fixed.model_type = model_type;
params.fixed.has_vocabulary = config.include_vocab;
params.fixed.search_version = search_version;
WriteHeader(backing.vocab.get(), params);
}

namespace detail {
Expand Down
2 changes: 1 addition & 1 deletion lm/binary_format.hh
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t

// Write header to binary file. This is done last to prevent incomplete files
// from loading.
void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, Backing &backing);
void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts, std::size_t vocab_pad, Backing &backing);

namespace detail {

Expand Down
1 change: 1 addition & 0 deletions lm/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Config::Config() :
temporary_directory_prefix(NULL),
arpa_complain(ALL),
write_mmap(NULL),
write_method(WRITE_AFTER),
include_vocab(true),
prob_bits(8),
backoff_bits(8),
Expand Down
6 changes: 6 additions & 0 deletions lm/config.hh
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ struct Config {
// Include the vocab in the binary file? Only effective if write_mmap != NULL.
bool include_vocab;

typedef enum {
WRITE_MMAP, // Map the file directly.
WRITE_AFTER // Write after we're done.
} WriteMethod;
WriteMethod write_method;

// Quantization options. Only effective for QuantTrieModel. One value is
// reserved for each of prob and backoff, so 2^bits - 1 buckets will be used
// to quantize (and one of the remaining backoffs will be 0).
Expand Down
2 changes: 1 addition & 1 deletion lm/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT
search_.unigram.Unknown().backoff = 0.0;
search_.unigram.Unknown().prob = config.unknown_missing_logprob;
}
FinishFile(config, kModelType, kVersion, counts, backing_);
FinishFile(config, kModelType, kVersion, counts, vocab_.UnkCountChangePadding(), backing_);
} catch (util::Exception &e) {
e << " Byte: " << f.Offset();
throw;
Expand Down
2 changes: 1 addition & 1 deletion lm/search_hashed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ template <class MiddleT, class LongestT> uint8_t *TemplateHashedSearch<MiddleT,

template <class MiddleT, class LongestT> template <class Voc> void TemplateHashedSearch<MiddleT, LongestT>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, Voc &vocab, Backing &backing) {
// TODO: fix sorted.
SetupMemory(GrowForSearch(config, 0, Size(counts, config), backing), counts, config);
SetupMemory(GrowForSearch(config, vocab.UnkCountChangePadding(), Size(counts, config), backing), counts, config);

PositiveProbWarn warn(config.positive_log_probability);

Expand Down
2 changes: 2 additions & 0 deletions lm/vocab.hh
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ class ProbingVocabulary : public base::Vocabulary {

void FinishedLoading(ProbBackoff *reorder_vocab);

std::size_t UnkCountChangePadding() const { return 0; }

bool SawUnk() const { return saw_unk_; }

void LoadedBinary(bool have_words, int fd, EnumerateVocab *to);
Expand Down
4 changes: 4 additions & 0 deletions util/file.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ void WriteOrThrow(int fd, const void *data_void, std::size_t size) {
}
}

void FSyncOrThrow(int fd) {
UTIL_THROW_IF(-1 == fsync(fd), ErrnoException, "Sync of " << fd << " failed.");
}

namespace {
void InternalSeek(int fd, off_t off, int whence) {
UTIL_THROW_IF((off_t)-1 == lseek(fd, off, whence), ErrnoException, "Seek failed");
Expand Down
2 changes: 2 additions & 0 deletions util/file.hh
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ std::size_t ReadOrEOF(int fd, void *to_void, std::size_t amount);

void WriteOrThrow(int fd, const void *data_void, std::size_t size);

void FSyncOrThrow(int fd);

// Seeking
void SeekOrThrow(int fd, uint64_t off);
void AdvanceOrThrow(int fd, int64_t off);
Expand Down

0 comments on commit 4f2aa45

Please sign in to comment.