Skip to content

Commit

Permalink
batch polling in client (Netflix#3304)
Browse files Browse the repository at this point in the history
* batch polling in client

* Deprecate shared threadpool and evenly split it between workers if that is supplied
  • Loading branch information
jxu-nflx authored Oct 31, 2022
1 parent 9e80c4a commit 03f8d96
Show file tree
Hide file tree
Showing 7 changed files with 261 additions and 133 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,31 +31,33 @@ class PollingSemaphore {
semaphore = new Semaphore(numSlots);
}

/**
* Signals if polling is allowed based on whether a permit can be acquired.
*
* @return {@code true} - if permit is acquired {@code false} - if permit could not be acquired
*/
boolean canPoll() {
boolean acquired = semaphore.tryAcquire();
LOGGER.debug("Trying to acquire permit: {}", acquired);
return acquired;
}

/** Signals that processing is complete and the permit can be released. */
void complete() {
/** Signals that processing is complete and the specified number of permits can be released. */
void complete(int numSlots) {
LOGGER.debug("Completed execution; releasing permit");
semaphore.release();
semaphore.release(numSlots);
}

/**
* Gets the number of threads available for processing.
*
* @return number of available permits
*/
int availableThreads() {
int availableSlots() {
int available = semaphore.availablePermits();
LOGGER.debug("Number of available permits: {}", available);
return available;
}

/**
* Signals if processing is allowed based on whether specified number of permits can be
* acquired.
*
* @param numSlots the number of permits to acquire
* @return {@code true} - if permit is acquired {@code false} - if permit could not be acquired
*/
public boolean acquireSlots(int numSlots) {
boolean acquired = semaphore.tryAcquire(numSlots);
LOGGER.debug("Trying to acquire {} permit: {}", numSlots, acquired);
return acquired;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
Expand Down Expand Up @@ -68,7 +69,6 @@ class TaskPollExecutor {
TaskPollExecutor(
EurekaClient eurekaClient,
TaskClient taskClient,
int threadCount,
int updateRetryCount,
Map<String, String> taskToDomain,
String workerNamePrefix,
Expand All @@ -80,17 +80,11 @@ class TaskPollExecutor {

this.pollingSemaphoreMap = new HashMap<>();
int totalThreadCount = 0;
if (!taskThreadCount.isEmpty()) {
for (Map.Entry<String, Integer> entry : taskThreadCount.entrySet()) {
String taskType = entry.getKey();
int count = entry.getValue();
totalThreadCount += count;
pollingSemaphoreMap.put(taskType, new PollingSemaphore(count));
}
} else {
totalThreadCount = threadCount;
// shared poll for all workers
pollingSemaphoreMap.put(ALL_WORKERS, new PollingSemaphore(threadCount));
for (Map.Entry<String, Integer> entry : taskThreadCount.entrySet()) {
String taskType = entry.getKey();
int count = entry.getValue();
totalThreadCount += count;
pollingSemaphoreMap.put(taskType, new PollingSemaphore(count));
}

LOGGER.info("Initialized the TaskPollExecutor with {} threads", totalThreadCount);
Expand Down Expand Up @@ -139,12 +133,12 @@ void pollAndExecute(Worker worker) {
String taskType = worker.getTaskDefName();
PollingSemaphore pollingSemaphore = getPollingSemaphore(taskType);

Task task;
int slotsToAcquire = pollingSemaphore.availableSlots();
if (slotsToAcquire <= 0 || !pollingSemaphore.acquireSlots(slotsToAcquire)) {
return;
}
int acquiredTasks = 0;
try {
if (!pollingSemaphore.canPoll()) {
return;
}

String domain =
Optional.ofNullable(PropertyFactory.getString(taskType, DOMAIN, null))
.orElseGet(
Expand All @@ -155,52 +149,60 @@ void pollAndExecute(Worker worker) {
.orElse(taskToDomain.get(taskType)));

LOGGER.debug("Polling task of type: {} in domain: '{}'", taskType, domain);
task =

List<Task> tasks =
MetricsContainer.getPollTimer(taskType)
.record(
() ->
taskClient.pollTask(
taskType, worker.getIdentity(), domain));

if (Objects.nonNull(task) && StringUtils.isNotBlank(task.getTaskId())) {
MetricsContainer.incrementTaskPollCount(taskType, 1);
LOGGER.debug(
"Polled task: {} of type: {} in domain: '{}', from worker: {}",
task.getTaskId(),
taskType,
domain,
worker.getIdentity());

CompletableFuture<Task> taskCompletableFuture =
CompletableFuture.supplyAsync(
() -> processTask(task, worker, pollingSemaphore), executorService);

if (task.getResponseTimeoutSeconds() > 0 && worker.leaseExtendEnabled()) {
ScheduledFuture<?> leaseExtendFuture =
leaseExtendExecutorService.scheduleWithFixedDelay(
extendLease(task, taskCompletableFuture),
Math.round(
task.getResponseTimeoutSeconds()
* LEASE_EXTEND_DURATION_FACTOR),
Math.round(
task.getResponseTimeoutSeconds()
* LEASE_EXTEND_DURATION_FACTOR),
TimeUnit.SECONDS);
leaseExtendMap.put(task.getTaskId(), leaseExtendFuture);
taskClient.batchPollTasksInDomain(
taskType,
domain,
worker.getIdentity(),
slotsToAcquire,
worker.getBatchPollTimeoutInMS()));
acquiredTasks = tasks.size();
for (Task task : tasks) {
if (Objects.nonNull(task) && StringUtils.isNotBlank(task.getTaskId())) {
MetricsContainer.incrementTaskPollCount(taskType, 1);
LOGGER.debug(
"Polled task: {} of type: {} in domain: '{}', from worker: {}",
task.getTaskId(),
taskType,
domain,
worker.getIdentity());

CompletableFuture<Task> taskCompletableFuture =
CompletableFuture.supplyAsync(
() -> processTask(task, worker, pollingSemaphore),
executorService);

if (task.getResponseTimeoutSeconds() > 0 && worker.leaseExtendEnabled()) {
ScheduledFuture<?> leaseExtendFuture =
leaseExtendExecutorService.scheduleWithFixedDelay(
extendLease(task, taskCompletableFuture),
Math.round(
task.getResponseTimeoutSeconds()
* LEASE_EXTEND_DURATION_FACTOR),
Math.round(
task.getResponseTimeoutSeconds()
* LEASE_EXTEND_DURATION_FACTOR),
TimeUnit.SECONDS);
leaseExtendMap.put(task.getTaskId(), leaseExtendFuture);
}

taskCompletableFuture.whenComplete(this::finalizeTask);
} else {
// no task was returned in the poll, release the permit
pollingSemaphore.complete(1);
}

taskCompletableFuture.whenComplete(this::finalizeTask);
} else {
// no task was returned in the poll, release the permit
pollingSemaphore.complete();
}
} catch (Exception e) {
// release the permit if exception is thrown during polling, because the thread would
// not be busy
pollingSemaphore.complete();
MetricsContainer.incrementTaskPollErrorCount(worker.getTaskDefName(), e);
LOGGER.error("Error when polling for tasks", e);
}

// immediately release unused permits
pollingSemaphore.complete(slotsToAcquire - acquiredTasks);
}

void shutdown(int timeout) {
Expand Down Expand Up @@ -247,7 +249,7 @@ private Task processTask(Task task, Worker worker, PollingSemaphore pollingSemap
TaskResult result = new TaskResult(task);
handleException(t, result, worker, task);
} finally {
pollingSemaphore.complete();
pollingSemaphore.complete(1);
}
return task;
}
Expand Down Expand Up @@ -391,11 +393,7 @@ private void handleException(Throwable t, TaskResult result, Worker worker, Task
}

private PollingSemaphore getPollingSemaphore(String taskType) {
if (pollingSemaphoreMap.containsKey(taskType)) {
return pollingSemaphoreMap.get(taskType);
} else {
return pollingSemaphoreMap.get(ALL_WORKERS);
}
return pollingSemaphoreMap.get(taskType);
}

private Runnable extendLease(Task task, CompletableFuture<Task> taskCompletableFuture) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
package com.netflix.conductor.client.automator;

import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import org.apache.commons.lang3.Validate;
import org.slf4j.Logger;
Expand All @@ -44,7 +47,7 @@ public class TaskRunnerConfigurer {
private final List<Worker> workers = new LinkedList<>();
private final int sleepWhenRetry;
private final int updateRetryCount;
private final int threadCount;
@Deprecated private final int threadCount;
private final int shutdownGracePeriodSeconds;
private final String workerNamePrefix;
private final Map<String /*taskType*/, String /*domain*/> taskToDomain;
Expand All @@ -64,19 +67,26 @@ private TaskRunnerConfigurer(Builder builder) {
} else if (!builder.taskThreadCount.isEmpty()) {
for (Worker worker : builder.workers) {
if (!builder.taskThreadCount.containsKey(worker.getTaskDefName())) {
String message =
String.format(MISSING_TASK_THREAD_COUNT, worker.getTaskDefName());
LOGGER.error(message);
throw new ConductorClientException(message);
LOGGER.info(
"No thread count specified for task type {}, default to 1 thread",
worker.getTaskDefName());
builder.taskThreadCount.put(worker.getTaskDefName(), 1);
}
workers.add(worker);
}
this.taskThreadCount = builder.taskThreadCount;
this.threadCount = -1;
} else {
builder.workers.forEach(workers::add);
this.taskThreadCount = builder.taskThreadCount;
Set<String> taskTypes = new HashSet<>();
for (Worker worker : builder.workers) {
taskTypes.add(worker.getTaskDefName());
workers.add(worker);
}
this.threadCount = (builder.threadCount == -1) ? workers.size() : builder.threadCount;
// shared thread pool will be evenly split between task types
int splitThreadCount = threadCount / taskTypes.size();
this.taskThreadCount =
taskTypes.stream().collect(Collectors.toMap(v -> v, v -> splitThreadCount));
}

this.eurekaClient = builder.eurekaClient;
Expand All @@ -94,7 +104,7 @@ public static class Builder {
private String workerNamePrefix = "workflow-worker-%d";
private int sleepWhenRetry = 500;
private int updateRetryCount = 3;
private int threadCount = -1;
@Deprecated private int threadCount = -1;
private int shutdownGracePeriodSeconds = 10;
private final Iterable<Worker> workers;
private EurekaClient eurekaClient;
Expand Down Expand Up @@ -143,7 +153,9 @@ public Builder withUpdateRetryCount(int updateRetryCount) {
* @param threadCount # of threads assigned to the workers. Should be at-least the size of
* taskWorkers to avoid starvation in a busy system.
* @return Builder instance
* @deprecated Use {@link TaskRunnerConfigurer.Builder#withTaskThreadCount(Map)} instead.
*/
@Deprecated
public Builder withThreadCount(int threadCount) {
if (threadCount < 1) {
throw new IllegalArgumentException("No. of threads cannot be less than 1");
Expand Down Expand Up @@ -200,6 +212,7 @@ public TaskRunnerConfigurer build() {
/**
* @return Thread Count for the shared executor pool
*/
@Deprecated
public int getThreadCount() {
return threadCount;
}
Expand Down Expand Up @@ -249,7 +262,6 @@ public synchronized void init() {
new TaskPollExecutor(
eurekaClient,
taskClient,
threadCount,
updateRetryCount,
taskToDomain,
workerNamePrefix,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ default boolean leaseExtendEnabled() {
return PropertyFactory.getBoolean(getTaskDefName(), "leaseExtendEnabled", false);
}

default int getBatchPollTimeoutInMS() {
return PropertyFactory.getInteger(getTaskDefName(), "batchPollTimeoutInMS", 1000);
}

static Worker create(String taskType, Function<Task, TaskResult> executor) {
return new Worker() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,17 @@ public void testBlockAfterAvailablePermitsExhausted() throws Exception {
t ->
futuresList.add(
CompletableFuture.runAsync(
pollingSemaphore::canPoll, executorService)));
() -> pollingSemaphore.acquireSlots(1),
executorService)));

CompletableFuture<Void> allFutures =
CompletableFuture.allOf(
futuresList.toArray(new CompletableFuture[futuresList.size()]));

allFutures.get();

assertEquals(0, pollingSemaphore.availableThreads());
assertFalse(pollingSemaphore.canPoll());
assertEquals(0, pollingSemaphore.availableSlots());
assertFalse(pollingSemaphore.acquireSlots(1));

executorService.shutdown();
}
Expand All @@ -65,18 +66,19 @@ public void testAllowsPollingWhenPermitBecomesAvailable() throws Exception {
t ->
futuresList.add(
CompletableFuture.runAsync(
pollingSemaphore::canPoll, executorService)));
() -> pollingSemaphore.acquireSlots(1),
executorService)));

CompletableFuture<Void> allFutures =
CompletableFuture.allOf(
futuresList.toArray(new CompletableFuture[futuresList.size()]));
allFutures.get();

assertEquals(0, pollingSemaphore.availableThreads());
pollingSemaphore.complete();
assertEquals(0, pollingSemaphore.availableSlots());
pollingSemaphore.complete(1);

assertTrue(pollingSemaphore.availableThreads() > 0);
assertTrue(pollingSemaphore.canPoll());
assertTrue(pollingSemaphore.availableSlots() > 0);
assertTrue(pollingSemaphore.acquireSlots(1));

executorService.shutdown();
}
Expand Down
Loading

0 comments on commit 03f8d96

Please sign in to comment.