Skip to content

Commit

Permalink
[serve, vllm] add vllm-example that we can reference to (ray-project#…
Browse files Browse the repository at this point in the history
…36617)

This adds a vllm example on serve that we can refer to.
---------

Signed-off-by: Chen Shen <[email protected]>
Co-authored-by: shrekris-anyscale <[email protected]>
  • Loading branch information
scv119 and shrekris-anyscale authored Jun 21, 2023
1 parent ee64dbc commit cc983fc
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .buildkite/pipeline.gpu_large.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
["NO_WHEELS_REQUIRED", "RAY_CI_PYTHON_AFFECTED", "RAY_CI_TUNE_AFFECTED", "RAY_CI_DOC_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- DOC_TESTING=1 TRAIN_TESTING=1 TUNE_TESTING=1 ./ci/env/install-dependencies.sh
- DOC_TESTING=1 TRAIN_TESTING=1 TUNE_TESTING=1 INSTALL_VLLM=1 ./ci/env/install-dependencies.sh
- pip install -Ur ./python/requirements/ml/requirements_ml_docker.txt
- ./ci/env/env_info.sh
# Test examples with newer version of `transformers`
Expand Down
5 changes: 5 additions & 0 deletions ci/env/install-dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,11 @@ install_pip_packages() {
requirements_packages+=("holidays==0.24") # holidays 0.25 causes `import prophet` to fail.
fi

# Additional dependency for vllm.
if [ "${INSTALL_VLLM-}" = 1 ]; then
requirements_packages+=("vllm")
fi

# Data processing test dependencies.
if [ "${DATA_PROCESSING_TESTING-}" = 1 ] || [ "${DOC_TESTING-}" = 1 ]; then
requirements_files+=("${WORKSPACE_DIR}/python/requirements/data_processing/requirements.txt")
Expand Down
10 changes: 10 additions & 0 deletions doc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ py_test_run_all_subdirectory(
"source/ray-air/doc_code/computer_vision.py",
"source/ray-air/doc_code/hf_trainer.py", # Too large
"source/ray-air/doc_code/predictors.py",
"source/serve/doc_code/vllm_example.py", # Requires GPU
],
extra_srcs = [],
tags = ["exclusive", "team:ml"],
Expand Down Expand Up @@ -279,6 +280,15 @@ py_test_run_all_subdirectory(
tags = ["exclusive", "team:ml", "ray_air", "gpu"],
)

py_test(
name = "vllm_example",
size = "large",
include = ["source/serve/doc_code/vllm_example.py"],
exclude = [],
extra_srcs = [],
tags = ["exclusive", "team:serve", "gpu"],
)

py_test(
name = "pytorch_resnet_finetune",
size = "large",
Expand Down
126 changes: 126 additions & 0 deletions doc/source/serve/doc_code/vllm_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import json
from typing import AsyncGenerator

from fastapi import BackgroundTasks
from starlette.requests import Request
from starlette.responses import StreamingResponse, Response
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid

from ray import serve


@serve.deployment(ray_actor_options={"num_gpus": 1})
class VLLMPredictDeployment:
def __init__(self, **kwargs):
"""
Construct a VLLM deployment.
Refer to https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py
for the full list of arguments.
Args:
model: name or path of the huggingface model to use
download_dir: directory to download and load the weights,
default to the default cache dir of huggingface.
use_np_weights: save a numpy copy of model weights for
faster loading. This can increase the disk usage by up to 2x.
use_dummy_weights: use dummy values for model weights.
dtype: data type for model weights and activations.
The "auto" option will use FP16 precision
for FP32 and FP16 models, and BF16 precision.
for BF16 models.
seed: random seed.
worker_use_ray: use Ray for distributed serving, will be
automatically set when using more than 1 GPU
pipeline_parallel_size: number of pipeline stages.
tensor_parallel_size: number of tensor parallel replicas.
block_size: token block size.
swap_space: CPU swap space size (GiB) per GPU.
gpu_memory_utilization: the percentage of GPU memory to be used for
the model executor
max_num_batched_tokens: maximum number of batched tokens per iteration
max_num_seqs: maximum number of sequences per iteration.
disable_log_stats: disable logging statistics.
engine_use_ray: use Ray to start the LLM engine in a separate
process as the server process.
disable_log_requests: disable logging requests.
"""
args = AsyncEngineArgs(**kwargs)
self.engine = AsyncLLMEngine.from_engine_args(args)

async def stream_results(self, results_generator) -> AsyncGenerator[bytes, None]:
num_returned = 0
async for request_output in results_generator:
text_outputs = [output.text for output in request_output.outputs]
assert len(text_outputs) == 1
text_output = text_outputs[0][num_returned:]
ret = {"text": text_output}
yield (json.dumps(ret) + "\n").encode("utf-8")
num_returned += len(text_output)

async def may_abort_request(self, request_id) -> None:
await self.engine.abort(request_id)

async def __call__(self, request: Request) -> Response:
"""Generate completion for the request.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json()
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
results_generator = self.engine.generate(prompt, sampling_params, request_id)
if stream:
background_tasks = BackgroundTasks()
# Using background_taks to abort the the request
# if the client disconnects.
background_tasks.add_task(self.may_abort_request, request_id)
return StreamingResponse(
self.stream_results(results_generator), background=background_tasks
)

# Non-streaming case
final_output = None
async for request_output in results_generator:
if await request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(request_id)
return Response(status_code=499)
final_output = request_output

assert final_output is not None
prompt = final_output.prompt
text_outputs = [prompt + output.text for output in final_output.outputs]
ret = {"text": text_outputs}
return Response(content=json.dumps(ret))


def send_sample_request():
import requests

prompt = "How do I cook fried rice?"
sample_input = {"prompt": prompt, "stream": True}
output = requests.post("http://localhost:8000/", json=sample_input)
for line in output.iter_lines():
print(line.decode("utf-8"))


if __name__ == "__main__":
# To run this example, you need to install vllm which requires
# OS: Linux
# Python: 3.8 or higher
# CUDA: 11.0 – 11.8
# GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, etc.)
# see https://vllm.readthedocs.io/en/latest/getting_started/installation.html
# for more details.
deployment = VLLMPredictDeployment.bind(model="facebook/opt-125m")
serve.run(deployment)
send_sample_request()

0 comments on commit cc983fc

Please sign in to comment.