Skip to content

Commit

Permalink
Disable pre-fetching when using queue policy (triton-inference-server…
Browse files Browse the repository at this point in the history
…#237)

* Disable pre-fetching when using queue policy

* Address review comments

* Fix the line spacing
  • Loading branch information
tanmayv25 authored Aug 4, 2023
1 parent 4215fdc commit 66c61d2
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 17 deletions.
10 changes: 7 additions & 3 deletions src/dynamic_batch_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,11 @@ DynamicBatchScheduler::Enqueue(std::unique_ptr<InferenceRequest>& request)
// equal to next preferred batch size, then wake batcher up to service
// this request. We do the actual wake outside of the lock to avoid
// having the woken thread immediately block on the lock
wake_batcher =
model_->Server()->GetRateLimiter()->PayloadSlotAvailable(model_);
// Explicitly force non-blocking to prevent waiting for the slot to
// be available.
wake_batcher = model_->Server()->GetRateLimiter()->PayloadSlotAvailable(
model_, model_instance_, queue_.SupportPrefetching(),
true /*force_non_blocking*/);

// We may wake up runner less often if we don't enforce equal shape
// within a batch, otherwise must always wake up runner to check it
Expand Down Expand Up @@ -313,7 +316,8 @@ DynamicBatchScheduler::BatcherThread(const int nice)
}

auto wait_for_slots = [this]() {
return model_->Server()->GetRateLimiter()->PayloadSlotAvailable(model_);
return model_->Server()->GetRateLimiter()->PayloadSlotAvailable(
model_, model_instance_, queue_.SupportPrefetching());
};
const uint64_t default_wait_microseconds = 500 * 1000;

Expand Down
38 changes: 37 additions & 1 deletion src/instance_queue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
namespace triton { namespace core {

InstanceQueue::InstanceQueue(size_t max_batch_size, uint64_t max_queue_delay_ns)
: max_batch_size_(max_batch_size), max_queue_delay_ns_(max_queue_delay_ns)
: max_batch_size_(max_batch_size), max_queue_delay_ns_(max_queue_delay_ns),
waiting_consumer_count_(0)
{
}

Expand Down Expand Up @@ -96,4 +97,39 @@ InstanceQueue::Dequeue(
}
}

void
InstanceQueue::IncrementConsumerCount()
{
{
std::lock_guard<std::mutex> lock(waiting_consumer_mu_);
waiting_consumer_count_++;
}
waiting_consumer_cv_.notify_one();
}

void
InstanceQueue::DecrementConsumerCount()
{
{
std::lock_guard<std::mutex> lock(waiting_consumer_mu_);
waiting_consumer_count_--;
}
waiting_consumer_cv_.notify_one();
}

void
InstanceQueue::WaitForConsumer()
{
std::unique_lock<std::mutex> lock(waiting_consumer_mu_);
waiting_consumer_cv_.wait(
lock, [this]() { return waiting_consumer_count_ > 0; });
}

int
InstanceQueue::WaitingConsumerCount()
{
std::lock_guard<std::mutex> lock(waiting_consumer_mu_);
return waiting_consumer_count_;
}

}} // namespace triton::core
9 changes: 9 additions & 0 deletions src/instance_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,22 @@ class InstanceQueue {
std::shared_ptr<Payload>* payload,
std::vector<std::shared_ptr<Payload>>* merged_payloads);

void IncrementConsumerCount();
void DecrementConsumerCount();
void WaitForConsumer();
int WaitingConsumerCount();

private:
size_t max_batch_size_;
uint64_t max_queue_delay_ns_;

std::deque<std::shared_ptr<Payload>> payload_queue_;
std::shared_ptr<Payload> staged_payload_;
std::mutex mu_;

int waiting_consumer_count_;
std::mutex waiting_consumer_mu_;
std::condition_variable waiting_consumer_cv_;
};

}} // namespace triton::core
125 changes: 114 additions & 11 deletions src/rate_limiter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,19 +158,82 @@ RateLimiter::UnregisterModel(const TritonModel* model)
}
}

bool
RateLimiter::PayloadSlotAvailable(const TritonModel* model)
void
RateLimiter::WaitForConsumer(
const TritonModel* model, const TritonModelInstance* model_instance)
{
bool result;
PayloadQueue* payload_queue = nullptr;
{
std::lock_guard<std::mutex> lk(payload_queues_mu_);
if (payload_queues_.find(model) == payload_queues_.end()) {
LOG_ERROR << "Unable to find the payload queue for the model "
<< model->Name();
return;
}
payload_queue = payload_queues_[model].get();
}

if (model_instance == nullptr) {
payload_queue->queue_->WaitForConsumer();
} else {
payload_queue->specific_queues_[model_instance]->WaitForConsumer();
}
}


int
RateLimiter::WaitingConsumerCount(
const TritonModel* model, const TritonModelInstance* model_instance)
{
PayloadQueue* payload_queue = nullptr;
{
std::lock_guard<std::mutex> lk(payload_queue->mu_);
result = payload_queue->queue_->Size() <
2 * payload_queue->specific_queues_.size();
std::lock_guard<std::mutex> lk(payload_queues_mu_);
if (payload_queues_.find(model) == payload_queues_.end()) {
LOG_ERROR << "Unable to find the payload queue for the model "
<< model->Name();
return 0;
}
payload_queue = payload_queues_[model].get();
}

if (model_instance == nullptr) {
return payload_queue->queue_->WaitingConsumerCount();
} else {
return payload_queue->specific_queues_[model_instance]
->WaitingConsumerCount();
}
}

bool
RateLimiter::PayloadSlotAvailable(
const TritonModel* model, const TritonModelInstance* model_instance,
const bool support_prefetching, const bool force_non_blocking)
{
bool result;
if (support_prefetching) {
PayloadQueue* payload_queue = nullptr;
{
std::lock_guard<std::mutex> lk(payload_queues_mu_);
payload_queue = payload_queues_[model].get();
}
{
std::lock_guard<std::mutex> lk(payload_queue->mu_);
// The logic below sets cap on the number of payloads that
// can be pre-fetched. For per-model batcher the cap is
// twice the number of model instances. For per-instance
// batcher the cap is 2.
size_t multiplier = (model_instance == nullptr)
? payload_queue->specific_queues_.size()
: 1;
result = payload_queue->queue_->Size() < (2 * multiplier);
}
} else {
result = true;
if (force_non_blocking) {
result = (WaitingConsumerCount(model, model_instance) > 0);
} else {
WaitForConsumer(model, model_instance);
}
}
return result;
}
Expand All @@ -186,10 +249,18 @@ RateLimiter::EnqueuePayload(
if (payload_queues_.find(model) == payload_queues_.end()) {
return Status(
Status::Code::INTERNAL,
"Should not print this! Enqueuing payload with an unknown model.");
"Unable to find the payload queue for the model " + model->Name());
}
payload_queue = payload_queues_[model].get();
}

// Update the pending consumer counts to prevent additional
// requests from getting enqueued.
if (pinstance != nullptr) {
payload_queue->specific_queues_[pinstance]->DecrementConsumerCount();
}
payload_queue->queue_->DecrementConsumerCount();

{
std::lock_guard<std::mutex> lk(payload_queue->mu_);
payload->SetState(Payload::State::REQUESTED);
Expand Down Expand Up @@ -230,15 +301,24 @@ RateLimiter::DequeuePayload(
{
payload->reset();
PayloadQueue* payload_queue = nullptr;
auto model = instances[0]->Model();
{
std::lock_guard<std::mutex> lk(payload_queues_mu_);
if (payload_queues_.find(instances[0]->Model()) == payload_queues_.end()) {
LOG_ERROR << "Should not print this! Dequeuing payload with an unknown "
"instance.";
if (payload_queues_.find(model) == payload_queues_.end()) {
LOG_ERROR << "Unable to find the payload queue for the model "
<< model->Name();
return;
}
payload_queue = payload_queues_[instances[0]->Model()].get();
payload_queue = payload_queues_[model].get();
}

// Update the queue to reflect availability of a waiting
// consumer.
payload_queue->queue_->IncrementConsumerCount();
for (const auto instance : instances) {
payload_queue->specific_queues_[instance]->IncrementConsumerCount();
}

std::vector<std::shared_ptr<Payload>> merged_payloads;
size_t instance_index = std::numeric_limits<std::size_t>::max();
{
Expand Down Expand Up @@ -274,10 +354,33 @@ RateLimiter::DequeuePayload(
(*payload)->Callback();
if ((*payload)->GetInstance() == nullptr) {
(*payload)->SetInstance(instances.front());
// Enqueue did not specify the specific instance to
// run with the payload. Hence, need to explicitly
// decrement the consumer count for the instance
// which got allocated.
payload_queue->specific_queues_[instances.front()]
->DecrementConsumerCount();
instances.pop_front();
} else {
instances.erase(instances.begin() + instance_index);
}

// Decrement the counts from the remaining specific
// instance handling as there will be no consumer for
// these queues.
// FIXME: DLIS-5238 For more accurate handling, the
// consumer count for the instances that were not
// requested should be decremented upon the
// EnqueuePayload too. This will need instance
// association to be derived via instances fed into
// DequeuePayload call.
// However, as multiple instances are provided to
// DequeuePayload call only when using device-blocking
// and a single consumer thread, we are decrementing the
// specific instance consumer count as an approximation.
for (const auto instance : instances) {
payload_queue->specific_queues_[instance]->DecrementConsumerCount();
}
}

std::shared_ptr<Payload>
Expand Down
26 changes: 24 additions & 2 deletions src/rate_limiter.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,20 @@ class RateLimiter {
void UnregisterModel(const TritonModel* model);

/// Returns true if there is a payload slot available for the given model.
/// \param model The pointer to TritonModel object to be removed.
/// Note the function can be a blocking call when support_prefetching is
/// false. In this case, the function will block until a slot is available to
/// start building the payload. force_non_blocking option can be set to True
/// to allow function to return back with availability.
/// \param model The pointer to TritonModel object to query for.
/// \param model_instance The pointer to TritonMode
/// \param support_prefetching Whether or not pre-fetching of payloads is
/// enabled.
/// \param force_non_blocking When set true, function will not block for
/// the availability of the slot.
/// \return slot availability in boolean.
bool PayloadSlotAvailable(const TritonModel* model);
bool PayloadSlotAvailable(
const TritonModel* model, const TritonModelInstance* model_instance,
const bool support_prefetching, const bool force_non_blocking = false);

/// Enqueues the payload to rate limiter for scheduling on the given model.
/// \param model The pointer to TritonModel object to be removed.
Expand Down Expand Up @@ -280,6 +291,17 @@ class RateLimiter {
// Initializes payload queues for the given model instance. The queue
// holds payloads that get scheduled by rate limiter.
void InitializePayloadQueues(const TritonModelInstance* instance);

// Should wait till a consumer registers a pending dequeue request
// for the given instance(s) of the model. This implies that the
// call will wait for an idle runner.
void WaitForConsumer(
const TritonModel* model, const TritonModelInstance* model_instance);
// Returns the number of consumers who have a pending dequeue request for
// the given instance(s) of the model.
int WaitingConsumerCount(
const TritonModel* model, const TritonModelInstance* model_instance);

// Defers scheduling of the payload to the future. Rate Limiter will
// schedule the payload execution based upon the resource availability/
// Note that OnSchedule function should only schedule(enqueued in payload
Expand Down
5 changes: 5 additions & 0 deletions src/scheduler_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,18 @@ PriorityQueue::PriorityQueue(
if (priority_levels == 0) {
// Only default policy is instantiated
queues_.emplace(0, PolicyQueue(default_policy_, true));
support_prefetching_ =
(default_policy_.default_timeout_microseconds() == 0) &&
(!default_policy_.allow_timeout_override()) &&
(default_policy_.max_queue_size() == 0);
} else {
// All priorities with user-given policy are instantiated. We do not
// permanently add default PolicyQueue because those will be dynamically
// created and erased to keep memory footprint low
for (const auto& qp : queue_policy_map) {
queues_.emplace(qp.first, PolicyQueue(qp.second, true));
}
support_prefetching_ = false;
}
front_priority_level_ = queues_.empty() ? 0 : queues_.begin()->first;
ResetCursor();
Expand Down
6 changes: 6 additions & 0 deletions src/scheduler_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ class PriorityQueue {
// Return the number of requests in pending batch.
size_t PendingBatchCount() { return pending_cursor_.pending_batch_count_; }

// Whether the queue supports pre-fetching of the requests.
bool SupportPrefetching() { return support_prefetching_; }

private:
class PolicyQueue {
public:
Expand Down Expand Up @@ -260,6 +263,9 @@ class PriorityQueue {

Cursor pending_cursor_;
Cursor current_mark_;

// Whether requests can be pre-fetched from the queue.
bool support_prefetching_{true};
};

}} // namespace triton::core

0 comments on commit 66c61d2

Please sign in to comment.