Skip to content

Commit

Permalink
Get hallucinated probabilities to the quantizer.
Browse files Browse the repository at this point in the history
git-svn-id: file:///dev/shm/somefilter.svn@649 e102df66-1e2e-11dd-9b44-c24451a4db5e
  • Loading branch information
kpu committed Sep 19, 2011
1 parent 701473e commit 36ba938
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 33 deletions.
75 changes: 51 additions & 24 deletions lm/search_trie.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ int Compare(unsigned char order, const void *first_void, const void *second_void
return 0;
}

struct ProbPointer {
unsigned char array;
uint64_t index;
};

// Array of n-grams and float indices.
class BackoffMessages {
public:
Expand All @@ -57,17 +62,17 @@ class BackoffMessages {
entry_size_ = entry_size;
}

void Add(const WordIndex *to, uint64_t index) {
void Add(const WordIndex *to, ProbPointer index) {
while (current_ + entry_size_ > allocated_) {
std::size_t allocated_size = allocated_ - (uint8_t*)backing_.get();
Resize(std::max<std::size_t>(allocated_size * 2, entry_size_));
}
memcpy(current_, to, entry_size_ - sizeof(uint64_t));
*reinterpret_cast<uint64_t*>(current_ + entry_size_ - sizeof(uint64_t)) = index;
memcpy(current_, to, entry_size_ - sizeof(ProbPointer));
*reinterpret_cast<ProbPointer*>(current_ + entry_size_ - sizeof(ProbPointer)) = index;
current_ += entry_size_;
}

void Apply(float *base, FILE *unigrams) {
void Apply(float *const *const base, FILE *unigrams) {
FinishedAdding();
if (current_ == allocated_) return;
rewind(unigrams);
Expand All @@ -84,14 +89,15 @@ class BackoffMessages {
UTIL_THROW_IF(fseek(unigrams, -sizeof(weights), SEEK_CUR), util::ErrnoException, "Seeking backwards to denote unigram extension failed.");
WriteOrThrow(unigrams, &weights, sizeof(weights));
}
base[*reinterpret_cast<const uint64_t*>(current_ + sizeof(WordIndex))] += weights.backoff;
const ProbPointer &write_to = *reinterpret_cast<const ProbPointer*>(current_ + sizeof(WordIndex));
base[write_to.array][write_to.index] += weights.backoff;
}
}

void Apply(float *base, RecordReader &reader) {
void Apply(float *const *const base, RecordReader &reader) {
FinishedAdding();
if (current_ == allocated_) return;
const unsigned char order = (entry_size_ - sizeof(uint64_t)) / sizeof(WordIndex);
const unsigned char order = (entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex);
for (reader.Rewind(); reader && (current_ != allocated_); ) {
switch (Compare(order, reader.Data(), current_)) {
case -1:
Expand All @@ -107,7 +113,8 @@ class BackoffMessages {
backoff = kExtensionBackoff;
reader.Overwrite(&backoff, sizeof(float));
} else {
base[*reinterpret_cast<const uint64_t*>(current_ + entry_size_ - sizeof(uint64_t))] += backoff;
const ProbPointer &write_to = *reinterpret_cast<const ProbPointer*>(current_ + entry_size_ - sizeof(ProbPointer));
base[write_to.array][write_to.index] += backoff;
}
current_ += entry_size_;
break;
Expand Down Expand Up @@ -141,29 +148,46 @@ class SRISucks {
public:
SRISucks() {
for (BackoffMessages *i = messages_; i != messages_ + kMaxOrder - 1; ++i)
i->Init(sizeof(uint64_t) + sizeof(WordIndex) * (i - messages_ + 1));
i->Init(sizeof(ProbPointer) + sizeof(WordIndex) * (i - messages_ + 1));
}

void Send(unsigned char begin, unsigned char end, const WordIndex *to, float prob_basis) {
void Send(unsigned char begin, unsigned char order, const WordIndex *to, float prob_basis) {
assert(prob_basis != kBadProb);
for (unsigned char i = begin; i < end; ++i) {
messages_[i - 1].Add(to, values_.size());
ProbPointer pointer;
pointer.array = order - 1;
pointer.index = values_[order - 1].size();
for (unsigned char i = begin; i < order; ++i) {
messages_[i - 1].Add(to, pointer);
}
values_.push_back(prob_basis);
values_[order - 1].push_back(prob_basis);
}

void ObtainBackoffs(unsigned char total_order, FILE *unigram_file, RecordReader *reader) {
messages_[0].Apply(&*values_.begin(), unigram_file);
for (unsigned char i = 0; i < kMaxOrder - 1; ++i) {
it_[i] = &*values_[i].begin();
}
messages_[0].Apply(it_, unigram_file);
BackoffMessages *messages = messages_ + 1;
const RecordReader *end = reader + total_order - 2 /* exclude unigrams and longest order */;
for (; reader != end; ++messages, ++reader) {
messages->Apply(&*values_.begin(), *reader);
messages->Apply(it_, *reader);
}
}

float GetBlankProb(unsigned char order) {
return *(it_[order - 1]++);
}

const std::vector<float> &Values(unsigned char order) const {
return values_[order - 1];
}

private:
std::vector<float> values_;
// This used to be one array. Then I needed to separate it by order for quantization to work.
std::vector<float> values_[kMaxOrder - 1];
BackoffMessages messages_[kMaxOrder - 1];

float *it_[kMaxOrder - 1];
};

class FindBlanks {
Expand Down Expand Up @@ -208,12 +232,14 @@ class FindBlanks {
// Phase to actually write n-grams to the trie.
template <class Quant, class Bhiksha> class WriteEntries {
public:
WriteEntries(RecordReader *contexts, UnigramValue *unigrams, BitPackedMiddle<typename Quant::Middle, Bhiksha> *middle, BitPackedLongest<typename Quant::Longest> &longest, unsigned char order) :
WriteEntries(RecordReader *contexts, UnigramValue *unigrams, BitPackedMiddle<typename Quant::Middle, Bhiksha> *middle, BitPackedLongest<typename Quant::Longest> &longest, unsigned char order, SRISucks &sri) :
contexts_(contexts),
unigrams_(unigrams),
middle_(middle),
longest_(longest),
bigram_pack_((order == 2) ? static_cast<BitPacked&>(longest_) : static_cast<BitPacked&>(*middle_)), order_(order) {}
bigram_pack_((order == 2) ? static_cast<BitPacked&>(longest_) : static_cast<BitPacked&>(*middle_)),
order_(order),
sri_(sri) {}

float UnigramProb(WordIndex index) const { return unigrams_[index].weights.prob; }

Expand All @@ -222,7 +248,7 @@ template <class Quant, class Bhiksha> class WriteEntries {
}

void MiddleBlank(const unsigned char order, const WordIndex *indices, unsigned char lower, float prob_base) {
middle_[order - 2].Insert(indices[order - 1], kBlankProb, kBlankBackoff);
middle_[order - 2].Insert(indices[order - 1], sri_.GetBlankProb(order), kBlankBackoff);
}

void Middle(const unsigned char order, const void *data) {
Expand Down Expand Up @@ -250,6 +276,7 @@ template <class Quant, class Bhiksha> class WriteEntries {
BitPackedLongest<typename Quant::Longest> &longest_;
BitPacked &bigram_pack_;
const unsigned char order_;
SRISucks &sri_;
};

struct Gram {
Expand Down Expand Up @@ -354,9 +381,9 @@ void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<u
}
}

template <class Quant> void TrainQuantizer(uint8_t order, uint64_t count, RecordReader &reader, util::ErsatzProgress &progress, Quant &quant) {
std::vector<float> probs, backoffs;
probs.reserve(count);
template <class Quant> void TrainQuantizer(uint8_t order, uint64_t count, const std::vector<float> &additional, RecordReader &reader, util::ErsatzProgress &progress, Quant &quant) {
std::vector<float> probs(additional), backoffs;
probs.reserve(count + additional.size());
backoffs.reserve(count);
for (reader.Rewind(); reader; ++reader) {
const ProbBackoff &weights = *reinterpret_cast<const ProbBackoff*>(reinterpret_cast<const uint8_t*>(reader.Data()) + sizeof(WordIndex) * order);
Expand Down Expand Up @@ -444,7 +471,7 @@ template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_pre
if (Quant::kTrain) {
util::ErsatzProgress progress(config.messages, "Quantizing", std::accumulate(counts.begin() + 1, counts.end(), 0));
for (unsigned char i = 2; i < counts.size(); ++i) {
TrainQuantizer(i, counts[i-1], inputs[i-2], progress, quant);
TrainQuantizer(i, counts[i-1], sri.Values(i), inputs[i-2], progress, quant);
}
TrainProbQuantizer(counts.size(), counts.back(), inputs[counts.size() - 2], progress, quant);
quant.FinishedLoading(config);
Expand All @@ -459,7 +486,7 @@ template <class Quant, class Bhiksha> void BuildTrie(const std::string &file_pre
}
// Fill entries except unigram probabilities.
{
WriteEntries<Quant, Bhiksha> writer(contexts, unigrams, out.middle_begin_, out.longest, counts.size());
WriteEntries<Quant, Bhiksha> writer(contexts, unigrams, out.middle_begin_, out.longest, counts.size(), sri);
RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Writing trie", writer);
}

Expand Down
19 changes: 14 additions & 5 deletions lm/trie_sort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ FILE *OpenOrThrow(const char *name, const char *mode) {
if (!ret) UTIL_THROW(util::ErrnoException, "Could not open " << name << " for " << mode);
return ret;
}

void WriteOrThrow(FILE *to, const void *data, size_t size) {
assert(size);
if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size);
}

namespace {

typedef util::SizedIterator NGramIter;
Expand Down Expand Up @@ -72,11 +78,6 @@ class PartialViewProxy {

typedef util::ProxyIterator<PartialViewProxy> PartialIter;

void WriteOrThrow(FILE *to, const void *data, size_t size) {
assert(size);
if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size);
}

std::string DiskFlush(const void *mem_begin, const void *mem_end, const std::string &file_prefix, std::size_t batch, unsigned char order, std::size_t weights_size) {
std::stringstream assembled;
assembled << file_prefix << static_cast<unsigned int>(order) << '_' << batch;
Expand Down Expand Up @@ -216,6 +217,14 @@ void RecordReader::Init(const std::string &name, std::size_t entry_size) {
++*this;
}

void RecordReader::Overwrite(const void *start, std::size_t amount) {
long internal = (uint8_t*)start - (uint8_t*)data_.get();
UTIL_THROW_IF(fseek(file_.get(), internal - entry_size_, SEEK_CUR), util::ErrnoException, "Couldn't seek backwards for revision");
WriteOrThrow(file_.get(), start, amount);
long forward = entry_size_ - internal - amount;
if (forward) UTIL_THROW_IF(fseek(file_.get(), forward, SEEK_CUR), util::ErrnoException, "Couldn't seek forwards past revision");
}

void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {
PositiveProbWarn warn(config.positive_log_probability);
{
Expand Down
10 changes: 6 additions & 4 deletions lm/trie_sort.hh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class Config;

namespace trie {

extern const char *kContextSuffix;
FILE *OpenOrThrow(const char *name, const char *mode);
void WriteOrThrow(FILE *to, const void *data, size_t size);

class EntryCompare : public std::binary_function<const void*, const void*, bool> {
public:
explicit EntryCompare(unsigned char order) : order_(order) {}
Expand Down Expand Up @@ -69,6 +73,8 @@ class RecordReader {

std::size_t EntrySize() const { return entry_size_; }

void Overwrite(const void *start, std::size_t amount);

private:
util::scoped_malloc data_;

Expand All @@ -79,10 +85,6 @@ class RecordReader {
util::scoped_FILE file_;
};

extern const char *kContextSuffix;

FILE *OpenOrThrow(const char *name, const char *mode);

void ARPAToSortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab);

} // namespace trie
Expand Down

0 comments on commit 36ba938

Please sign in to comment.