Skip to content

Commit

Permalink
Extract OpenAI API retry handler and unify ADA embeddings calls. (Sig…
Browse files Browse the repository at this point in the history
…nificant-Gravitas#3191)

* Extract retry logic, unify embedding functions

* Add some docstrings

* Remove embedding creation from API manager

* Add test suite for retry handler

* Make api manager fixture

* Fix typing

* Streamline tests
  • Loading branch information
collijk authored Apr 25, 2023
1 parent 940b115 commit 2619740
Show file tree
Hide file tree
Showing 9 changed files with 242 additions and 93 deletions.
26 changes: 0 additions & 26 deletions autogpt/api_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,32 +65,6 @@ def create_chat_completion(
self.update_cost(prompt_tokens, completion_tokens, model)
return response

def embedding_create(
self,
text_list: List[str],
model: str = "text-embedding-ada-002",
) -> List[float]:
"""
Create an embedding for the given input text using the specified model.
Args:
text_list (List[str]): Input text for which the embedding is to be created.
model (str, optional): The model to use for generating the embedding.
Returns:
List[float]: The generated embedding as a list of float values.
"""
if cfg.use_azure:
response = openai.Embedding.create(
input=text_list,
engine=cfg.get_azure_deployment_id_for_model(model),
)
else:
response = openai.Embedding.create(input=text_list, model=model)

self.update_cost(response.usage.prompt_tokens, 0, model)
return response["data"][0]["embedding"]

def update_cost(self, prompt_tokens, completion_tokens, model):
"""
Update the total cost, prompt tokens, and completion tokens.
Expand Down
119 changes: 93 additions & 26 deletions autogpt/llm_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import functools
import time
from typing import List, Optional

Expand All @@ -13,10 +14,62 @@
from autogpt.types.openai import Message

CFG = Config()

openai.api_key = CFG.openai_api_key


def retry_openai_api(
num_retries: int = 10,
backoff_base: float = 2.0,
warn_user: bool = True,
):
"""Retry an OpenAI API call.
Args:
num_retries int: Number of retries. Defaults to 10.
backoff_base float: Base for exponential backoff. Defaults to 2.
warn_user bool: Whether to warn the user. Defaults to True.
"""
retry_limit_msg = f"{Fore.RED}Error: " f"Reached rate limit, passing...{Fore.RESET}"
api_key_error_msg = (
f"Please double check that you have setup a "
f"{Fore.CYAN + Style.BRIGHT}PAID{Style.RESET_ALL} OpenAI API Account. You can "
f"read more here: {Fore.CYAN}https://github.com/Significant-Gravitas/Auto-GPT#openai-api-keys-configuration{Fore.RESET}"
)
backoff_msg = (
f"{Fore.RED}Error: API Bad gateway. Waiting {{backoff}} seconds...{Fore.RESET}"
)

def _wrapper(func):
@functools.wraps(func)
def _wrapped(*args, **kwargs):
user_warned = not warn_user
num_attempts = num_retries + 1 # +1 for the first attempt
for attempt in range(1, num_attempts + 1):
try:
return func(*args, **kwargs)

except RateLimitError:
if attempt == num_attempts:
raise

logger.debug(retry_limit_msg)
if not user_warned:
logger.double_check(api_key_error_msg)
user_warned = True

except APIError as e:
if (e.http_status != 502) or (attempt == num_attempts):
raise

backoff = backoff_base ** (attempt + 2)
logger.debug(backoff_msg.format(backoff=backoff))
time.sleep(backoff)

return _wrapped

return _wrapper


def call_ai_function(
function: str, args: list, description: str, model: str | None = None
) -> str:
Expand Down Expand Up @@ -154,32 +207,46 @@ def create_chat_completion(
return resp


def get_ada_embedding(text):
def get_ada_embedding(text: str) -> List[int]:
"""Get an embedding from the ada model.
Args:
text (str): The text to embed.
Returns:
List[int]: The embedding.
"""
model = "text-embedding-ada-002"
text = text.replace("\n", " ")
return api_manager.embedding_create(
text_list=[text], model="text-embedding-ada-002"

if CFG.use_azure:
kwargs = {"engine": CFG.get_azure_deployment_id_for_model(model)}
else:
kwargs = {"model": model}

embedding = create_embedding(text, **kwargs)
api_manager.update_cost(
prompt_tokens=embedding.usage.prompt_tokens,
completion_tokens=0,
model=model,
)
return embedding["data"][0]["embedding"]


def create_embedding_with_ada(text) -> list:
"""Create an embedding with text-ada-002 using the OpenAI SDK"""
num_retries = 10
for attempt in range(num_retries):
backoff = 2 ** (attempt + 2)
try:
return api_manager.embedding_create(
text_list=[text], model="text-embedding-ada-002"
)
except RateLimitError:
pass
except (APIError, Timeout) as e:
if e.http_status != 502:
raise
if attempt == num_retries - 1:
raise
if CFG.debug_mode:
print(
f"{Fore.RED}Error: ",
f"API Bad gateway. Waiting {backoff} seconds...{Fore.RESET}",
)
time.sleep(backoff)
@retry_openai_api()
def create_embedding(
text: str,
*_,
**kwargs,
) -> openai.Embedding:
"""Create an embedding using the OpenAI API
Args:
text (str): The text to embed.
kwargs: Other arguments to pass to the OpenAI API embedding creation call.
Returns:
openai.Embedding: The embedding object.
"""

return openai.Embedding.create(input=[text], **kwargs)
6 changes: 3 additions & 3 deletions autogpt/memory/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import orjson

from autogpt.llm_utils import create_embedding_with_ada
from autogpt.llm_utils import get_ada_embedding
from autogpt.memory.base import MemoryProviderSingleton

EMBED_DIM = 1536
Expand Down Expand Up @@ -63,7 +63,7 @@ def add(self, text: str):
return ""
self.data.texts.append(text)

embedding = create_embedding_with_ada(text)
embedding = get_ada_embedding(text)

vector = np.array(embedding).astype(np.float32)
vector = vector[np.newaxis, :]
Expand Down Expand Up @@ -111,7 +111,7 @@ def get_relevant(self, text: str, k: int) -> list[Any]:
Returns: List[str]
"""
embedding = create_embedding_with_ada(text)
embedding = get_ada_embedding(text)

scores = np.dot(self.data.embeddings, embedding)

Expand Down
6 changes: 3 additions & 3 deletions autogpt/memory/pinecone.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pinecone
from colorama import Fore, Style

from autogpt.llm_utils import create_embedding_with_ada
from autogpt.llm_utils import get_ada_embedding
from autogpt.logs import logger
from autogpt.memory.base import MemoryProviderSingleton

Expand Down Expand Up @@ -44,7 +44,7 @@ def __init__(self, cfg):
self.index = pinecone.Index(table_name)

def add(self, data):
vector = create_embedding_with_ada(data)
vector = get_ada_embedding(data)
# no metadata here. We may wish to change that long term.
self.index.upsert([(str(self.vec_num), vector, {"raw_text": data})])
_text = f"Inserting data into memory at index: {self.vec_num}:\n data: {data}"
Expand All @@ -64,7 +64,7 @@ def get_relevant(self, data, num_relevant=5):
:param data: The data to compare to.
:param num_relevant: The number of relevant data to return. Defaults to 5
"""
query_embedding = create_embedding_with_ada(data)
query_embedding = get_ada_embedding(data)
results = self.index.query(
query_embedding, top_k=num_relevant, include_metadata=True
)
Expand Down
6 changes: 3 additions & 3 deletions autogpt/memory/redismem.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
from redis.commands.search.query import Query

from autogpt.llm_utils import create_embedding_with_ada
from autogpt.llm_utils import get_ada_embedding
from autogpt.logs import logger
from autogpt.memory.base import MemoryProviderSingleton

Expand Down Expand Up @@ -88,7 +88,7 @@ def add(self, data: str) -> str:
"""
if "Command Error:" in data:
return ""
vector = create_embedding_with_ada(data)
vector = get_ada_embedding(data)
vector = np.array(vector).astype(np.float32).tobytes()
data_dict = {b"data": data, "embedding": vector}
pipe = self.redis.pipeline()
Expand Down Expand Up @@ -130,7 +130,7 @@ def get_relevant(self, data: str, num_relevant: int = 5) -> list[Any] | None:
Returns: A list of the most relevant data.
"""
query_embedding = create_embedding_with_ada(data)
query_embedding = get_ada_embedding(data)
base_query = f"*=>[KNN {num_relevant} @embedding $vector AS vector_score]"
query = (
Query(base_query)
Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import pytest
from dotenv import load_dotenv

from autogpt.api_manager import ApiManager
from autogpt.api_manager import api_manager as api_manager_
from autogpt.config import Config
from autogpt.workspace import Workspace

Expand All @@ -29,3 +31,11 @@ def config(workspace: Workspace) -> Config:
config.workspace_path = workspace.root
yield config
config.workspace_path = old_ws_path


@pytest.fixture()
def api_manager() -> ApiManager:
old_attrs = api_manager_.__dict__.copy()
api_manager_.reset()
yield api_manager_
api_manager_.__dict__.update(old_attrs)
31 changes: 0 additions & 31 deletions tests/test_api_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,37 +86,6 @@ def test_create_chat_completion_valid_inputs():
assert api_manager.get_total_completion_tokens() == 20
assert api_manager.get_total_cost() == (10 * 0.002 + 20 * 0.002) / 1000

@staticmethod
def test_embedding_create_invalid_model():
"""Test if an invalid model for embedding raises a KeyError."""
text_list = ["Hello, how are you?"]
model = "invalid-model"

with patch("openai.Embedding.create") as mock_create:
mock_response = MagicMock()
mock_response.usage.prompt_tokens = 5
mock_create.side_effect = KeyError("Invalid model")
with pytest.raises(KeyError):
api_manager.embedding_create(text_list, model=model)

@staticmethod
def test_embedding_create_valid_inputs():
"""Test if valid inputs for embedding result in correct tokens and cost."""
text_list = ["Hello, how are you?"]
model = "text-embedding-ada-002"

with patch("openai.Embedding.create") as mock_create:
mock_response = MagicMock()
mock_response.usage.prompt_tokens = 5
mock_response["data"] = [{"embedding": [0.1, 0.2, 0.3]}]
mock_create.return_value = mock_response

api_manager.embedding_create(text_list, model=model)

assert api_manager.get_total_prompt_tokens() == 5
assert api_manager.get_total_completion_tokens() == 0
assert api_manager.get_total_cost() == (5 * 0.0004) / 1000

def test_getter_methods(self):
"""Test the getter methods for total tokens, cost, and budget."""
api_manager.update_cost(60, 120, "gpt-3.5-turbo")
Expand Down
Loading

0 comments on commit 2619740

Please sign in to comment.