Skip to content

Commit

Permalink
Merge branch 'INSTX-1190-enforce-sink-moves' into 'master'
Browse files Browse the repository at this point in the history
Enforce that values passed to a sink are mutable rvalues

See merge request machine-learning/dorado!436
  • Loading branch information
blawrence-ont committed Aug 23, 2023
2 parents 8aa7722 + 6c9e736 commit 829dba3
Show file tree
Hide file tree
Showing 11 changed files with 39 additions and 32 deletions.
2 changes: 1 addition & 1 deletion dorado/read_pipeline/BaseSpaceDuplexCallerNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ void BaseSpaceDuplexCallerNode::basespace(const std::string& template_read_id,
duplex_read->read_id = template_read->read_id + ";" + complement_read->read_id;
duplex_read->read_tag = template_read->read_tag;

send_message_to_sink(duplex_read);
send_message_to_sink(std::move(duplex_read));
}
edlibFreeAlignResult(result);
}
Expand Down
10 changes: 5 additions & 5 deletions dorado/read_pipeline/DuplexReadTaggingNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ void DuplexReadTaggingNode::worker_thread() {
// the ones whose duplex offsprings never came. They are retagged to not be
// duplex parents and then sent downstream.
if (!read->is_duplex && !read->is_duplex_parent) {
send_message_to_sink(read);
send_message_to_sink(std::move(read));
} else if (read->is_duplex) {
std::string template_read_id = read->read_id.substr(0, read->read_id.find(';'));
std::string complement_read_id =
read->read_id.substr(read->read_id.find(';') + 1, read->read_id.length());

send_message_to_sink(read);
send_message_to_sink(std::move(read));

for (auto& rid : {template_read_id, complement_read_id}) {
if (m_parents_processed.find(rid) != m_parents_processed.end()) {
Expand All @@ -61,7 +61,7 @@ void DuplexReadTaggingNode::worker_thread() {
if (find_parent != m_duplex_parents.end()) {
// Parent read has been seen. Process it and send it
// downstream.
send_message_to_sink(find_parent->second);
send_message_to_sink(std::move(find_parent->second));
m_parents_processed.insert(rid);
m_duplex_parents.erase(find_parent);
} else {
Expand All @@ -76,8 +76,8 @@ void DuplexReadTaggingNode::worker_thread() {
// If a read is in the parents wanted list, then sent it downstream
// and add it to the set of processed reads. It will also be removed
// from the parent reads being looked for.
send_message_to_sink(read);
m_parents_processed.insert(read->read_id);
send_message_to_sink(std::move(read));
m_parents_wanted.erase(find_read);
} else {
// No duplex offspring is seen so far, so hold it and track
Expand All @@ -89,7 +89,7 @@ void DuplexReadTaggingNode::worker_thread() {

for (auto& [k, v] : m_duplex_parents) {
v->is_duplex_parent = false;
send_message_to_sink(v);
send_message_to_sink(std::move(v));
}
}

Expand Down
2 changes: 1 addition & 1 deletion dorado/read_pipeline/FakeDataLoader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ void FakeDataLoader::load_reads(const int num_reads) {
fake_read->raw_data = torch::randint(0, 10000, {read_size}, torch::kInt16);
fake_read->read_id = "Placeholder-read-id";

m_pipeline.push_message(fake_read);
m_pipeline.push_message(std::move(fake_read));
}
}

Expand Down
2 changes: 1 addition & 1 deletion dorado/read_pipeline/ModBaseCallerNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ void ModBaseCallerNode::input_worker_thread() {
++m_working_reads_size;
} else {
// No modbases to call, pass directly to next node
send_message_to_sink(read);
send_message_to_sink(std::move(read));
++m_num_non_mod_base_reads_pushed;
}
break;
Expand Down
18 changes: 9 additions & 9 deletions dorado/read_pipeline/PairingNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,9 @@ void PairingNode::pair_generating_worker_thread(int tid) {
std::unique_lock<std::mutex> lock(m_pairing_mtx);
auto flush_message = std::get<CacheFlushMessage>(message);
auto& read_cache = m_read_caches[flush_message.client_id];
for (const auto& [key, reads_list] : read_cache.channel_mux_read_map) {
for (auto& [key, reads_list] : read_cache.channel_mux_read_map) {
// kv is a std::pair<UniquePoreIdentifierKey, std::list<std::shared_ptr<Read>>>
for (const auto& read_ptr : reads_list) {
for (auto& read_ptr : reads_list) {
// Push each read message
send_message_to_sink(std::move(read_ptr));
}
Expand Down Expand Up @@ -378,8 +378,8 @@ void PairingNode::pair_generating_worker_thread(int tid) {
ok_to_clear = true;
}
if (ok_to_clear) {
send_message_to_sink(std::move(*to_clear_itr));
to_clear_itr = m_reads_to_clear.erase(to_clear_itr);
auto read_handle = m_reads_to_clear.extract(*to_clear_itr++);
send_message_to_sink(std::move(read_handle.value()));
} else {
++to_clear_itr;
}
Expand All @@ -391,14 +391,14 @@ void PairingNode::pair_generating_worker_thread(int tid) {
std::unique_lock<std::mutex> lock(m_pairing_mtx);
// There are still reads in channel_mux_read_map. Push them to the sink.
// Last thread alive is responsible for cleaning up the cache.
for (const auto& [client_id, read_cache] : m_read_caches) {
for (const auto& kv : read_cache.channel_mux_read_map) {
for (auto& [client_id, read_cache] : m_read_caches) {
for (auto& kv : read_cache.channel_mux_read_map) {
// kv is a std::pair<UniquePoreIdentifierKey, std::list<std::shared_ptr<Read>>>
const auto& reads_list = kv.second;
auto& reads_list = kv.second;

for (const auto& read_ptr : reads_list) {
for (auto& read_ptr : reads_list) {
// Push each read message
send_message_to_sink(read_ptr);
send_message_to_sink(std::move(read_ptr));
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion dorado/read_pipeline/ReadFilterNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void ReadFilterNode::worker_thread() {
(m_read_ids_to_filter.find(read->read_id) != m_read_ids_to_filter.end())) {
log_filtering();
} else {
send_message_to_sink(read);
send_message_to_sink(std::move(read));
}
}
}
Expand Down
6 changes: 1 addition & 5 deletions dorado/read_pipeline/ReadPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ float Read::calculate_mean_qscore() const {

MessageSink::MessageSink(size_t max_messages) : m_work_queue(max_messages) {}

void MessageSink::push_message(Message &&message) {
void MessageSink::push_message_internal(Message &&message) {
const auto status = m_work_queue.try_push(std::move(message));
// try_push will fail if the sink has been told to terminate.
// We do not expect to be pushing reads from this source if that is the case.
Expand Down Expand Up @@ -368,10 +368,6 @@ Pipeline::Pipeline(PipelineDescriptor &&descriptor,

void MessageSink::add_sink(MessageSink &sink) { m_sinks.push_back(std::ref(sink)); }

void MessageSink::send_message_to_sink(int sink_index, Message &&message) {
m_sinks.at(sink_index).get().push_message(std::move(message));
}

void Pipeline::push_message(Message &&message) {
assert(!m_nodes.empty());
const auto source_node_index = m_source_to_sink_order.front();
Expand Down
20 changes: 15 additions & 5 deletions dorado/read_pipeline/ReadPipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,12 @@ class MessageSink {
}

// Adds a message to the input queue. This can block if the sink's queue is full.
// Pushed messages must be rvalues: the input queue takes ownership.
void push_message(Message&& message);
template <typename Msg>
void push_message(Msg&& msg) {
static_assert(!std::is_reference_v<Msg> && !std::is_const_v<Msg>,
"Pushed messages must be rvalues: the sink takes ownership");
push_message_internal(Message(std::move(msg)));
}

// Waits until work is finished and shuts down worker threads.
// No work can be done by the node after this returns until
Expand All @@ -201,12 +205,16 @@ class MessageSink {
void restart_input_queue() { m_work_queue.restart(); }

// Sends message to the designated sink.
void send_message_to_sink(int sink_index, Message&& message);
template <typename Msg>
void send_message_to_sink(int sink_index, Msg&& message) {
m_sinks.at(sink_index).get().push_message(std::forward<Msg>(message));
}

// Version for nodes with a single sink that is implicit.
void send_message_to_sink(Message&& message) {
template <typename Msg>
void send_message_to_sink(Msg&& message) {
assert(m_sinks.size() == 1);
send_message_to_sink(0, std::move(message));
send_message_to_sink(0, std::forward<Msg>(message));
}

// Pops the next input message, returning true on success.
Expand All @@ -225,6 +233,8 @@ class MessageSink {

friend class Pipeline;
void add_sink(MessageSink& sink);

void push_message_internal(Message&& message);
};

// Object from which a Pipeline is created.
Expand Down
2 changes: 1 addition & 1 deletion dorado/read_pipeline/ScalerNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void ScalerNode::worker_thread() {
read->num_trimmed_samples = trim_start;

// Pass the read to the next node
send_message_to_sink(read);
send_message_to_sink(std::move(read));
}
}

Expand Down
3 changes: 2 additions & 1 deletion dorado/read_pipeline/StereoDuplexEncoderNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,8 @@ void StereoDuplexEncoderNode::worker_thread() {
read_pair->read_1, read_pair->read_2, read_pair->read_1_start,
read_pair->read_1_end, read_pair->read_2_start, read_pair->read_2_end);

send_message_to_sink(stereo_encoded_read); // Stereo-encoded read created, send it to sink
send_message_to_sink(
std::move(stereo_encoded_read)); // Stereo-encoded read created, send it to sink
}
}

Expand Down
4 changes: 2 additions & 2 deletions tests/ReadFilterNodeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ TEST_CASE("ReadFilterNode: Filter read based on read name", TEST_GROUP) {
read_2->attributes.fast5_filename = "batch_0.fast5";

dorado::ReadFilterNode filter(sink, 0 /*min_qscore*/, 0, {"read_2"}, 2 /*threads*/);
filter.push_message(read_1);
filter.push_message(read_2);
filter.push_message(std::move(read_1));
filter.push_message(std::move(read_2));
}

auto messages = sink.get_messages();
Expand Down

0 comments on commit 829dba3

Please sign in to comment.