Skip to content

Commit

Permalink
Implement /v1/chat/completions endpoint for CPU mode
Browse files Browse the repository at this point in the history
Signed-off-by: Johannes Plötner <[email protected]>
  • Loading branch information
Johannes Plötner authored and manyoso committed Mar 11, 2024
1 parent 61d6765 commit 026ee4e
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 8 deletions.
38 changes: 33 additions & 5 deletions gpt4all-api/gpt4all_api/app/api_v1/routes/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import time
from typing import List
from uuid import uuid4
from fastapi import APIRouter
from fastapi import APIRouter, HTTPException
from gpt4all import GPT4All
from pydantic import BaseModel, Field
from api_v1.settings import settings
from fastapi.responses import StreamingResponse
Expand All @@ -18,6 +19,7 @@ class ChatCompletionMessage(BaseModel):
class ChatCompletionRequest(BaseModel):
model: str = Field(settings.model, description='The model to generate a completion from.')
messages: List[ChatCompletionMessage] = Field(..., description='Messages for the chat completion.')
temperature: float = Field(settings.temp, description='Model temperature')

class ChatCompletionChoice(BaseModel):
message: ChatCompletionMessage
Expand Down Expand Up @@ -45,15 +47,41 @@ async def chat_completion(request: ChatCompletionRequest):
'''
Completes a GPT4All model response based on the last message in the chat.
'''
# Example: Echo the last message content with some modification
# GPU is not implemented yet
if settings.inference_mode == "gpu":
raise HTTPException(status_code=400,
detail=f"Not implemented yet: Can only infere in CPU mode.")

# we only support the configured model
if request.model != settings.model:
raise HTTPException(status_code=400,
detail=f"The GPT4All inference server is booted to only infer: `{settings.model}`")

# run only of we have a message
if request.messages:
last_message = request.messages[-1].content
response_content = f"Echo: {last_message}"
model = GPT4All(model_name=settings.model, model_path=settings.gpt4all_path)

# format system message and conversation history correctly
formatted_messages = ""
for message in request.messages:
formatted_messages += f"<|im_start|>{message.role}\n{message.content}<|im_end|>\n"

# the LLM will complete the response of the assistant
formatted_messages += "<|im_start|>assistant\n"
response = model.generate(
prompt=formatted_messages,
temp=request.temperature
)

# the LLM may continue to hallucinate the conversation, but we want only the first response
# so, cut off everything after first <|im_end|>
index = response.find("<|im_end|>")
response_content = response[:index].strip()
else:
response_content = "No messages received."

# Create a chat message for the response
response_message = ChatCompletionMessage(role="system", content=response_content)
response_message = ChatCompletionMessage(role="assistant", content=response_content)

# Create a choice object with the response message
response_choice = ChatCompletionChoice(
Expand Down
22 changes: 19 additions & 3 deletions gpt4all-api/gpt4all_api/app/tests/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_batched_completion():
model = model_id # replace with your specific model ID
prompt = "Who is Michael Jordan?"
responses = []

# Loop to create completions one at a time
for _ in range(3):
response = openai.Completion.create(
Expand All @@ -62,7 +62,7 @@ def test_batched_completion():
# Assertions to check the responses
for response in responses:
assert len(response['choices'][0]['text']) > len(prompt)

assert len(responses) == 3

def test_embedding():
Expand All @@ -74,4 +74,20 @@ def test_embedding():

assert response["model"] == model
assert isinstance(output, list)
assert all(isinstance(x, args) for x in output)
assert all(isinstance(x, args) for x in output)

def test_chat_completion():
model = model_id

response = openai.ChatCompletion.create(
model=model,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Knock knock."},
{"role": "assistant", "content": "Who's there?"},
{"role": "user", "content": "Orange."},
]
)

assert response.choices[0].message.role == "assistant"
assert len(response.choices[0].message.content) > 0

0 comments on commit 026ee4e

Please sign in to comment.