Skip to content

Commit

Permalink
Task retry logic (#140)
Browse files Browse the repository at this point in the history
* implement retry logic for failed tasks.

* retry logic with tests
  • Loading branch information
braceal authored Aug 21, 2024
1 parent 8d45741 commit 972659a
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 1 deletion.
4 changes: 4 additions & 0 deletions colmena/models/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,10 @@ class Result(BaseModel):
# Task routing information
topic: Optional[str] = Field(None, description='Label used to group results in queue between Thinker and Task Server')

# Fault tolerance
max_retries: int = Field(0, description='Maximum number of times this task should be retried if it fails')
retries: int = Field(0, description='Number of times this task has been retried')

def __init__(self, inputs: Tuple[Tuple[Any], Dict[str, Any]], **kwargs):
"""
Args:
Expand Down
5 changes: 4 additions & 1 deletion colmena/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ def send_inputs(self,
keep_inputs: Optional[bool] = None,
resources: Optional[Union[ResourceRequirements, dict]] = None,
topic: str = 'default',
task_info: Optional[Dict[str, Any]] = None) -> str:
task_info: Optional[Dict[str, Any]] = None,
max_retries: int = 0) -> str:
"""Send a task request
Args:
Expand All @@ -195,6 +196,7 @@ def send_inputs(self,
topic (str): Topic for the queue, which sets the topic for the result
resources: Suggestions for how many resources to use for the task
task_info (dict): Any information used for task tracking
max_retries (int): Maximum number of times to retry the task if it fails
Returns:
Task ID
"""
Expand Down Expand Up @@ -236,6 +238,7 @@ def send_inputs(self,
task_info=task_info,
resources=resources or ResourceRequirements(), # Takes either the user specified or a default,
topic=topic,
max_retries=max_retries,
**ps_kwargs
)

Expand Down
19 changes: 19 additions & 0 deletions colmena/task_server/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,25 @@ def perform_callback(self, future: Future, result: Result, topic: str):

task_exc = future.exception()

# The task could have failed at the workflow engine level (task_exc)
# or application level (result.failure_info)
task_failed = (task_exc is not None) or (result.failure_info is not None)

# If the task failed and we have retries left, try again
if task_failed and result.retries < result.max_retries:
# Increment the retry count and clear the failure information
result.retries += 1
result.failure_info, result.success = None, None

# Log the retry
logger.warning(f'Task {result.task_id} failed. Retrying with {result.retries}/{result.max_retries} retries.')

# Provide it to the workflow system to be re-executed
self.process_queue(topic, result)

# Do not send the result back to the user
return

# If it was, send back a modified copy of the input structure
if task_exc is not None:
# Mark it as unsuccessful and capture the exception information
Expand Down
100 changes: 100 additions & 0 deletions colmena/task_server/tests/test_retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import Tuple, Generator
from parsl import ThreadPoolExecutor
from parsl.config import Config
from colmena.queue.base import ColmenaQueues
from colmena.queue.python import PipeQueues
from colmena.task_server.parsl import ParslTaskServer

from pytest import fixture, mark

# Make global state for the retry task
RETRY_COUNT = 0

def retry_task(success_idx: int) -> bool:
"""Task that will succeed (return True) every `success_idx` times."""
global RETRY_COUNT

# If we haven't reached the success index, raise an error.
if RETRY_COUNT < success_idx:
RETRY_COUNT += 1
raise ValueError('Retry')

# Reset the retry count
RETRY_COUNT = 0
return True

@fixture
def reset_retry_count():
"""Reset the retry count before each test."""
global RETRY_COUNT
RETRY_COUNT = 0

@fixture()
def config(tmpdir):
"""Make the Parsl configuration."""
return Config(
executors=[
ThreadPoolExecutor(max_threads=2)
],
strategy=None,
run_dir=str(tmpdir / 'run'),
)

@fixture
def server_and_queue(config) -> Generator[Tuple[ParslTaskServer, ColmenaQueues], None, None]:
queues = PipeQueues()
server = ParslTaskServer([retry_task], queues, config)
yield server, queues
if server.is_alive():
queues.send_kill_signal()
server.join(timeout=30)


@mark.timeout(10)
def test_retry_policy_max_retries_zero(server_and_queue, reset_retry_count):
"""Test the retry policy with max_retries=0"""

# Start the server
server, queue = server_and_queue
server.start()

# The task will fail every other time (setting success_idx=1)
success_idx = 1

for i in range(4):
# The task will fail every other time (setting success_idx=1)
queue.send_inputs(success_idx, method='retry_task', max_retries=0)
result = queue.get_result()
assert result.success == (i % 2 == 1)
if i % 2 == 1:
assert result.value
assert result.failure_info is None
else:
assert not result.success
assert 'Retry' in str(result.failure_info.exception)

@mark.timeout(10)
@mark.parametrize(('success_idx', 'max_retries'), [(0, 0), (1, 1), (4, 10)])
def test_retry_policy_max_retries(server_and_queue, reset_retry_count, success_idx: int, max_retries: int):
"""Test the retry policy.
This test checks the following cases:
- A task that always succeeds (success_idx=0, max_retries=0)
- A task that succeeds after one retry (success_idx=1, max_retries=1)
- A task that succeeds after four retries (and extra max_retries) (success_idx=4, max_retries=10)
"""

# Start the server
server, queue = server_and_queue
server.start()

# The task will fail every other time (setting success_idx=1)
# However, we set max_retries=1, so it should succeed after the first try
queue.send_inputs(success_idx, method='retry_task', max_retries=max_retries)
result = queue.get_result()
assert result is not None
assert result.success
assert result.value
assert result.failure_info is None
assert result.retries == success_idx
assert result.max_retries == max_retries

0 comments on commit 972659a

Please sign in to comment.