Skip to content

Commit

Permalink
Change FileBuffer::Sink() to not recycle as well
Browse files Browse the repository at this point in the history
  • Loading branch information
kpu committed Feb 10, 2016
1 parent d34d11c commit ee798b9
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 21 deletions.
4 changes: 2 additions & 2 deletions lm/builder/pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class Master {
template <class Compare> void SetupSorts(Sorts<Compare> &sorts, bool exclude_unigrams) {
sorts.Init(config_.order - exclude_unigrams);
// Unigrams don't get sorted because their order is always the same.
if (exclude_unigrams) chains_[0] >> unigrams_.Sink();
if (exclude_unigrams) chains_[0] >> unigrams_.Sink() >> util::stream::kRecycle;
for (std::size_t i = exclude_unigrams; i < config_.order; ++i) {
sorts.push_back(chains_[i], config_.sort, Compare(i + 1));
}
Expand Down Expand Up @@ -255,7 +255,7 @@ void InitialProbabilities(const std::vector<uint64_t> &counts, const std::vector
gammas.Init(config.order - 1);
for (std::size_t i = 1; i < config.order; ++i) {
gammas.push_back(util::MakeTemp(config.TempPrefix()));
gamma_chains[i] >> gammas[i - 1].Sink();
gamma_chains[i] >> gammas[i - 1].Sink() >> util::stream::kRecycle;
}
// Has to be done here due to gamma_chains scope.
master.SetupSorts(primary, true);
Expand Down
2 changes: 2 additions & 0 deletions lm/interpolate/tune_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -350,12 +350,14 @@ class ExtensionsFirstIteration {
assert(!backoffs.empty());
uint8_t max_order = backoffs.front().cols();
for (util::stream::Stream stream(position); stream; ++stream) {
std::cerr << "ApplyBackoffs running" << std::endl;
InitialExtension &ini = *reinterpret_cast<InitialExtension*>(stream.Get());
assert(ini.order > 1); // If it's an extension, it should be higher than a unigram.
if (ini.order != max_order) {
ini.ext.ln_prob += backoffs[ini.ext.instance](ini.ext.model, ini.order - 1);
}
}
std::cerr << "ApplyBackoffs finished." << std::endl;
}

private:
Expand Down
33 changes: 32 additions & 1 deletion lm/interpolate/tune_instance_test.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#include "lm/interpolate/tune_instance.hh"

#include "util/file_stream.hh"
#include "util/file.hh"
#include "util/file_stream.hh"
#include "util/stream/chain.hh"
#include "util/stream/config.hh"
#include "util/stream/typed_stream.hh"
#include "util/string_piece.hh"

#define BOOST_TEST_MODULE InstanceTest
Expand Down Expand Up @@ -76,6 +79,34 @@ BOOST_AUTO_TEST_CASE(Toy) {
BOOST_CHECK_CLOSE(0.0, inst.LNBackoffs(1)(0), 0.001);
BOOST_CHECK_CLOSE((-0.30103 - 0.30103) * M_LN10, inst.LNBackoffs(1)(1), 0.001);

util::stream::Chain extensions(util::stream::ChainConfig(inst.ReadExtensionsEntrySize(), 2, 300));
inst.ReadExtensions(extensions);
std::cerr << "About to construct stream." << std::endl;
try {
util::stream::TypedStream<Extension> stream(extensions.Add());
std::cerr << "Constructed stream." << std::endl;
extensions >> util::stream::kRecycle;
std::cerr << "Added recycling." << std::endl;

// The extensions are
// <s> a
// <s> b
// <s> c
// c </s>

BOOST_REQUIRE(stream);
BOOST_REQUIRE(++stream);
BOOST_REQUIRE(++stream);
BOOST_REQUIRE(++stream);
BOOST_REQUIRE(++stream);
BOOST_REQUIRE(++stream);
BOOST_CHECK(!++stream);

} catch (const std::exception &e) {
std::cerr << "Fail on adding recycling." << std::endl;
}


/*
// Three extensions: a, b, c
BOOST_REQUIRE_EQUAL(3, instances[0].ln_extensions.rows());
Expand Down
5 changes: 0 additions & 5 deletions util/stream/chain.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,6 @@ Chain &Chain::operator>>(const WriteAndRecycle &writer) {
return *this;
}

Chain &Chain::operator>>(const PWriteAndRecycle &writer) {
threads_.push_back(new Thread(Complete(), writer));
return *this;
}

void Chain::Wait(bool release_memory) {
if (queues_.empty()) {
assert(threads_.empty());
Expand Down
10 changes: 4 additions & 6 deletions util/stream/chain.hh
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ class Thread {
* This method is called automatically by this class's @ref Thread() "constructor".
*/
template <class Position, class Worker> void operator()(const Position &position, Worker &worker) {
// try {
try {
worker.Run(position);
// } catch (const std::exception &e) {
// UnhandledException(e);
// }
} catch (const std::exception &e) {
UnhandledException(e);
}
}

private:
Expand All @@ -103,7 +103,6 @@ class Recycler {

extern const Recycler kRecycle;
class WriteAndRecycle;
class PWriteAndRecycle;

/**
* Represents a sequence of workers, through which @ref Block "blocks" can pass.
Expand Down Expand Up @@ -217,7 +216,6 @@ class Chain {
* and runs that worker in a new Thread owned by this chain.
*/
Chain &operator>>(const WriteAndRecycle &writer);
Chain &operator>>(const PWriteAndRecycle &writer);

// Chains are reusable. Call Wait to wait for everything to finish and free memory.
void Wait(bool release_memory = true);
Expand Down
4 changes: 1 addition & 3 deletions util/stream/io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,11 @@ void WriteAndRecycle::Run(const ChainPosition &position) {
}
}

void PWriteAndRecycle::Run(const ChainPosition &position) {
const std::size_t block_size = position.GetChain().BlockSize();
void PWrite::Run(const ChainPosition &position) {
uint64_t offset = 0;
for (Link link(position); link; ++link) {
ErsatzPWrite(file_, link->Get(), link->ValidSize(), offset);
offset += link->ValidSize();
link->SetValidSize(block_size);
}
// Trim file to size.
util::ResizeOrThrow(file_, offset);
Expand Down
8 changes: 4 additions & 4 deletions util/stream/io.hh
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ class WriteAndRecycle {
int file_;
};

class PWriteAndRecycle {
class PWrite {
public:
explicit PWriteAndRecycle(int fd) : file_(fd) {}
explicit PWrite(int fd) : file_(fd) {}
void Run(const ChainPosition &position);
private:
int file_;
Expand All @@ -65,9 +65,9 @@ class FileBuffer {
public:
explicit FileBuffer(int fd) : file_(fd) {}

PWriteAndRecycle Sink() const {
PWrite Sink() const {
util::SeekOrThrow(file_.get(), 0);
return PWriteAndRecycle(file_.get());
return PWrite(file_.get());
}

PRead Source(bool discard = false) {
Expand Down

0 comments on commit ee798b9

Please sign in to comment.