Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
feat(llm): add alternative support of OpenAI API (#163)
Browse files Browse the repository at this point in the history
* refactor(telemetry): update event name

* feat(llm): add full support of OpenAI API

* build(deps): add openai

* fix(docker): update docker var in prod

* feat(docker): add env vars

* test(llm): add tests for OpenAI support

* refactor(schemas): remove legacy schemas

* docs(llm): add mention of alternative LLM providers

* ci(push): update docker update

* test(llm): update error catching
  • Loading branch information
frgfm authored May 15, 2024
1 parent cd1b080 commit 0f2d66b
Show file tree
Hide file tree
Showing 17 changed files with 226 additions and 487 deletions.
4 changes: 4 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ LLM_PROVIDER='ollama'
OLLAMA_MODEL='dolphin-llama3:8b-v2.9-q4_K_M'
# Smaller option
# OLLAMA_MODEL='tinydolphin:1.1b-v2.8-q4_K_M'
GROQ_API_KEY=
GROQ_MODEL='llama3-8b-8192'
OPENAI_API_KEY=
OPENAI_MODEL='gpt-4o-2024-05-13'
LLM_TEMPERATURE=0
JWT_SECRET=
SENTRY_DSN=
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ jobs:
docker rmi -f $(docker images -f "dangling=true" -q)
docker volume rm -f $(docker volume ls -f "dangling=true" -q)
# Update the service
docker compose pull backend gradio
docker compose stop backend gradio && docker compose up -d --wait
docker compose pull backend chat
docker compose stop backend chat && docker compose up -d --wait
# Check update
docker inspect -f '{{ .Created }}' $(docker compose images -q backend)
docker inspect -f '{{ .Created }}' $(docker compose images -q gradio)
docker inspect -f '{{ .Created }}' $(docker compose images -q chat)
# Clean up
docker rm -fv $(docker ps -aq)
docker rmi -f $(docker images -f "dangling=true" -q)
Expand Down
10 changes: 8 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,12 @@ If you are wondering how to do something with Companion API, or a more general q
- [Git](https://git-scm.com/book/en/v2/Getting-Started-Installing-Git)
- [Docker](https://docs.docker.com/engine/install/)
- [Docker compose](https://docs.docker.com/compose/)
- [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) and a GPU (>= 6 Gb VRAM for good performance/latency balance)
- [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) and a GPU (>= 6 Gb VRAM for good performance/latency balance)*
- [Poetry](https://python-poetry.org/docs/)
- [Make](https://www.gnu.org/software/make/) (optional)

_*If you don't have a GPU, you can use alternative LLM providers (currently supported: Groq, OpenAI)_


### Configure your fork

Expand Down Expand Up @@ -130,8 +132,12 @@ This file contains all the information to run the project.
- `SUPERADMIN_PWD`: the password of the initial admin user

#### Other optional values
- `SECRET_KEY`: if set, tokens can be reused between sessions. All instances sharing the same secret key can use the same token.
- `JWT_SECRET`: if set, tokens can be reused between sessions. All instances sharing the same secret key can use the same token.
- `OLLAMA_MODEL`: the model tag in [Ollama library](https://ollama.com/library) that will be used for the API.
- `GROQ_API_KEY`: your [Groq API KEY](https://console.groq.com/keys), required if you select `groq` as `LLM_PROVIDER`.
- `GROQ_MODEL`: the model tag in [Groq supported models](https://console.groq.com/docs/models) that will be used for the API.
- `OPENAI_API_KEY`: your [OpenAI API KEY](https://platform.openai.com/api-keys), required if you select `openai` as `LLM_PROVIDER`.
- `OPENAI_MODEL`: the model tag in [OpenAI supported models](https://platform.openai.com/docs/models) that will be used for the API.
- `SENTRY_DSN`: the DSN for your [Sentry](https://sentry.io/) project, which monitors back-end errors and report them back.
- `SERVER_NAME`: the server tag that will be used to report events to Sentry.
- `POSTHOG_HOST`: the host for PostHog [PostHog](https://eu.posthog.com/settings/project-details).
Expand Down
2 changes: 1 addition & 1 deletion docker-compose.prod.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ services:
- POSTGRES_URL=postgresql+asyncpg://${POSTGRES_USER}:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB}
- SUPERADMIN_LOGIN=${SUPERADMIN_LOGIN}
- SUPERADMIN_PWD=${SUPERADMIN_PWD}
- SECRET_KEY=${SECRET_KEY}
- JWT_SECRET=${JWT_SECRET}
- OLLAMA_ENDPOINT=http://ollama:11434
- OLLAMA_MODEL=${OLLAMA_MODEL}
- OLLAMA_TIMEOUT=${OLLAMA_TIMEOUT:-60}
Expand Down
2 changes: 2 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ services:
- OLLAMA_MODEL=${OLLAMA_MODEL}
- GROQ_API_KEY=${GROQ_API_KEY}
- GROQ_MODEL=${GROQ_MODEL}
- OPENAI_API_KEY=${OPENAI_API_KEY}
- OPENAI_MODEL=${OPENAI_MODEL}
- OLLAMA_TIMEOUT=${OLLAMA_TIMEOUT:-60}
- SUPPORT_EMAIL=${SUPPORT_EMAIL}
- DEBUG=true
Expand Down
2 changes: 1 addition & 1 deletion docs/developers/self-hosting.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Whatever your installation method, you'll need at least the following to be inst
1. [Docker](https://docs.docker.com/engine/install/) (and [Docker compose](https://docs.docker.com/compose/) if you're using an old version)
2. [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) and a GPU

_We recommend min 5Gb of VRAM on your GPU for good performance/latency balance._
_We recommend min 5Gb of VRAM on your GPU for good performance/latency balance. Please note that by default, this will run your LLM locally (available offline) but if you don't have a GPU, you can use online LLM providers (currently supported: Groq, OpenAI)_

### 60 seconds setup ⏱️

Expand Down
36 changes: 35 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ posthog = "^3.0.0"
prometheus-fastapi-instrumentator = "^6.1.0"
groq = "^0.5.0"
ollama = "^0.1.9"
openai = "^1.29.0"
uvloop = "^0.19.0"
httptools = "^0.6.1"

Expand All @@ -40,6 +41,7 @@ optional = true
ruff = "==0.4.4"
mypy = "==1.10.0"
types-requests = ">=2.0.0"
types-urllib3 = ">=1.26.25"
types-passlib = ">=1.7.0"
pre-commit = "^3.6.0"

Expand Down
2 changes: 1 addition & 1 deletion src/app/api/api_v1/endpoints/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async def chat(
guidelines: GuidelineCRUD = Depends(get_guideline_crud),
token_payload: TokenPayload = Security(get_quack_jwt, scopes=[UserScope.ADMIN, UserScope.USER]),
) -> StreamingResponse:
telemetry_client.capture(token_payload.sub, event="compute-chat")
telemetry_client.capture(token_payload.sub, event="code-chat")
# Validate payload
if len(payload.messages) == 0:
raise HTTPException(
Expand Down
2 changes: 1 addition & 1 deletion src/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def sqlachmey_uri(cls, v: str) -> str:
GROQ_API_KEY: Union[str, None] = os.environ.get("GROQ_API_KEY")
GROQ_MODEL: str = os.environ.get("GROQ_MODEL", "llama3-8b-8192")
OPENAI_API_KEY: Union[str, None] = os.environ.get("OPENAI_API_KEY")
OPENAI_MODEL: str = os.environ.get("OPENAI_MODEL", "gpt-4-turbo-2024-04-09")
OPENAI_MODEL: str = os.environ.get("OPENAI_MODEL", "gpt-4o-2024-05-13")

# Error monitoring
SENTRY_DSN: Union[str, None] = os.environ.get("SENTRY_DSN")
Expand Down
56 changes: 5 additions & 51 deletions src/app/schemas/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,65 +4,21 @@
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

from enum import Enum
from typing import Any, Dict, List, Union
from typing import List, Union

from pydantic import BaseModel, HttpUrl

__all__ = ["ChatCompletion"]


class OpenAIModel(str, Enum):
# https://platform.openai.com/docs/models/overview
GPT3_5_TURBO: str = "gpt-3.5-turbo-0125"
GPT3_5_TURBO_LEGACY: str = "gpt-3.5-turbo-1106"
GPT4_TURBO: str = "gpt-4-0125-preview"
GPT4_TURBO_LEGACY: str = "gpt-4-1106-preview"


class OpenAIChatRole(str, Enum):
class ChatRole(str, Enum):
SYSTEM: str = "system"
USER: str = "user"
ASSISTANT: str = "assistant"


class FieldSchema(BaseModel):
type: str
description: str


class ObjectSchema(BaseModel):
type: str = "object"
properties: Dict[str, Any]
required: List[str]


class ArraySchema(BaseModel):
type: str = "array"
items: ObjectSchema


class OpenAIFunction(BaseModel):
name: str
description: str
parameters: ObjectSchema


class OpenAITool(BaseModel):
type: str = "function"
function: OpenAIFunction


class _FunctionName(BaseModel):
name: str


class _OpenAIToolChoice(BaseModel):
type: str = "function"
function: _FunctionName


class OpenAIMessage(BaseModel):
role: OpenAIChatRole
class ChatMessage(BaseModel):
role: ChatRole
content: str


Expand All @@ -72,9 +28,7 @@ class _ResponseFormat(BaseModel):

class ChatCompletion(BaseModel):
model: str
messages: List[OpenAIMessage]
functions: List[OpenAIFunction]
function_call: Dict[str, str]
messages: List[ChatMessage]
temperature: float = 0.0
frequency_penalty: float = 1.0
response_format: _ResponseFormat = _ResponseFormat(type="json_object")
Expand Down
8 changes: 2 additions & 6 deletions src/app/services/llm/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from groq import Groq, Stream
from groq.lib.chat_completion_chunk import ChatCompletionChunk

from .utils import CHAT_PROMPT

logger = logging.getLogger("uvicorn.error")


Expand All @@ -20,12 +22,6 @@ class GroqModel(str, Enum):
MIXTRAL_8X7b: str = "mixtral-8x7b-32768"


CHAT_PROMPT = (
"You are an AI programming assistant, developed by the company Quack AI, and you only answer questions related to computer science "
"(refuse to answer for the rest)."
)


class GroqClient:
def __init__(
self,
Expand Down
32 changes: 30 additions & 2 deletions src/app/services/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,48 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

import re
from enum import Enum
from typing import Union
from typing import Dict, Union

from fastapi import HTTPException, status

from app.core.config import settings

from .groq import GroqClient
from .ollama import OllamaClient
from .openai import OpenAIClient

__all__ = ["llm_client"]

EXAMPLE_PROMPT = (
"You are responsible for producing concise illustrations of the company coding guidelines. "
"This will be used to teach new developers our way of engineering software. "
"Make sure your code is in the specified programming language and functional, don't add extra comments or explanations.\n"
# Format
"You should output two code blocks: "
"a minimal code snippet where the instruction was correctly followed, "
"and the same snippet with minimal modifications that invalidates the instruction."
)
# Strangely, this doesn't work when compiled
EXAMPLE_PATTERN = r"```[a-zA-Z]*\n(?P<positive>.*?)```\n.*```[a-zA-Z]*\n(?P<negative>.*?)```"


def validate_example_response(response: str) -> Dict[str, str]:
matches = re.search(EXAMPLE_PATTERN, response.strip(), re.DOTALL)
if matches is None:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed output schema validation")

return matches.groupdict()


class LLMProvider(str, Enum):
OLLAMA: str = "ollama"
OPENAI: str = "openai"
GROQ: str = "groq"


llm_client: Union[OllamaClient, GroqClient]
llm_client: Union[OllamaClient, GroqClient, OpenAIClient]
if settings.LLM_PROVIDER == LLMProvider.OLLAMA:
if not settings.OLLAMA_ENDPOINT:
raise ValueError("Please provide a value for `OLLAMA_ENDPOINT`")
Expand All @@ -29,5 +53,9 @@ class LLMProvider(str, Enum):
if not settings.GROQ_API_KEY:
raise ValueError("Please provide a value for `GROQ_API_KEY`")
llm_client = GroqClient(settings.GROQ_API_KEY, settings.GROQ_MODEL, settings.LLM_TEMPERATURE) # type: ignore[arg-type]
elif settings.LLM_PROVIDER == LLMProvider.OPENAI:
if not settings.OPENAI_API_KEY:
raise ValueError("Please provide a value for `OPENAI_API_KEY`")
llm_client = OpenAIClient(settings.OPENAI_API_KEY, settings.OPENAI_MODEL, settings.LLM_TEMPERATURE) # type: ignore[arg-type]
else:
raise NotImplementedError("LLM provider is not implemented")
Loading

0 comments on commit 0f2d66b

Please sign in to comment.