Skip to content

Commit

Permalink
[serve] Reduce surface area of create_router interface (ray-project…
Browse files Browse the repository at this point in the history
…#48467)

Moves more of the initialization logic into `create_router` so
`DeploymentHandle` does not depend on being run inside an initialized
Ray context.

This is required so I can use the same `DeploymentHandle` implementation
for local testing mode and only inject a different `create_router`
implementation.

---------

Signed-off-by: Edward Oakes <[email protected]>
  • Loading branch information
edoakes authored Nov 1, 2024
1 parent ee86218 commit 676f800
Show file tree
Hide file tree
Showing 10 changed files with 184 additions and 99 deletions.
46 changes: 31 additions & 15 deletions python/ray/serve/_private/default_impl.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import asyncio
from typing import Callable, Optional
from typing import Any, Callable, Optional, Tuple

import ray
from ray._raylet import GcsClient
from ray.actor import ActorHandle
from ray.serve._private.cluster_node_info_cache import (
ClusterNodeInfoCache,
DefaultClusterNodeInfoCache,
Expand All @@ -23,8 +21,9 @@
ActorReplicaWrapper,
PowerOfTwoChoicesReplicaScheduler,
)
from ray.serve._private.router import Router
from ray.serve._private.router import Router, SingletonThreadRouter
from ray.serve._private.utils import (
get_current_actor_id,
get_head_node_id,
inside_ray_client_context,
resolve_request_args,
Expand Down Expand Up @@ -67,20 +66,38 @@ def create_init_handle_options(**kwargs):
return _InitHandleOptions.create(**kwargs)


def _get_node_id_and_az() -> Tuple[str, Optional[str]]:
node_id = ray.get_runtime_context().get_node_id()
try:
cluster_node_info_cache = create_cluster_node_info_cache(
GcsClient(address=ray.get_runtime_context().gcs_address)
)
cluster_node_info_cache.update()
az = cluster_node_info_cache.get_node_az(node_id)
except Exception:
az = None

return node_id, az


# Interface definition for create_router.
CreateRouterCallable = Callable[[str, DeploymentID, Any], Router]


def create_router(
controller_handle: ActorHandle,
deployment_id: DeploymentID,
handle_id: str,
node_id: str,
actor_id: str,
availability_zone: Optional[str],
event_loop: asyncio.BaseEventLoop,
handle_options,
):
deployment_id: DeploymentID,
handle_options: Any,
) -> Router:
# NOTE(edoakes): this is lazy due to a nasty circular import that should be fixed.
from ray.serve.context import _get_global_client

actor_id = get_current_actor_id()
node_id, availability_zone = _get_node_id_and_az()
controller_handle = _get_global_client()._controller
is_inside_ray_client_context = inside_ray_client_context()

replica_scheduler = PowerOfTwoChoicesReplicaScheduler(
event_loop,
deployment_id,
handle_options._source,
handle_options._prefer_local_routing,
Expand All @@ -98,13 +115,12 @@ def create_router(
create_replica_wrapper_func=lambda r: ActorReplicaWrapper(r),
)

return Router(
return SingletonThreadRouter(
controller_handle=controller_handle,
deployment_id=deployment_id,
handle_id=handle_id,
self_actor_id=actor_id,
handle_source=handle_options._source,
event_loop=event_loop,
replica_scheduler=replica_scheduler,
# Streaming ObjectRefGenerators are not supported in Ray Client
enable_strict_max_ongoing_requests=(
Expand Down
2 changes: 1 addition & 1 deletion python/ray/serve/_private/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,7 @@ def _get_logging_config(self) -> Tuple:

def _dump_ingress_replicas_for_testing(self, route: str) -> Set[ReplicaID]:
_, handle, _ = self.http_proxy.proxy_router.match_route(route)
return handle._router._replica_scheduler._replica_id_set
return handle._router._asyncio_router._replica_scheduler._replica_id_set

def should_start_grpc_service(self) -> bool:
"""Determine whether gRPC service should be started.
Expand Down
22 changes: 15 additions & 7 deletions python/ray/serve/_private/replica_scheduler/pow_2_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ class PowerOfTwoChoicesReplicaScheduler(ReplicaScheduler):

def __init__(
self,
event_loop: asyncio.AbstractEventLoop,
deployment_id: DeploymentID,
handle_source: DeploymentHandleSource,
prefer_local_node_routing: bool = False,
Expand All @@ -105,7 +104,6 @@ def __init__(
Callable[[RunningReplicaInfo], ReplicaWrapper]
] = None,
):
self._loop = event_loop
self._deployment_id = deployment_id
self._handle_source = handle_source
self._prefer_local_node_routing = prefer_local_node_routing
Expand All @@ -130,6 +128,7 @@ def __init__(
# from a different loop than it uses for scheduling, so we need to construct it
# lazily to avoid an error due to the event being attached to the wrong loop.
self._lazily_constructed_replicas_updated_event: Optional[asyncio.Event] = None
self._lazily_fetched_loop: Optional[asyncio.AbstractEventLoop] = None

# Colocated replicas (e.g. wrt node, AZ)
self._colocated_replica_ids: DefaultDict[
Expand Down Expand Up @@ -195,6 +194,13 @@ def __init__(
self.num_scheduling_tasks_in_backoff
)

@property
def _event_loop(self) -> asyncio.AbstractEventLoop:
if self._lazily_fetched_loop is None:
self._lazily_fetched_loop = asyncio.get_running_loop()

return self._lazily_fetched_loop

@property
def _replicas_updated_event(self) -> asyncio.Event:
"""Lazily construct `asyncio.Event`.
Expand Down Expand Up @@ -323,7 +329,7 @@ def update_replicas(self, replicas: List[ReplicaWrapper]):
active_replica_ids=new_replica_id_set
)
# Populate cache for new replicas
self._loop.create_task(self._probe_queue_lens(replicas_to_ping, 0))
self._event_loop.create_task(self._probe_queue_lens(replicas_to_ping, 0))
self._replicas_updated_event.set()
self.maybe_start_scheduling_tasks()

Expand Down Expand Up @@ -548,7 +554,7 @@ async def _probe_queue_lens(

get_queue_len_tasks = []
for r in replicas:
t = self._loop.create_task(
t = self._event_loop.create_task(
r.get_queue_len(deadline_s=queue_len_response_deadline_s)
)
t.replica = r
Expand Down Expand Up @@ -660,7 +666,9 @@ async def select_from_candidate_replicas(
elif len(not_in_cache) > 0:
# If there are replicas without a valid cache entry, probe them in the
# background to populate the cache.
self._loop.create_task(self._probe_queue_lens(not_in_cache, backoff_index))
self._event_loop.create_task(
self._probe_queue_lens(not_in_cache, backoff_index)
)

# `self._replicas` may have been updated since the candidates were chosen.
# In that case, return `None` so a new one is selected.
Expand Down Expand Up @@ -779,7 +787,7 @@ async def fulfill_pending_requests(self):
except Exception:
logger.exception("Unexpected error in fulfill_pending_requests.")
finally:
self._scheduling_tasks.remove(asyncio.current_task(loop=self._loop))
self._scheduling_tasks.remove(asyncio.current_task(loop=self._event_loop))
self.num_scheduling_tasks_gauge.set(self.curr_num_scheduling_tasks)

def maybe_start_scheduling_tasks(self):
Expand All @@ -797,7 +805,7 @@ def maybe_start_scheduling_tasks(self):
)
for _ in range(tasks_to_start):
self._scheduling_tasks.add(
self._loop.create_task(self.fulfill_pending_requests())
self._event_loop.create_task(self.fulfill_pending_requests())
)
if tasks_to_start > 0:
self.num_scheduling_tasks_gauge.set(self.curr_num_scheduling_tasks)
Expand Down
91 changes: 87 additions & 4 deletions python/ray/serve/_private/router.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
import concurrent.futures
import logging
import threading
import time
import uuid
from abc import ABC, abstractmethod
from collections import defaultdict
from contextlib import contextmanager
from functools import partial
Expand Down Expand Up @@ -310,7 +312,26 @@ async def shutdown(self):
self._shutdown = True


class Router:
class Router(ABC):
@abstractmethod
def running_replicas_populated(self) -> bool:
pass

@abstractmethod
def assign_request(
self,
request_meta: RequestMetadata,
*request_args,
**request_kwargs,
) -> concurrent.futures.Future[ReplicaResult]:
pass

@abstractmethod
def shutdown(self):
pass


class AsyncioRouter:
def __init__(
self,
controller_handle: ActorHandle,
Expand Down Expand Up @@ -338,7 +359,7 @@ def __init__(

# Flipped to `True` once the router has received a non-empty
# replica set at least once.
self.running_replicas_populated: bool = False
self._running_replicas_populated: bool = False

# The config for the deployment this router sends requests to will be broadcast
# by the controller. That means it is not available until we get the first
Expand Down Expand Up @@ -392,12 +413,15 @@ def __init__(
call_in_event_loop=self._event_loop,
)

def running_replicas_populated(self) -> bool:
return self._running_replicas_populated

def update_running_replicas(self, running_replicas: List[RunningReplicaInfo]):
self._replica_scheduler.update_running_replicas(running_replicas)
self._metrics_manager.update_running_replicas(running_replicas)

if running_replicas:
self.running_replicas_populated = True
self._running_replicas_populated = True

def update_deployment_config(self, deployment_config: DeploymentConfig):
self._metrics_manager.update_deployment_config(
Expand Down Expand Up @@ -560,7 +584,66 @@ async def assign_request(

raise

async def shutdown(self):
await self._metrics_manager.shutdown()


class SingletonThreadRouter(Router):
"""Wrapper class that runs an AsyncioRouter on a separate thread.
The motivation for this is to avoid user code blocking the event loop and
preventing the router from making progress.
Maintains a singleton event loop running in a daemon thread that is shared by
all AsyncioRouters.
"""

_asyncio_loop: Optional[asyncio.AbstractEventLoop] = None
_asyncio_loop_creation_lock = threading.Lock()

def __init__(self, **passthrough_kwargs):
assert (
"event_loop" not in passthrough_kwargs
), "SingletonThreadRouter manages the router event loop."

self._asyncio_router = AsyncioRouter(
event_loop=self._get_singleton_asyncio_loop(), **passthrough_kwargs
)

@classmethod
def _get_singleton_asyncio_loop(cls) -> asyncio.AbstractEventLoop:
"""Get singleton asyncio loop running in a daemon thread.
This method is thread safe.
"""
with cls._asyncio_loop_creation_lock:
if cls._asyncio_loop is None:
cls._asyncio_loop = asyncio.new_event_loop()
thread = threading.Thread(
daemon=True,
target=cls._asyncio_loop.run_forever,
)
thread.start()

return cls._asyncio_loop

def running_replicas_populated(self) -> bool:
return self._asyncio_router.running_replicas_populated()

def assign_request(
self,
request_meta: RequestMetadata,
*request_args,
**request_kwargs,
) -> concurrent.futures.Future[ReplicaResult]:
return asyncio.run_coroutine_threadsafe(
self._asyncio_router.assign_request(
request_meta, *request_args, **request_kwargs
),
loop=self._asyncio_loop,
)

def shutdown(self):
asyncio.run_coroutine_threadsafe(
self._metrics_manager.shutdown(), loop=self._event_loop
self._asyncio_router.shutdown(), loop=self._asyncio_loop
).result()
Loading

0 comments on commit 676f800

Please sign in to comment.