Skip to content

Commit

Permalink
trust remote code on tokenizier
Browse files Browse the repository at this point in the history
  • Loading branch information
l4b4r4b4b4 committed Aug 21, 2024
1 parent bf5e8d5 commit 165de29
Showing 1 changed file with 73 additions and 2 deletions.
75 changes: 73 additions & 2 deletions server_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
import re
from typing import Any, AsyncGenerator, Dict, List, Literal, Optional, Tuple, Union
import logging
from pydantic import BaseModel
import requests

import fastapi
import uvicorn
from fastapi import Request
from fastapi import Request, HTTPException, Header
from fastapi.middleware.cors import CORSMiddleware
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import ModelCard, ModelList, ModelPermission
Expand All @@ -46,6 +48,73 @@
app = fastapi.FastAPI()


@app.get("/healthz")
async def get_health_and_readiness():
"""
Indicate service health and readiness.
Returns:
dict: A dictionary containing the vLLM inference service's readiness and health status.
- "ready" (bool): Indicates if the service is ready to accept requests.
- "health" (bool): Indicates if the service is healthy and operational.
"""
return {"ready": True, "health": True}


class EmbeddingData(BaseModel):
object: str = "embedding"
index: int
embedding: List[float]


class UsageData(BaseModel):
prompt_tokens: int
total_tokens: int


class EmbeddingResponse(BaseModel):
object: str = "list"
data: List[EmbeddingData]
model: str
usage: UsageData


class EmbeddingRequest(BaseModel):
input: Union[str, List[str]]
model: str


@app.post("/v1/embeddings", response_model=EmbeddingResponse)
async def get_embedding(request: EmbeddingRequest, authorization: str = Header(None)):
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Unauthorized")

# Handle both single string and list of strings
inputs = [request.input] if isinstance(request.input, str) else request.input

url = "http://embeddings:8080/embed"
headers = {"Content-Type": "application/json"}
data = {"inputs": inputs}

response = requests.post(url, headers=headers, json=data)
embeddings = response.json()

# Construct the response data
data = [
EmbeddingData(object="embedding", index=i, embedding=embeddings[i])
for i in range(len(inputs))
]

response = EmbeddingResponse(
object="list",
data=data,
model=request.model,
usage=UsageData(prompt_tokens=len(inputs) * 5, total_tokens=len(inputs) * 5),
)

return response


@app.get("/v1/models")
async def show_available_models():
"""Show available models. Right now we only have one model."""
Expand Down Expand Up @@ -137,7 +206,9 @@ async def create_chat_completion(raw_request: Request):
engine_args = AsyncEngineArgs.from_cli_args(args)
# A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(
engine_args.tokenizer, tokenizer_mode=engine_args.tokenizer_mode
engine_args.tokenizer,
tokenizer_mode=engine_args.tokenizer_mode,
trust_remote_code=engine_args.trust_remote_code,
)
# Overwrite vLLM's default ModelConfig.max_logprobs of 5
engine_args.max_logprobs = len(tokenizer.vocab.keys())
Expand Down

0 comments on commit 165de29

Please sign in to comment.