Skip to content

Commit

Permalink
Quantization working if a little brittle
Browse files Browse the repository at this point in the history
git-svn-id: file:///dev/shm/somefilter.svn@568 e102df66-1e2e-11dd-9b44-c24451a4db5e
  • Loading branch information
kpu committed Jun 26, 2011
1 parent d7a7937 commit 7ff43dd
Show file tree
Hide file tree
Showing 12 changed files with 151 additions and 48 deletions.
17 changes: 14 additions & 3 deletions lm/binary_format.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ void WriteHeader(void *to, const Parameters &params) {

} // namespace

void SeekOrThrow(int fd, off_t off) {
if ((off_t)-1 == lseek(fd, off, SEEK_SET)) UTIL_THROW(util::ErrnoException, "Seek failed");
}

void AdvanceOrThrow(int fd, off_t off) {
if ((off_t)-1 == lseek(fd, off, SEEK_CUR)) UTIL_THROW(util::ErrnoException, "Seek failed");
}

uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing) {
if (config.write_mmap) {
std::size_t total = TotalHeaderSize(order) + memory_size;
Expand Down Expand Up @@ -156,7 +164,7 @@ bool IsBinaryFormat(int fd) {
}

void ReadHeader(int fd, Parameters &out) {
if ((off_t)-1 == lseek(fd, sizeof(Sanity), SEEK_SET)) UTIL_THROW(util::ErrnoException, "Seek failed in binary file");
SeekOrThrow(fd, sizeof(Sanity));
ReadLoop(fd, &out.fixed, sizeof(out.fixed));
if (out.fixed.probing_multiplier < 1.0)
UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << out.fixed.probing_multiplier << " which is < 1.0.");
Expand All @@ -173,6 +181,10 @@ void MatchCheck(ModelType model_type, const Parameters &params) {
}
}

void SeekPastHeader(int fd, const Parameters &params) {
SeekOrThrow(fd, TotalHeaderSize(params.counts.size()));
}

uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t memory_size, Backing &backing) {
const off_t file_size = util::SizeFile(backing.file.get());
// The header is smaller than a page, so we have to map the whole header as well.
Expand All @@ -186,8 +198,7 @@ uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t
UTIL_THROW(FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary.");

if (config.enumerate_vocab) {
if ((off_t)-1 == lseek(backing.file.get(), total_map, SEEK_SET))
UTIL_THROW(util::ErrnoException, "Failed to seek in binary file to vocab words");
SeekOrThrow(backing.file.get(), total_map);
}
return reinterpret_cast<uint8_t*>(backing.search.get()) + TotalHeaderSize(params.counts.size());
}
Expand Down
10 changes: 9 additions & 1 deletion lm/binary_format.hh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
namespace lm {
namespace ngram {

typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2} ModelType;
typedef enum {HASH_PROBING=0, HASH_SORTED=1, TRIE_SORTED=2, QUANT_TRIE_SORTED=3} ModelType;

/*Inspect a file to determine if it is a binary lm. If not, return false.
* If so, return true and set recognized to the type. This is the only API in
Expand Down Expand Up @@ -48,6 +48,10 @@ struct Backing {
util::scoped_memory search;
};

void SeekOrThrow(int fd, off_t off);
// Seek forward
void AdvanceOrThrow(int fd, off_t off);

// Create just enough of a binary file to write vocabulary to it.
uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_size, Backing &backing);
// Grow the binary file for the search data structure and set backing.search, returning the memory address where the search data structure should begin.
Expand All @@ -65,6 +69,8 @@ void ReadHeader(int fd, Parameters &params);

void MatchCheck(ModelType model_type, const Parameters &params);

void SeekPastHeader(int fd, const Parameters &params);

uint8_t *SetupBinary(const Config &config, const Parameters &params, std::size_t memory_size, Backing &backing);

void ComplainAboutARPA(const Config &config, ModelType model_type);
Expand All @@ -83,6 +89,8 @@ template <class To> void LoadLM(const char *file, const Config &config, To &to)
// Replace the run-time configured probing_multiplier with the one in the file.
Config new_config(config);
new_config.probing_multiplier = params.fixed.probing_multiplier;
detail::SeekPastHeader(backing.file.get(), params);
To::UpdateConfigFromBinary(backing.file.get(), params.counts, new_config);
std::size_t memory_size = To::Size(params.counts, new_config);
uint8_t *start = detail::SetupBinary(new_config, params, memory_size, backing);
to.InitializeFromBinary(start, params, new_config, backing.file.get());
Expand Down
3 changes: 2 additions & 1 deletion lm/blank.hh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ namespace ngram {
*/
const float kNoExtensionBackoff = -0.0;
const float kExtensionBackoff = 0.0;
const uint32_t kNoExtensionQuant = 0;
const uint64_t kNoExtensionQuant = 0;
const uint64_t kExtensionQuant = 1;

inline void SetExtension(float &backoff) {
if (backoff == kNoExtensionBackoff) backoff = kExtensionBackoff;
Expand Down
59 changes: 46 additions & 13 deletions lm/build_binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,16 @@ void Usage(const char *name) {
"one.\n"
"-s allows models to be built even if they do not have <s> and </s>.\n"
"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n"
"type is one of probing, trie, or sorted:\n\n"
"type is either probing or trie:\n\n"
"probing uses a probing hash table. It is the fastest but uses the most memory.\n"
"-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n"
"trie is a straightforward trie with bit-level packing. It uses the least\n"
"memory and is still faster than SRI or IRST. Building the trie format uses an\n"
"on-disk sort to save memory.\n"
"-t is the temporary directory prefix. Default is the output file name.\n"
"-m limits memory use for sorting. Measured in MB. Default is 1024MB.\n\n"
/*"sorted is like probing but uses a sorted uniform map instead of a hash table.\n"
"It uses more memory than trie and is also slower, so there's no real reason to\n"
"use it.\n\n"*/
"-m limits memory use for sorting. Measured in MB. Default is 1024MB.\n"
"-q turns quantization on and sets the number of bits (e.g. -q 8).\n"
"-b sets backoff quantization bits. Requires -q and defaults to that value.\n\n"
"See http://kheafield.com/code/kenlm/benchmark/ for data structure benchmarks.\n"
"Passing only an input file will print memory usage of each data structure.\n"
"If the ARPA file does not have <unk>, -u sets <unk>'s probability; default 0.0.\n";
Expand All @@ -51,19 +50,31 @@ unsigned long int ParseUInt(const char *from) {
return ret;
}

uint8_t ParseBitCount(const char *from) {
unsigned long val = ParseUInt(from);
if (val > 25) {
util::ParseNumberException e(from);
e << " bit counts are limited to 256.";
}
return val;
}

void ShowSizes(const char *file, const lm::ngram::Config &config) {
std::vector<uint64_t> counts;
util::FilePiece f(file);
lm::ReadARPACounts(f, counts);
std::size_t probing_size = ProbingModel::Size(counts, config);
// probing is always largest so use it to determine number of columns.
long int length = std::max<long int>(5, lrint(ceil(log10(probing_size))));
// probing is usually largest.
long int length = std::max<long int>(2, lrint(ceil(log10(probing_size / 1024))));
// but Quant could be bigger on very small models like the test.
length = std::max<long int>(length, lrint(ceil(log10(QuantTrieModel::Size(counts, config) / 1024))));
std::cout << "Memory estimate:\ntype ";
// right align bytes.
for (long int i = 0; i < length - 5; ++i) std::cout << ' ';
std::cout << "bytes\n"
"probing " << std::setw(length) << probing_size << " assuming -p " << config.probing_multiplier << "\n"
"trie " << std::setw(length) << TrieModel::Size(counts, config) << "\n";
for (long int i = 0; i < length - 2; ++i) std::cout << ' ';
std::cout << "kB\n"
"probing " << std::setw(length) << (probing_size / 1024) << " assuming -p " << config.probing_multiplier << "\n"
"trie " << std::setw(length) << (TrieModel::Size(counts, config) / 1024) << " without quantization\n"
"trie " << std::setw(length) << (QuantTrieModel::Size(counts, config) / 1024) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n";
}

} // namespace ngram
Expand All @@ -73,11 +84,21 @@ void ShowSizes(const char *file, const lm::ngram::Config &config) {
int main(int argc, char *argv[]) {
using namespace lm::ngram;

bool quantize = false, set_backoff_bits = false;
try {
lm::ngram::Config config;
int opt;
while ((opt = getopt(argc, argv, "siu:p:t:m:")) != -1) {
while ((opt = getopt(argc, argv, "siu:p:t:m:q:b:")) != -1) {
switch(opt) {
case 'q':
config.prob_bits = ParseBitCount(optarg);
if (!set_backoff_bits) config.backoff_bits = config.prob_bits;
quantize = true;
break;
case 'b':
config.backoff_bits = ParseBitCount(optarg);
set_backoff_bits = true;
break;
case 'u':
config.unknown_missing_logprob = ParseFloat(optarg);
break;
Expand All @@ -100,6 +121,10 @@ int main(int argc, char *argv[]) {
Usage(argv[0]);
}
}
if (!quantize && set_backoff_bits) {
std::cerr << "You specified backoff quantization (-b) but not probability quantization (-q)" << std::endl;
abort();
}
if (optind + 1 == argc) {
ShowSizes(argv[optind], config);
} else if (optind + 2 == argc) {
Expand All @@ -110,9 +135,17 @@ int main(int argc, char *argv[]) {
const char *from_file = argv[optind + 1];
config.write_mmap = argv[optind + 2];
if (!strcmp(model_type, "probing")) {
if (quantize) {
std::cerr << "Quantization is only implemented in the trie data structure." << std::endl;
abort();
}
ProbingModel(from_file, config);
} else if (!strcmp(model_type, "trie")) {
TrieModel(from_file, config);
if (quantize) {
QuantTrieModel(from_file, config);
} else {
TrieModel(from_file, config);
}
} else {
Usage(argv[0]);
}
Expand Down
6 changes: 3 additions & 3 deletions lm/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,9 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
return ret;
}

template class GenericModel<ProbingHashedSearch, ProbingVocabulary>;
template class GenericModel<trie::TrieSearch<DontQuantize>, SortedVocabulary>;
template class GenericModel<trie::TrieSearch<SeparatelyQuantize>, SortedVocabulary>;
template class GenericModel<ProbingHashedSearch, ProbingVocabulary>; // HASH_PROBING
template class GenericModel<trie::TrieSearch<DontQuantize>, SortedVocabulary>; // TRIE_SORTED
template class GenericModel<trie::TrieSearch<SeparatelyQuantize>, SortedVocabulary>; // TRIE_SORTED_QUANT

} // namespace detail
} // namespace ngram
Expand Down
18 changes: 12 additions & 6 deletions lm/model.hh
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,10 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
private:
typedef base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> P;
public:
// Get the size of memory that will be mapped given ngram counts. This
// does not include small non-mapped control structures, such as this class
// itself.
/* Get the size of memory that will be mapped given ngram counts. This
* does not include small non-mapped control structures, such as this class
* itself.
*/
static size_t Size(const std::vector<uint64_t> &counts, const Config &config = Config());

/* Load the model from a file. It may be an ARPA or binary file. Binary
Expand Down Expand Up @@ -112,6 +113,11 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
private:
friend void LoadLM<>(const char *file, const Config &config, GenericModel<Search, VocabularyT> &to);

static void UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {
AdvanceOrThrow(fd, VocabularyT::Size(counts[0], config));
Search::UpdateConfigFromBinary(fd, counts, config);
}

float SlowBackoffLookup(const WordIndex *const context_rbegin, const WordIndex *const context_rend, unsigned char start) const;

FullScoreReturn ScoreExceptBackoff(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const;
Expand Down Expand Up @@ -140,15 +146,15 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod

// These must also be instantiated in the cc file.
typedef ::lm::ngram::ProbingVocabulary Vocabulary;
typedef detail::GenericModel<detail::ProbingHashedSearch, Vocabulary> ProbingModel;
typedef detail::GenericModel<detail::ProbingHashedSearch, Vocabulary> ProbingModel; // HASH_PROBING
// Default implementation. No real reason for it to be the default.
typedef ProbingModel Model;

// Smaller implementation.
typedef ::lm::ngram::SortedVocabulary SortedVocabulary;
typedef detail::GenericModel<trie::TrieSearch<DontQuantize>, SortedVocabulary> TrieModel;
typedef detail::GenericModel<trie::TrieSearch<DontQuantize>, SortedVocabulary> TrieModel; // TRIE_SORTED

typedef detail::GenericModel<trie::TrieSearch<SeparatelyQuantize>, SortedVocabulary> QuantModel;
typedef detail::GenericModel<trie::TrieSearch<SeparatelyQuantize>, SortedVocabulary> QuantTrieModel; // QUANT_TRIE_SORTED

} // namespace ngram
} // namespace lm
Expand Down
5 changes: 4 additions & 1 deletion lm/model_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ BOOST_AUTO_TEST_CASE(trie) {
}

BOOST_AUTO_TEST_CASE(quant) {
LoadingTest<QuantModel>();
LoadingTest<QuantTrieModel>();
}

template <class ModelT> void BinaryTest() {
Expand Down Expand Up @@ -279,6 +279,9 @@ BOOST_AUTO_TEST_CASE(write_and_read_probing) {
BOOST_AUTO_TEST_CASE(write_and_read_trie) {
BinaryTest<TrieModel>();
}
BOOST_AUTO_TEST_CASE(write_and_read_quant_trie) {
BinaryTest<QuantTrieModel>();
}

} // namespace
} // namespace ngram
Expand Down
15 changes: 13 additions & 2 deletions lm/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <algorithm>
#include <numeric>

#include <unistd.h>

namespace lm {
namespace ngram {

Expand All @@ -25,15 +27,24 @@ void MakeBins(float *values, float *values_end, float *centers, uint32_t bins) {
finish = values + (((values_end - values) * static_cast<uint64_t>(i + 1)) / bins);
if (finish == start) {
// zero length bucket.
*centers = i ? *(centers - 1) : 0;
*centers = i ? *(centers - 1) : -std::numeric_limits<float>::infinity();
} else {
*centers = std::accumulate(start, finish, 0.0) / static_cast<float>(finish - start);
}
}
}

const char kSeparatelyQuantizeVersion = 1;

} // namespace

void SeparatelyQuantize::UpdateConfigFromBinary(int fd, const std::vector<uint64_t> &counts, Config &config) {
char version;
if (read(fd, &version, 1) != 1 || read(fd, &config.prob_bits, 1) != 1 || read(fd, &config.backoff_bits, 1) != 1)
UTIL_THROW(util::ErrnoException, "Failed to read header for quantization.");
if (version != kSeparatelyQuantizeVersion) UTIL_THROW(FormatLoadException, "This file has quantization version " << (unsigned)version << " but the code expects version " << (unsigned)kSeparatelyQuantizeVersion);
}

void SeparatelyQuantize::SetupMemory(void *start, const Config &config) {
// Reserve 8 byte header for bit counts.
start_ = reinterpret_cast<float*>(static_cast<uint8_t*>(start) + 8);
Expand Down Expand Up @@ -64,7 +75,7 @@ void SeparatelyQuantize::TrainProb(uint8_t order, std::vector<float> &prob) {

void SeparatelyQuantize::FinishedLoading(const Config &config) {
uint8_t *actual_base = reinterpret_cast<uint8_t*>(start_) - 8;
*(actual_base++) = 1; // version
*(actual_base++) = kSeparatelyQuantizeVersion; // version
*(actual_base++) = config.prob_bits;
*(actual_base++) = config.backoff_bits;
}
Expand Down
Loading

0 comments on commit 7ff43dd

Please sign in to comment.