Skip to content

Commit

Permalink
[Serve] Use serve_instance for test_http_headers to reduce the test t…
Browse files Browse the repository at this point in the history
…ime (ray-project#36988)

rewrite test to use shared instance instead of ray shutdown.
In my local, time is reduced from 40s to 23s.
  • Loading branch information
sihanwang41 authored Jun 30, 2023
1 parent f40d236 commit 0d61731
Showing 1 changed file with 25 additions and 23 deletions.
48 changes: 25 additions & 23 deletions python/ray/serve/tests/test_http_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ray.serve._private.constants import RAY_SERVE_REQUEST_ID_HEADER


def test_request_id_header_by_default(ray_shutdown):
def test_request_id_header_by_default(serve_instance):
"""Test that a request_id is generated by default and returned as a header."""

@serve.deployment
Expand All @@ -24,50 +24,52 @@ def __call__(self):
assert resp.text == resp.headers[RAY_SERVE_REQUEST_ID_HEADER]


@pytest.mark.parametrize("deploy_type", ["basic", "fastapi", "starlette_resp"])
def test_user_provided_request_id_header(ray_shutdown, deploy_type):
"""Test that a user-provided request_id is propagated to the
replica and returned as a header."""

if deploy_type == "fastapi":
app = FastAPI()
class TestUserProvidedRequestIDHeader:
def verify_result(self):
resp = requests.get(
"http://localhost:8000", headers={RAY_SERVE_REQUEST_ID_HEADER: "123-234"}
)
assert resp.status_code == 200
assert resp.json() == 1
assert RAY_SERVE_REQUEST_ID_HEADER in resp.headers
assert resp.headers[RAY_SERVE_REQUEST_ID_HEADER] == "123-234"

def test_basic(self, serve_instance):
@serve.deployment
@serve.ingress(app)
class Model:
@app.get("/")
def say_hi(self) -> int:
def __call__(self) -> int:
request_id = ray.serve.context._serve_request_context.get().request_id
assert request_id == "123-234"
return 1

elif deploy_type == "basic":
serve.run(Model.bind())
self.verify_result()

def test_fastapi(self, serve_instance):
app = FastAPI()

@serve.deployment
@serve.ingress(app)
class Model:
def __call__(self) -> int:
@app.get("/")
def say_hi(self) -> int:
request_id = ray.serve.context._serve_request_context.get().request_id
assert request_id == "123-234"
return 1

else:
serve.run(Model.bind())
self.verify_result()

def test_starlette_resp(self, serve_instance):
@serve.deployment
class Model:
def __call__(self) -> int:
request_id = ray.serve.context._serve_request_context.get().request_id
assert request_id == "123-234"
return starlette.responses.Response("1", media_type="application/json")

serve.run(Model.bind())

resp = requests.get(
"http://localhost:8000", headers={RAY_SERVE_REQUEST_ID_HEADER: "123-234"}
)
assert resp.status_code == 200
assert resp.json() == 1
assert RAY_SERVE_REQUEST_ID_HEADER in resp.headers
assert resp.headers[RAY_SERVE_REQUEST_ID_HEADER] == "123-234"
serve.run(Model.bind())
self.verify_result()


if __name__ == "__main__":
Expand Down

0 comments on commit 0d61731

Please sign in to comment.