From f40d236d6de2d878cc8983f1601448956b2086ca Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Fri, 30 Jun 2023 16:55:35 -0500 Subject: [PATCH] [serve] Implement timeouts for streaming & enable all tests with experimental streaming flag turned on (#36261) This is the penultimate PR for turning RAY_SERVE_EXPERIMENTAL_STREAMING on by default. In this change, we are now running all Serve tests under three conditions: - RAY_SERVE_EXPERIMENTAL_STREAMING=0 && RAY_SERVE_ENABLE_NEW_ROUTING=0 - RAY_SERVE_EXPERIMENTAL_STREAMING=0 && RAY_SERVE_ENABLE_NEW_ROUTING=1 - RAY_SERVE_EXPERIMENTAL_STREAMING=1 && RAY_SERVE_ENABLE_NEW_ROUTING=1 (Note that RAY_SERVE_EXPERIMENTAL_STREAMING=1 implies RAY_SERVE_ENABLE_NEW_ROUTING so the latter is not explicitly set). This required making two functionality fixes: - Never using the streaming path for Java even if the flag is on. - Implementing timeouts in the streaming path. I could split these into separate PRs, but they can't really be fully tested independently, so opting to combine them here. --- .buildkite/pipeline.build.yml | 37 +++- .buildkite/pipeline.build_py37.yml | 39 +++- python/ray/serve/BUILD | 28 +-- .../ray/serve/_private/application_state.py | 7 +- python/ray/serve/_private/client.py | 9 +- python/ray/serve/_private/common.py | 1 + python/ray/serve/_private/http_proxy.py | 179 +++++++++++++++--- python/ray/serve/_private/http_util.py | 14 +- python/ray/serve/_private/replica.py | 12 +- python/ray/serve/_private/utils.py | 20 ++ python/ray/serve/config.py | 2 +- python/ray/serve/controller.py | 9 +- python/ray/serve/tests/conftest.py | 32 ++++ python/ray/serve/tests/test_cli.py | 4 +- .../tests/test_experimental_streaming.py | 21 -- python/ray/serve/tests/test_handle.py | 6 +- .../serve/tests/test_http_prefix_matching.py | 70 ++++--- python/ray/serve/tests/test_http_routes.py | 1 - python/ray/serve/tests/test_http_util.py | 7 +- python/ray/serve/tests/test_regression.py | 30 ++- .../ray/serve/tests/test_request_timeout.py | 150 +++++++++++++++ python/ray/serve/tests/test_standalone2.py | 117 +----------- python/ray/serve/tests/test_standalone3.py | 32 ---- python/ray/serve/tests/test_util.py | 55 +++++- 24 files changed, 600 insertions(+), 282 deletions(-) delete mode 100644 python/ray/serve/tests/test_experimental_streaming.py create mode 100644 python/ray/serve/tests/test_request_timeout.py diff --git a/.buildkite/pipeline.build.yml b/.buildkite/pipeline.build.yml index 21ce7a9e2fba..d7b807c4a17d 100644 --- a/.buildkite/pipeline.build.yml +++ b/.buildkite/pipeline.build.yml @@ -10,11 +10,11 @@ commands: - ./java/test.sh -- label: ":java: Java (RAY_SERVE_ENABLE_NEW_ROUTING=1)" +- label: ":java: Java (RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1)" conditions: ["RAY_CI_JAVA_AFFECTED"] instance_size: medium commands: - - export RAY_SERVE_ENABLE_NEW_ROUTING=1 && ./java/test.sh + - export RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1 && ./java/test.sh - label: ":serverless: Dashboard Tests" conditions: @@ -110,6 +110,39 @@ --test_env=RAY_SERVE_ENABLE_NEW_ROUTING=1 $(cat test_shard.txt) +- label: ":serverless: Serve Tests (RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1)" + parallelism: 3 + conditions: + [ + "RAY_CI_SERVE_AFFECTED", + "RAY_CI_PYTHON_AFFECTED", + "RAY_CI_ML_AFFECTED", + ] + instance_size: large + commands: + - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT + - TORCH_VERSION=1.9.0 ./ci/env/install-dependencies.sh + - bash ./ci/ci.sh prepare_docker + - 'git clone https://github.com/wg/wrk.git /tmp/wrk && pushd /tmp/wrk && make -j && sudo cp wrk /usr/local/bin && popd' + - ./ci/env/env_info.sh + - >- + set -x; + python ./ci/run/bazel_sharding/bazel_sharding.py + --exclude_manual + --index "\${BUILDKITE_PARALLEL_JOB}" --count "\${BUILDKITE_PARALLEL_JOB_COUNT}" + --tag_filters=-post_wheel_build,-gpu + python/ray/serve/... + > test_shard.txt + - cat test_shard.txt + - bazel test --config=ci $(./ci/run/bazel_export_options) + --test_tag_filters=-post_wheel_build,-gpu + --test_env=DOCKER_HOST=tcp://docker:2376 + --test_env=DOCKER_TLS_VERIFY=1 + --test_env=DOCKER_CERT_PATH=/certs/client + --test_env=DOCKER_TLS_CERTDIR=/certs + --test_env=RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1 + $(cat test_shard.txt) + - label: ":python: Minimal install Python {{matrix}}" conditions: ["RAY_CI_PYTHON_AFFECTED"] instance_size: medium diff --git a/.buildkite/pipeline.build_py37.yml b/.buildkite/pipeline.build_py37.yml index a661442540e9..ae7d86368705 100644 --- a/.buildkite/pipeline.build_py37.yml +++ b/.buildkite/pipeline.build_py37.yml @@ -142,7 +142,7 @@ $(cat test_shard.txt) -- label: ":cold_face: :python: :serverless: Python 3.7 Serve Tests (RAY_SERVE_USE_NEW_ROUTING=1)" +- label: ":cold_face: :python: :serverless: Python 3.7 Serve Tests (RAY_SERVE_ENABLE_NEW_ROUTING=1)" parallelism: 3 conditions: [ @@ -173,5 +173,40 @@ --test_env=DOCKER_TLS_VERIFY=1 --test_env=DOCKER_CERT_PATH=/certs/client --test_env=DOCKER_TLS_CERTDIR=/certs - --test_env=RAY_SERVE_USE_NEW_ROUTING=1 + --test_env=RAY_SERVE_ENABLE_NEW_ROUTING=1 + $(cat test_shard.txt) + + +- label: ":cold_face: :python: :serverless: Python 3.7 Serve Tests (RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1)" + parallelism: 3 + conditions: + [ + "RAY_CI_SERVE_AFFECTED", + "RAY_CI_PYTHON_AFFECTED", + "RAY_CI_ML_AFFECTED", + ] + instance_size: large + commands: + - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT + - ./ci/env/install-minimal.sh 3.7 + - PYTHON=3.7 TORCH_VERSION=1.9.0 ./ci/env/install-dependencies.sh + - bash ./ci/ci.sh prepare_docker + - 'git clone https://github.com/wg/wrk.git /tmp/wrk && pushd /tmp/wrk && make -j && sudo cp wrk /usr/local/bin && popd' + - ./ci/env/env_info.sh + - >- + set -x; + python ./ci/run/bazel_sharding/bazel_sharding.py + --exclude_manual + --index "\${BUILDKITE_PARALLEL_JOB}" --count "\${BUILDKITE_PARALLEL_JOB_COUNT}" + --tag_filters=-post_wheel_build,-gpu + python/ray/serve/... + > test_shard.txt + - cat test_shard.txt + - bazel test --config=ci $(./ci/run/bazel_export_options) + --test_tag_filters=-post_wheel_build,-gpu + --test_env=DOCKER_HOST=tcp://docker:2376 + --test_env=DOCKER_TLS_VERIFY=1 + --test_env=DOCKER_CERT_PATH=/certs/client + --test_env=DOCKER_TLS_CERTDIR=/certs + --test_env=RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1 $(cat test_shard.txt) diff --git a/python/ray/serve/BUILD b/python/ray/serve/BUILD index b2b4e2558966..91f8cd8302c1 100644 --- a/python/ray/serve/BUILD +++ b/python/ray/serve/BUILD @@ -312,6 +312,14 @@ py_test( deps = [":serve_lib"], ) +py_test( + name = "test_request_timeout", + size = "medium", + srcs = serve_tests_srcs, + tags = ["exclusive", "team:serve"], + deps = [":serve_lib"], +) + py_test( name = "test_standalone", size = "large", @@ -494,29 +502,12 @@ py_test( deps = [":serve_lib"], ) -# Runs a subset of the tests with experimental streaming turned on. -py_test( - name = "test_experimental_streaming", - size = "large", - srcs = glob(["tests/test_experimental_streaming.py", - "tests/test_api.py", - "tests/test_failure.py", - "tests/test_fastapi.py", - "tests/test_http_adapters.py", - "tests/test_http_headers.py", - "**/conftest.py"]), - tags = ["exclusive", "team:serve"], - deps = [":serve_lib"], - env = {"RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING": "1"}, -) - py_test( name = "test_streaming_response", size = "large", srcs = serve_tests_srcs, tags = ["exclusive", "team:serve"], deps = [":serve_lib"], - env = {"RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING": "1"}, ) py_test( @@ -525,7 +516,6 @@ py_test( srcs = serve_tests_srcs, tags = ["exclusive", "team:serve"], deps = [":serve_lib"], - env = {"RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING": "1"}, ) py_test( @@ -613,7 +603,7 @@ py_test( py_test( name = "test_gradio", - size = "small", + size = "medium", srcs = serve_tests_srcs, tags = ["exclusive", "team:serve"], deps = [":serve_lib"], diff --git a/python/ray/serve/_private/application_state.py b/python/ray/serve/_private/application_state.py index 30698184846c..89f3d5ef7f65 100644 --- a/python/ray/serve/_private/application_state.py +++ b/python/ray/serve/_private/application_state.py @@ -193,9 +193,14 @@ def apply_deployment_info( self._deployment_state_manager.deploy(deployment_name, deployment_info) if deployment_info.route_prefix is not None: + config = deployment_info.deployment_config self._endpoint_state.update_endpoint( deployment_name, - EndpointInfo(route=deployment_info.route_prefix, app_name=self._name), + EndpointInfo( + route=deployment_info.route_prefix, + app_name=self._name, + app_is_cross_language=config.is_cross_language, + ), ) else: self._endpoint_state.delete_endpoint(deployment_name) diff --git a/python/ray/serve/_private/client.py b/python/ray/serve/_private/client.py index e36f6ef55977..00f67fe4dd44 100644 --- a/python/ray/serve/_private/client.py +++ b/python/ray/serve/_private/client.py @@ -308,7 +308,14 @@ def deploy( route_prefix=route_prefix, ) - updating = ray.get(self._controller.deploy.remote(**controller_deploy_args)) + updating = ray.get( + self._controller.deploy.remote( + # TODO(edoakes): this is a hack because the deployment_language + # doesn't seem to get set properly from Java. + is_deployed_from_python=True, + **controller_deploy_args, + ) + ) tag = self.log_deployment_update_status(name, version, updating) diff --git a/python/ray/serve/_private/common.py b/python/ray/serve/_private/common.py index 144ad9f22695..69d13c0210bb 100644 --- a/python/ray/serve/_private/common.py +++ b/python/ray/serve/_private/common.py @@ -28,6 +28,7 @@ class EndpointInfo: route: str app_name: str + app_is_cross_language: bool = False # Keep in sync with ServeReplicaState in dashboard/client/src/type/serve.ts diff --git a/python/ray/serve/_private/http_proxy.py b/python/ray/serve/_private/http_proxy.py index 3b50bc4a276e..bb8b51797fc4 100644 --- a/python/ray/serve/_private/http_proxy.py +++ b/python/ray/serve/_private/http_proxy.py @@ -19,6 +19,7 @@ from ray.exceptions import RayActorError, RayTaskError from ray.util import metrics from ray._private.utils import get_or_create_event_loop +from ray._raylet import StreamingObjectRefGenerator from ray import serve from ray.serve.handle import RayServeHandle @@ -48,7 +49,11 @@ get_component_logger_file_path, ) -from ray.serve._private.utils import get_random_letters, call_function_from_import_path +from ray.serve._private.utils import ( + calculate_remaining_timeout, + call_function_from_import_path, + get_random_letters, +) logger = logging.getLogger(SERVE_LOGGER_NAME) @@ -59,6 +64,7 @@ "RAY_SERVE_HTTP_REQUEST_MAX_RETRIES cannot be negative." ) +TIMEOUT_ERROR_CODE = "timeout" DISCONNECT_ERROR_CODE = "disconnection" SOCKET_REUSE_PORT_ENABLED = ( os.environ.get("SERVE_SOCKET_REUSE_PORT_ENABLED", "1") == "1" @@ -94,6 +100,8 @@ def __init__(self, get_handle: Callable): self.route_info: Dict[str, Tuple[EndpointTag, ApplicationName]] = dict() # Contains a ServeHandle for each endpoint. self.handles: Dict[str, RayServeHandle] = dict() + # Map of application name to is_cross_language. + self.app_to_is_cross_language: Dict[ApplicationName, bool] = dict() def endpoint_exists(self, endpoint: EndpointTag) -> bool: return endpoint in self.handles @@ -106,13 +114,22 @@ def update_routes(self, endpoints: Dict[EndpointTag, EndpointInfo]) -> None: existing_handles = set(self.handles.keys()) routes = [] route_info = {} + app_to_is_cross_language = {} for endpoint, info in endpoints.items(): routes.append(info.route) route_info[info.route] = (endpoint, info.app_name) + app_to_is_cross_language[info.app_name] = info.app_is_cross_language if endpoint in self.handles: existing_handles.remove(endpoint) else: - self.handles[endpoint] = self._get_handle(endpoint) + self.handles[endpoint] = self._get_handle( + endpoint, + # Streaming codepath isn't supported for Java. + stream=( + RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING + and not info.app_is_cross_language + ), + ) # Clean up any handles that are no longer used. if len(existing_handles) > 0: @@ -127,18 +144,18 @@ def update_routes(self, endpoints: Dict[EndpointTag, EndpointInfo]) -> None: # prefix matching. self.sorted_routes = sorted(routes, key=lambda x: len(x), reverse=True) self.route_info = route_info + self.app_to_is_cross_language = app_to_is_cross_language def match_route( self, target_route: str - ) -> Tuple[Optional[str], Optional[RayServeHandle]]: + ) -> Optional[Tuple[str, RayServeHandle, str, bool]]: """Return the longest prefix match among existing routes for the route. Args: target_route: route to match against. Returns: - (matched_route (str), serve_handle (RayServeHandle)) if found, - else (None, None). + (route, handle, app_name, is_cross_language) if found, else None. """ for route in self.sorted_routes: @@ -159,9 +176,14 @@ def match_route( if matched: endpoint, app_name = self.route_info[route] - return route, self.handles[endpoint], app_name + return ( + route, + self.handles[endpoint], + app_name, + self.app_to_is_cross_language[app_name], + ) - return None, None, None + return None class HTTPProxy: @@ -179,6 +201,9 @@ def __init__( request_timeout_s: Optional[float] = None, ): self.request_timeout_s = request_timeout_s + if self.request_timeout_s is not None and self.request_timeout_s < 0: + self.request_timeout_s = None + self._node_id = node_id # Set the controller name so that serve will connect to the @@ -199,13 +224,13 @@ def __init__( extra={"log_to_stderr": False}, ) - def get_handle(name): + def get_handle(name, stream: bool = False): return serve.context.get_global_client().get_handle( name, sync=False, missing_ok=True, _is_for_http_requests=True, - _stream=RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING, + _stream=stream, ) self.prefix_router = LongestPrefixRouter(get_handle) @@ -424,8 +449,8 @@ async def __call__(self, scope, receive, send): try: self._ongoing_requests_start() - route_prefix, handle, app_name = self.prefix_router.match_route(route_path) - if route_prefix is None: + matched_route = self.prefix_router.match_route(route_path) + if matched_route is None: self.request_error_counter.inc( tags={ "route": route_path, @@ -443,6 +468,8 @@ async def __call__(self, scope, receive, send): ) return await self._not_found(scope, receive, send) + route_prefix, handle, app_name, app_is_cross_language = matched_route + # Modify the path and root path so that reverse lookups and redirection # work as expected. We do this here instead of in replicas so it can be # changed without restarting the replicas. @@ -467,7 +494,8 @@ async def __call__(self, scope, receive, send): ray.serve.context.RequestContext(**request_context_info) ) - if RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING: + # Streaming codepath isn't supported for Java. + if RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING and not app_is_cross_language: status_code = await self.send_request_to_replica_streaming( request_context_info["request_id"], handle, scope, receive, send ) @@ -654,6 +682,81 @@ async def proxy_asgi_receive( if msg["type"] == "websocket.disconnect": return msg["code"] + async def _assign_request_with_timeout( + self, + handle: RayServeHandle, + scope: Scope, + disconnected_task: asyncio.Task, + timeout_s: Optional[float] = None, + ) -> Optional[StreamingObjectRefGenerator]: + """Attempt to send a request on the handle within the timeout. + + If `timeout_s` is exceeded while trying to assign a replica, `TimeoutError` + will be raised. + + `disconnected_task` is expected to be done if the client disconnects; in this + case, we will abort assigning a replica and return `None`. + """ + assignment_task = handle.remote(pickle.dumps(scope), self.self_actor_handle) + done, _ = await asyncio.wait( + [assignment_task, disconnected_task], + return_when=FIRST_COMPLETED, + timeout=timeout_s, + ) + if assignment_task in done: + return assignment_task.result() + elif disconnected_task in done: + assignment_task.cancel() + return None + else: + assignment_task.cancel() + raise TimeoutError() + + async def _consume_and_send_asgi_message_generator( + self, + obj_ref_generator: StreamingObjectRefGenerator, + send: Send, + timeout_s: Optional[float] = None, + ) -> Optional[str]: + """Consumes an obj ref generator that yields ASGI messages. + + The messages are sent over the `send` interface. + + If timeout_s is `None`, there's no timeout. If it's not `None`, a timeout error + will be raised if the full generator isn't consumed within the timeout. + + Returns the status code for HTTP responses. + """ + status_code = "" + start = time.time() + while True: + try: + obj_ref = await obj_ref_generator._next_async( + timeout_s=calculate_remaining_timeout( + timeout_s=timeout_s, + start_time_s=start, + curr_time_s=time.time(), + ) + ) + if obj_ref.is_nil(): + raise TimeoutError + + asgi_messages: List[Message] = pickle.loads(await obj_ref) + for asgi_message in asgi_messages: + if asgi_message["type"] == "http.response.start": + # HTTP responses begin with exactly one + # "http.response.start" message containing the "status" + # field Other response types (e.g., WebSockets) may not. + status_code = str(asgi_message["status"]) + elif asgi_message["type"] == "websocket.disconnect": + status_code = str(asgi_message["code"]) + + await send(asgi_message) + except StopAsyncIteration: + break + + return status_code + async def send_request_to_replica_streaming( self, request_id: str, @@ -672,22 +775,46 @@ async def send_request_to_replica_streaming( ) status_code = "" + start = time.time() try: - object_ref_generator = await handle.remote( - pickle.dumps(scope), self.self_actor_handle - ) - async for obj_ref in object_ref_generator: - asgi_messages: List[Message] = pickle.loads(await obj_ref) - for asgi_message in asgi_messages: - if asgi_message["type"] == "http.response.start": - # HTTP responses begin with exactly one "http.response.start" - # message containing the "status" field. Other response types - # (e.g., WebSockets) may not. - status_code = str(asgi_message["status"]) - elif asgi_message["type"] == "websocket.disconnect": - status_code = str(asgi_message["code"]) + try: + obj_ref_generator = await self._assign_request_with_timeout( + handle, + scope, + proxy_asgi_receive_task, + timeout_s=self.request_timeout_s, + ) + if obj_ref_generator is None: + logger.info( + f"Client from {scope['client']} disconnected, cancelling the " + "request.", + extra={"log_to_stderr": False}, + ) + return DISCONNECT_ERROR_CODE + except TimeoutError: + logger.warning( + f"Request {request_id} timed out after " + f"{self.request_timeout_s}s while waiting for assignment." + ) + return TIMEOUT_ERROR_CODE + + try: + status_code = await self._consume_and_send_asgi_message_generator( + obj_ref_generator, + send, + timeout_s=calculate_remaining_timeout( + timeout_s=self.request_timeout_s, + start_time_s=start, + curr_time_s=time.time(), + ), + ) + except TimeoutError: + logger.warning( + f"Request {request_id} timed out after " + f"{self.request_timeout_s}s while executing." + ) + return TIMEOUT_ERROR_CODE - await send(asgi_message) except Exception as e: logger.exception(e) status_code = "500" diff --git a/python/ray/serve/_private/http_util.py b/python/ray/serve/_private/http_util.py index 8033f6075a6f..9c0cecba33a4 100644 --- a/python/ray/serve/_private/http_util.py +++ b/python/ray/serve/_private/http_util.py @@ -207,25 +207,15 @@ class ASGIReceiveProxy: def __init__( self, - event_loop: asyncio.AbstractEventLoop, request_id: str, actor_handle: ActorHandle, ): - self._task = None self._queue = asyncio.Queue() - self._event_loop = event_loop self._request_id = request_id self._actor_handle = actor_handle self._disconnect_message = None - def start(self): - self._task = self._event_loop.create_task(self._fetch_until_disconnect()) - - def stop(self): - if self._task is not None and not self._task.done(): - self._task.cancel() - - async def _fetch_until_disconnect(self): + async def fetch_until_disconnect(self): """Fetch messages repeatedly until a disconnect message is received. If a disconnect message is received, this function exits and returns it. @@ -255,8 +245,6 @@ async def __call__(self) -> Message: This will repeatedly return a disconnect message once it's been received. """ - assert self._task is not None, "Must call `start` before receiving messages." - if self._queue.empty() and self._disconnect_message is not None: return self._disconnect_message diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index 730d8060a8c2..224a63eba010 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -266,14 +266,16 @@ async def handle_request_streaming( "Only HTTP requests are currently supported over streaming." ) - receiver = None + receiver_task = None handle_request_task = None wait_for_message_task = None try: receiver = ASGIReceiveProxy( - self._event_loop, request_metadata.request_id, http_proxy_handle + request_metadata.request_id, http_proxy_handle + ) + receiver_task = self._event_loop.create_task( + receiver.fetch_until_disconnect() ) - receiver.start() scope = pickle.loads(pickled_asgi_scope) asgi_queue_send = ASGIMessageQueue() @@ -313,8 +315,8 @@ async def handle_request_streaming( if e is not None: raise e from None finally: - if receiver is not None: - receiver.stop() + if receiver_task is not None: + receiver_task.cancel() if handle_request_task is not None and not handle_request_task.done(): handle_request_task.cancel() diff --git a/python/ray/serve/_private/utils.py b/python/ray/serve/_private/utils.py index eeb169ef1ead..ccf11eac599b 100644 --- a/python/ray/serve/_private/utils.py +++ b/python/ray/serve/_private/utils.py @@ -703,3 +703,23 @@ def get_head_node_id() -> str: assert head_node_id is not None, "Cannot find alive head node." return head_node_id + + +def calculate_remaining_timeout( + *, + timeout_s: Optional[float], + start_time_s: float, + curr_time_s: float, +) -> Optional[float]: + """Get the timeout remaining given an overall timeout, start time, and curr time. + + If the timeout passed in was `None` or negative, will always return that timeout + directly. + + If the timeout is >= 0, the returned remaining timeout always be >= 0. + """ + if timeout_s is None or timeout_s < 0: + return timeout_s + + time_since_start_s = curr_time_s - start_time_s + return max(0, timeout_s - time_since_start_s) diff --git a/python/ray/serve/config.py b/python/ray/serve/config.py index 655b0ee93713..15b6926851b4 100644 --- a/python/ray/serve/config.py +++ b/python/ray/serve/config.py @@ -175,7 +175,7 @@ class DeploymentConfig(BaseModel): default=None, update_type=DeploymentOptionUpdateType.LightWeight ) - # This flag is used to let replica know they are deplyed from + # This flag is used to let replica know they are deployed from # a different language. is_cross_language: bool = False diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index af9c7cf32213..1d1b224470c8 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -458,6 +458,9 @@ def deploy( docs_path: Optional[str] = None, is_driver_deployment: Optional[bool] = False, app_name: str = None, + # TODO(edoakes): this is a hack because the deployment_language doesn't seem + # to get set properly from Java. + is_deployed_from_python: bool = False, ) -> bool: """Deploys a deployment.""" if route_prefix is not None: @@ -486,7 +489,11 @@ def deploy( updating = self.deployment_state_manager.deploy(name, deployment_info) if route_prefix is not None: - endpoint_info = EndpointInfo(route=route_prefix, app_name=app_name) + endpoint_info = EndpointInfo( + route=route_prefix, + app_name=app_name, + app_is_cross_language=not is_deployed_from_python, + ) self.endpoint_state.update_endpoint(name, endpoint_info) else: self.endpoint_state.delete_endpoint(name) diff --git a/python/ray/serve/tests/conftest.py b/python/ray/serve/tests/conftest.py index b1e6812fa96c..142f7606bb84 100644 --- a/python/ray/serve/tests/conftest.py +++ b/python/ray/serve/tests/conftest.py @@ -113,3 +113,35 @@ def ray_start_stop(): check_ray_stop, timeout=15, ) + + +@pytest.fixture +def ray_instance(request): + """Starts and stops a Ray instance for this test. + + Args: + request: request.param should contain a dictionary of env vars and + their values. The Ray instance will be started with these env vars. + """ + + original_env_vars = os.environ.copy() + + try: + requested_env_vars = request.param + except AttributeError: + requested_env_vars = {} + + os.environ.update(requested_env_vars) + + yield ray.init( + _metrics_export_port=9999, + _system_config={ + "metrics_report_interval_ms": 1000, + "task_retry_delay_ms": 50, + }, + ) + + ray.shutdown() + + os.environ.clear() + os.environ.update(original_env_vars) diff --git a/python/ray/serve/tests/test_cli.py b/python/ray/serve/tests/test_cli.py index b52a787367c9..21a3a633592e 100644 --- a/python/ray/serve/tests/test_cli.py +++ b/python/ray/serve/tests/test_cli.py @@ -1455,8 +1455,8 @@ def test_run_config_request_timeout(): # Ensure the http request is killed and failed when the deployment runs longer than # the 0.1 request_timeout_s set in in the config yaml wait_for_condition( - lambda: requests.get("http://localhost:8000/app1?sleep_s=0.11").text - == "Task failed with 1 retries.", + lambda: requests.get("http://localhost:8000/app1?sleep_s=0.11").status_code + == 500, ) # Ensure the http request returned the correct response when the deployment runs diff --git a/python/ray/serve/tests/test_experimental_streaming.py b/python/ray/serve/tests/test_experimental_streaming.py deleted file mode 100644 index 2558426752e8..000000000000 --- a/python/ray/serve/tests/test_experimental_streaming.py +++ /dev/null @@ -1,21 +0,0 @@ -import pytest -from pathlib import Path -import sys - -from ray.serve._private.constants import RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING - -if __name__ == "__main__": - curr_dir = Path(__file__).parent - test_paths = curr_dir.rglob("test_*.py") - sorted_path = sorted(map(lambda path: str(path.absolute()), test_paths)) - serve_tests_files = list(sorted_path) - - print("Testing the following files") - for test_file in serve_tests_files: - print("->", test_file.split("/")[-1]) - - assert ( - RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING - ), "RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING=1 must be set." - - sys.exit(pytest.main(["-v", "-s"] + serve_tests_files)) diff --git a/python/ray/serve/tests/test_handle.py b/python/ray/serve/tests/test_handle.py index 1762ccf20438..779d4898bd32 100644 --- a/python/ray/serve/tests/test_handle.py +++ b/python/ray/serve/tests/test_handle.py @@ -10,8 +10,9 @@ from ray import serve from ray.serve.exceptions import RayServeException from ray.serve._private.constants import ( - SERVE_DEFAULT_APP_NAME, DEPLOYMENT_NAME_PREFIX_SEPARATOR, + RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING, + SERVE_DEFAULT_APP_NAME, ) from ray.serve.context import get_global_client @@ -120,6 +121,9 @@ async def __call__(self, _): assert requests.get("http://127.0.0.1:8000/Endpoint2").text == "hello" +@pytest.mark.skipif( + RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING, reason="Not supported w/ streaming." +) def test_handle_inject_starlette_request(serve_instance): @serve.deployment(name="echo") def echo_request_type(request): diff --git a/python/ray/serve/tests/test_http_prefix_matching.py b/python/ray/serve/tests/test_http_prefix_matching.py index a0655d0851e0..a4f16c88f901 100644 --- a/python/ray/serve/tests/test_http_prefix_matching.py +++ b/python/ray/serve/tests/test_http_prefix_matching.py @@ -15,19 +15,24 @@ def mock_get_handle(name, *args, **kwargs): def test_no_match(mock_longest_prefix_router): router = mock_longest_prefix_router router.update_routes({"endpoint": EndpointInfo(route="/hello", app_name="")}) - route, handle, app_name = router.match_route("/nonexistent") - assert route is None and handle is None and app_name is None + assert router.match_route("/nonexistent") is None def test_default_route(mock_longest_prefix_router): router = mock_longest_prefix_router router.update_routes({"endpoint": EndpointInfo(route="/endpoint", app_name="")}) - route, handle, app_name = router.match_route("/nonexistent") - assert route is None and handle is None and app_name is None + assert router.match_route("/nonexistent") is None - route, handle, app_name = router.match_route("/endpoint") - assert route == "/endpoint" and handle == "endpoint" and app_name == "" + route, handle, app_name, app_is_cross_language = router.match_route("/endpoint") + assert all( + [ + route == "/endpoint", + handle == "endpoint", + app_name == "", + not app_is_cross_language, + ] + ) def test_trailing_slash(mock_longest_prefix_router): @@ -38,7 +43,7 @@ def test_trailing_slash(mock_longest_prefix_router): } ) - route, handle, _ = router.match_route("/test/") + route, handle, _, _ = router.match_route("/test/") assert route == "/test" and handle == "endpoint" router.update_routes( @@ -47,8 +52,7 @@ def test_trailing_slash(mock_longest_prefix_router): } ) - route, handle, app_name = router.match_route("/test") - assert route is None and handle is None and app_name is None + assert router.match_route("/test") is None def test_prefix_match(mock_longest_prefix_router): @@ -61,23 +65,23 @@ def test_prefix_match(mock_longest_prefix_router): } ) - route, handle, _ = router.match_route("/test/test2/subpath") + route, handle, _, _ = router.match_route("/test/test2/subpath") assert route == "/test/test2" and handle == "endpoint1" - route, handle, _ = router.match_route("/test/test2/") + route, handle, _, _ = router.match_route("/test/test2/") assert route == "/test/test2" and handle == "endpoint1" - route, handle, _ = router.match_route("/test/test2") + route, handle, _, _ = router.match_route("/test/test2") assert route == "/test/test2" and handle == "endpoint1" - route, handle, _ = router.match_route("/test/subpath") + route, handle, _, _ = router.match_route("/test/subpath") assert route == "/test" and handle == "endpoint2" - route, handle, _ = router.match_route("/test/") + route, handle, _, _ = router.match_route("/test/") assert route == "/test" and handle == "endpoint2" - route, handle, _ = router.match_route("/test") + route, handle, _, _ = router.match_route("/test") assert route == "/test" and handle == "endpoint2" - route, handle, _ = router.match_route("/test2") + route, handle, _, _ = router.match_route("/test2") assert route == "/" and handle == "endpoint3" - route, handle, _ = router.match_route("/") + route, handle, _, _ = router.match_route("/") assert route == "/" and handle == "endpoint3" @@ -85,18 +89,38 @@ def test_update_routes(mock_longest_prefix_router): router = mock_longest_prefix_router router.update_routes({"endpoint": EndpointInfo(route="/endpoint", app_name="app1")}) - route, handle, app_name = router.match_route("/endpoint") - assert route == "/endpoint" and handle == "endpoint" and app_name == "app1" + route, handle, app_name, app_is_cross_language = router.match_route("/endpoint") + assert all( + [ + route == "/endpoint", + handle == "endpoint", + app_name == "app1", + not app_is_cross_language, + ] + ) router.update_routes( - {"endpoint2": EndpointInfo(route="/endpoint2", app_name="app2")} + { + "endpoint2": EndpointInfo( + route="/endpoint2", + app_name="app2", + app_is_cross_language=True, + ) + } ) - route, handle, app_name = router.match_route("/endpoint") - assert route is None and handle is None and app_name is None + assert router.match_route("/endpoint") is None - route, handle, app_name = router.match_route("/endpoint2") + route, handle, app_name, app_is_cross_language = router.match_route("/endpoint2") assert route == "/endpoint2" and handle == "endpoint2" and app_name == "app2" + assert all( + [ + route == "/endpoint2", + handle == "endpoint2", + app_name == "app2", + app_is_cross_language, + ] + ) if __name__ == "__main__": diff --git a/python/ray/serve/tests/test_http_routes.py b/python/ray/serve/tests/test_http_routes.py index 69c76e221b35..0803a4d52558 100644 --- a/python/ray/serve/tests/test_http_routes.py +++ b/python/ray/serve/tests/test_http_routes.py @@ -305,7 +305,6 @@ def h(): serve.run(h.bind()) r = requests.get("http://localhost:8000/h") assert r.status_code == 500 - assert "retries" in r.text, r.text if __name__ == "__main__": diff --git a/python/ray/serve/tests/test_http_util.py b/python/ray/serve/tests/test_http_util.py index f666b468e59a..dc1719e3e42a 100644 --- a/python/ray/serve/tests/test_http_util.py +++ b/python/ray/serve/tests/test_http_util.py @@ -62,6 +62,7 @@ async def test_asgi_message_queue(): @pytest.fixture +@pytest.mark.asyncio def setup_receive_proxy( shared_ray_instance, ) -> Generator[Tuple[ASGIReceiveProxy, ActorHandle], None, None]: @@ -88,12 +89,12 @@ async def receive_asgi_messages(self, request_id: str) -> bytes: actor = ASGIReceive.remote() ray.get(actor.ready.remote()) loop = get_or_create_event_loop() - asgi_receive_proxy = ASGIReceiveProxy(loop, "", actor) - asgi_receive_proxy.start() + asgi_receive_proxy = ASGIReceiveProxy("", actor) + receiver_task = loop.create_task(asgi_receive_proxy.fetch_until_disconnect()) try: yield asgi_receive_proxy, actor except Exception: - asgi_receive_proxy.stop() + receiver_task.cancel() @pytest.mark.asyncio diff --git a/python/ray/serve/tests/test_regression.py b/python/ray/serve/tests/test_regression.py index ac0b4b24607b..4f5b002e5e26 100644 --- a/python/ray/serve/tests/test_regression.py +++ b/python/ray/serve/tests/test_regression.py @@ -1,5 +1,6 @@ -import gc import asyncio +import gc +import sys import numpy as np import requests @@ -78,28 +79,37 @@ async def __call__(self, *args): assert result.json() == 100.0 +@pytest.mark.skipif( + sys.version_info.major >= 3 and sys.version_info.minor <= 7, + reason="Failing on Python 3.7 due to different GC behavior.", +) def test_replica_memory_growth(serve_instance): # https://github.com/ray-project/ray/issues/12395 - @serve.deployment(name="model") + @serve.deployment def gc_unreachable_objects(*args): gc.set_debug(gc.DEBUG_SAVEALL) gc.collect() - return len(gc.garbage) + gc_garbage_len = len(gc.garbage) + if gc_garbage_len > 0: + print(gc.garbage) + return gc_garbage_len handle = serve.run(gc_unreachable_objects.bind()) + def get_gc_garbage_len_http(): + result = requests.get("http://127.0.0.1:8000") + assert result.status_code == 200 + return result.json() + # We are checking that there's constant number of object in gc. - known_num_objects = ray.get(handle.remote()) + known_num_objects_from_http = get_gc_garbage_len_http() for _ in range(10): - result = requests.get("http://127.0.0.1:8000/model") - assert result.status_code == 200 - num_unreachable_objects = result.json() - assert num_unreachable_objects == known_num_objects + assert get_gc_garbage_len_http() == known_num_objects_from_http + known_num_objects_from_handle = ray.get(handle.remote()) for _ in range(10): - num_unreachable_objects = ray.get(handle.remote()) - assert num_unreachable_objects == known_num_objects + assert ray.get(handle.remote()) == known_num_objects_from_handle def test_ref_in_handle_input(serve_instance): diff --git a/python/ray/serve/tests/test_request_timeout.py b/python/ray/serve/tests/test_request_timeout.py new file mode 100644 index 000000000000..5791881ad98e --- /dev/null +++ b/python/ray/serve/tests/test_request_timeout.py @@ -0,0 +1,150 @@ +import asyncio +import os +import sys +from typing import Set + +import pytest +import requests + +import ray +from ray._private.test_utils import SignalActor + +from ray import serve +from ray.serve._private.constants import RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING + + +@ray.remote +def do_request(): + return requests.get("http://localhost:8000") + + +@pytest.fixture +def shutdown_serve(): + yield + serve.shutdown() + + +@pytest.mark.parametrize( + "ray_instance", + [ + {"RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S": "5"}, + ], + indirect=True, +) +def test_normal_operation(ray_instance, shutdown_serve): + """ + Verify that a moderate timeout doesn't affect normal operation. + """ + + @serve.deployment(num_replicas=2) + def f(*args): + return "Success!" + + serve.run(f.bind()) + + assert all( + response.text == "Success!" + for response in ray.get([do_request.remote() for _ in range(10)]) + ) + + +@pytest.mark.parametrize( + "ray_instance", + [ + {"RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S": "0.1"}, + ], + indirect=True, +) +def test_request_hangs_in_execution(ray_instance, shutdown_serve): + """ + Verify that requests are timed out if they take longer than the timeout to execute. + """ + + @ray.remote + class PidTracker: + def __init__(self): + self.pids = set() + + def add_pid(self, pid: int) -> None: + self.pids.add(pid) + + def get_pids(self) -> Set[int]: + return self.pids + + pid_tracker = PidTracker.remote() + signal_actor = SignalActor.remote() + + @serve.deployment(num_replicas=2, graceful_shutdown_timeout_s=0) + class HangsOnFirstRequest: + def __init__(self): + self._saw_first_request = False + + async def __call__(self): + ray.get(pid_tracker.add_pid.remote(os.getpid())) + if not self._saw_first_request: + self._saw_first_request = True + await asyncio.sleep(10) + + return "Success!" + + serve.run(HangsOnFirstRequest.bind()) + + response = requests.get("http://localhost:8000") + if RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING: + assert response.status_code == 500 + else: + assert response.status_code == 200 + assert response.text == "Success!" + + # Hanging request should have been retried on a different replica. + assert len(ray.get(pid_tracker.get_pids.remote())) == 2 + + ray.get(signal_actor.send.remote()) + + +@pytest.mark.parametrize( + "ray_instance", + [ + {"RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S": "0.1"}, + ], + indirect=True, +) +def test_request_hangs_in_assignment(ray_instance, shutdown_serve): + """ + Verify that requests are timed out if they take longer than the timeout while + pending assignment (queued in the handle). + """ + signal_actor = SignalActor.remote() + + @serve.deployment(graceful_shutdown_timeout_s=0, max_concurrent_queries=1) + class HangsOnFirstRequest: + def __init__(self): + self._saw_first_request = False + + async def __call__(self): + await signal_actor.wait.remote() + return "Success!" + + serve.run(HangsOnFirstRequest.bind()) + + # First request will hang executing, second pending assignment. + response_ref1 = do_request.remote() + response_ref2 = do_request.remote() + + if RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING: + # Streaming path does not retry on timeouts, so the requests should be failed. + assert ray.get(response_ref1).status_code == 500 + assert ray.get(response_ref2).status_code == 500 + ray.get(signal_actor.send.remote()) + assert ray.get(do_request.remote()).status_code == 200 + else: + # Legacy path retries on timeouts, so the requests should succeed. + ray.get(signal_actor.send.remote()) + assert ray.get(response_ref1).status_code == 200 + assert ray.get(response_ref1).text == "Success!" + assert ray.get(response_ref2).status_code == 200 + assert ray.get(response_ref2).text == "Success!" + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_standalone2.py b/python/ray/serve/tests/test_standalone2.py index 0ccb9d482329..7a9953a8adee 100644 --- a/python/ray/serve/tests/test_standalone2.py +++ b/python/ray/serve/tests/test_standalone2.py @@ -3,8 +3,7 @@ import sys import time from contextlib import contextmanager -from typing import Dict, Set -from concurrent.futures.thread import ThreadPoolExecutor +from typing import Dict from functools import partial from tempfile import NamedTemporaryFile @@ -48,38 +47,6 @@ def shutdown_ray(): ray.shutdown() -@pytest.fixture() -def ray_instance(request): - """Starts and stops a Ray instance for this test. - - Args: - request: request.param should contain a dictionary of env vars and - their values. The Ray instance will be started with these env vars. - """ - - original_env_vars = os.environ.copy() - - try: - requested_env_vars = request.param - except AttributeError: - requested_env_vars = {} - - os.environ.update(requested_env_vars) - - yield ray.init( - _metrics_export_port=9999, - _system_config={ - "metrics_report_interval_ms": 1000, - "task_retry_delay_ms": 50, - }, - ) - - ray.shutdown() - - os.environ.clear() - os.environ.update(original_env_vars) - - @contextmanager def start_and_shutdown_ray_cli(): subprocess.check_output(["ray", "stop", "--force"]) @@ -1662,87 +1629,5 @@ def test_deployments_not_listed_in_config(self, client: ServeControllerClient): assert all(pid == pid1 for pid in pids) -class TestServeRequestProcessingTimeoutS: - @pytest.mark.parametrize( - "ray_instance", - [ - {"RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S": "5"}, - {"SERVE_REQUEST_PROCESSING_TIMEOUT_S": "5"}, - { - "RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S": "5", - "SERVE_REQUEST_PROCESSING_TIMEOUT_S": "0", - }, - ], - indirect=True, - ) - def test_normal_operation(self, ray_instance): - """Checks that a moderate timeout doesn't affect normal operation.""" - - @serve.deployment(num_replicas=2) - def f(*args): - return "Success!" - - serve.run(f.bind()) - - for _ in range(20): - requests.get("http://localhost:8000").text == "Success!" - - serve.shutdown() - - @pytest.mark.parametrize( - "ray_instance", - [ - {"RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S": "0.1"}, - {"SERVE_REQUEST_PROCESSING_TIMEOUT_S": "0.1"}, - { - "RAY_SERVE_REQUEST_PROCESSING_TIMEOUT_S": "0.1", - "SERVE_REQUEST_PROCESSING_TIMEOUT_S": "0", - }, - ], - indirect=True, - ) - def test_hanging_request(self, ray_instance): - """Checks that the env var mitigates the hang.""" - - @ray.remote - class PidTracker: - def __init__(self): - self.pids = set() - - def add_pid(self, pid: int) -> None: - self.pids.add(pid) - - def get_pids(self) -> Set[int]: - return self.pids - - pid_tracker = PidTracker.remote() - signal_actor = SignalActor.remote() - - @serve.deployment(num_replicas=2) - async def waiter(*args): - import os - - ray.get(pid_tracker.add_pid.remote(os.getpid())) - await signal_actor.wait.remote() - return "Success!" - - serve.run(waiter.bind()) - - with ThreadPoolExecutor() as pool: - response_fut = pool.submit(requests.get, "http://localhost:8000") - - # Force request to hang - time.sleep(0.5) - ray.get(signal_actor.send.remote()) - - wait_for_condition(lambda: response_fut.done()) - assert response_fut.result().text == "Success!" - - # Hanging request should have been retried - assert len(ray.get(pid_tracker.get_pids.remote())) == 2 - - serve.shutdown() - - if __name__ == "__main__": sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_standalone3.py b/python/ray/serve/tests/test_standalone3.py index 444d1ee9d020..0fbfdf457ea4 100644 --- a/python/ray/serve/tests/test_standalone3.py +++ b/python/ray/serve/tests/test_standalone3.py @@ -37,38 +37,6 @@ def shutdown_ray(): ray.shutdown() -@pytest.fixture() -def ray_instance(request): - """Starts and stops a Ray instance for this test. - - Args: - request: request.param should contain a dictionary of env vars and - their values. The Ray instance will be started with these env vars. - """ - - original_env_vars = os.environ.copy() - - try: - requested_env_vars = request.param - except AttributeError: - requested_env_vars = {} - - os.environ.update(requested_env_vars) - - yield ray.init( - _metrics_export_port=9999, - _system_config={ - "metrics_report_interval_ms": 1000, - "task_retry_delay_ms": 50, - }, - ) - - ray.shutdown() - - os.environ.clear() - os.environ.update(original_env_vars) - - @contextmanager def start_and_shutdown_ray_cli(): subprocess.check_output( diff --git a/python/ray/serve/tests/test_util.py b/python/ray/serve/tests/test_util.py index a5e477476b96..fb95492bb752 100644 --- a/python/ray/serve/tests/test_util.py +++ b/python/ray/serve/tests/test_util.py @@ -13,12 +13,13 @@ import ray from ray import serve from ray.serve._private.utils import ( + calculate_remaining_timeout, get_deployment_import_path, - override_runtime_envs_except_env_vars, - serve_encoders, merge_dict, msgpack_serialize, msgpack_deserialize, + override_runtime_envs_except_env_vars, + serve_encoders, snake_to_camel_case, dict_keys_snake_to_camel_case, get_head_node_id, @@ -567,6 +568,56 @@ def test_get_head_node_id(): get_head_node_id() +def test_calculate_remaining_timeout(): + # Always return `None` or negative value. + assert ( + calculate_remaining_timeout( + timeout_s=None, + start_time_s=100, + curr_time_s=101, + ) + is None + ) + + assert ( + calculate_remaining_timeout( + timeout_s=-1, + start_time_s=100, + curr_time_s=101, + ) + == -1 + ) + + # Return delta from start. + assert ( + calculate_remaining_timeout( + timeout_s=10, + start_time_s=100, + curr_time_s=101, + ) + == 9 + ) + + assert ( + calculate_remaining_timeout( + timeout_s=100, + start_time_s=100, + curr_time_s=101.1, + ) + == 98.9 + ) + + # Never return a negative timeout once it has elapsed. + assert ( + calculate_remaining_timeout( + timeout_s=10, + start_time_s=100, + curr_time_s=111, + ) + == 0 + ) + + if __name__ == "__main__": import sys