Skip to content

Commit

Permalink
[cudadecoder] Expose API to wait for separate streams (kaldi-asr#4681)
Browse files Browse the repository at this point in the history
  • Loading branch information
nshmyrev authored Jan 15, 2022
1 parent df1e911 commit 7460d99
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
25 changes: 23 additions & 2 deletions src/cudadecoder/cuda-online-pipeline-dynamic-batcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ CudaOnlinePipelineDynamicBatcher::CudaOnlinePipelineDynamicBatcher(

batcher_thread_.reset(new std::thread(
&CudaOnlinePipelineDynamicBatcher::BatcherThreadLoop, this));

n_chunks_per_corr_.reserve(num_channels_);
}

CudaOnlinePipelineDynamicBatcher::~CudaOnlinePipelineDynamicBatcher() {
Expand All @@ -64,6 +66,7 @@ void CudaOnlinePipelineDynamicBatcher::Push(
backlog_.push_back(
{corr_id, is_first_chunk, is_last_chunk, std::move(wave_samples)});
}
++n_chunks_per_corr_[corr_id];
n_chunks_not_done_.fetch_add(1, std::memory_order_release);
}

Expand Down Expand Up @@ -144,9 +147,20 @@ void CudaOnlinePipelineDynamicBatcher::BatcherThreadLoop() {
curr_batch_->corr_ids, curr_batch_->h_all_waveform,
curr_batch_->n_samples_valid, curr_batch_->is_first_chunk,
curr_batch_->is_last_chunk);
n_chunks_not_done_.fetch_sub(curr_batch_->Size(),
std::memory_order_release);

{
// Update counts
std::lock_guard<std::mutex> lk(next_batch_and_backlog_m_);
n_chunks_not_done_.fetch_sub(curr_batch_->Size(),
std::memory_order_release);
for (size_t i = 0; i < curr_batch_->corr_ids.size(); ++i) {
CorrelationID corr_id = curr_batch_->corr_ids[i];
--n_chunks_per_corr_[corr_id];
if (curr_batch_->is_last_chunk[i]) {
n_chunks_per_corr_.erase(corr_id);
}
}
}
curr_batch_->Clear();
}

Expand All @@ -167,5 +181,12 @@ void CudaOnlinePipelineDynamicBatcher::WaitForCompletion() {
cuda_pipeline_.WaitForLatticeCallbacks();
}

int CudaOnlinePipelineDynamicBatcher::GetNumPendingChunks(CorrelationID corr_id) {
std::lock_guard<std::mutex> lk(next_batch_and_backlog_m_);
if (n_chunks_per_corr_.find(corr_id) == n_chunks_per_corr_.end())
return 0;
return n_chunks_per_corr_[corr_id];
}

} // namespace cuda_decoder
} // namespace kaldi
6 changes: 6 additions & 0 deletions src/cudadecoder/cuda-online-pipeline-dynamic-batcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ class CudaOnlinePipelineDynamicBatcher {
// return
void Push(CorrelationID corr_id, bool is_first_chunk, bool is_last_chunk,
const SubVector<BaseFloat> &wave_samples);

// Wait for completion of the submitted chunks
void WaitForCompletion();
// Get the number of unprocessed chunks for poll-like processing
int GetNumPendingChunks(CorrelationID corr_id);

private:
// Batches created by this Batcher
Expand Down Expand Up @@ -125,7 +129,9 @@ class CudaOnlinePipelineDynamicBatcher {

std::vector<const std::string *> partial_hypotheses_;
std::vector<bool> end_points_;

std::atomic<std::uint32_t> n_chunks_not_done_;
std::unordered_map<CorrelationID, int> n_chunks_per_corr_;

int max_batch_size_;
int num_channels_;
Expand Down

0 comments on commit 7460d99

Please sign in to comment.