Skip to content

Commit

Permalink
Avoid shutting down Ray driver when running on a Ray worker (PrefectH…
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle authored Oct 1, 2024
1 parent 18abe92 commit c068d7e
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/integrations/prefect-ray/prefect_ray/task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def __enter__(self):
"Using existing local instance."
)
return self
if self.address and self.address != "auto":
elif self.address and self.address != "auto":
self.logger.info(
f"Connecting to an existing Ray instance at {self.address}"
)
Expand All @@ -427,8 +427,12 @@ def __enter__(self):

def __exit__(self, *exc_info):
"""
Shuts down the cluster.
Shuts down the driver/cluster.
"""
self.logger.debug("Shutting down Ray cluster...")
ray.shutdown()
# Check if we are running on the driver. Calling ray.shutdown() when running on a
# worker will crash the worker.
if ray.get_runtime_context().worker.mode == 0:
# Running on the driver. Will shutdown cluster if started by this task runner.
self.logger.debug("Shutting down Ray driver...")
ray.shutdown()
super().__exit__(*exc_info)
30 changes: 30 additions & 0 deletions src/integrations/prefect-ray/tests/test_task_runners.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
import random
import subprocess
import sys
import time
Expand Down Expand Up @@ -508,3 +509,32 @@ def test_flow():
test_flow()

assert "A future was garbage collected before it resolved" not in caplog.text

def test_can_run_many_tasks_without_crashing(self, task_runner):
"""
Regression test for https://github.com/PrefectHQ/prefect/issues/15539
"""

@task
def random_integer(range_from: int = 0, range_to: int = 100) -> int:
"""Task that returns a random integer."""

random_int = random.randint(range_from, range_to)

return random_int

@flow(task_runner=task_runner)
def add_random_integers(number_tasks: int = 50) -> int:
"""Flow that submits some random_integer tasks and returns the sum of the results."""

futures = []
for _ in range(number_tasks):
futures.append(random_integer.submit())

sum = 0
for future in futures:
sum += future.result()

return sum

assert add_random_integers() > 0

0 comments on commit c068d7e

Please sign in to comment.