Skip to content

Commit

Permalink
More precise testing OpenAILike (run-llama#9026)
Browse files Browse the repository at this point in the history
Co-authored-by: Logan <[email protected]>
  • Loading branch information
jamesbraza and logan-markewich authored Nov 20, 2023
1 parent 1033d30 commit dee8f8e
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### New Features

- More precise testing of `OpenAILike` (#9026)
- Added callback manager to each retriever (#8871)

### Bug Fixes / Nits
Expand Down
58 changes: 40 additions & 18 deletions tests/llms/test_openai_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,34 @@
from unittest.mock import MagicMock, patch

from llama_index.llms import OpenAILike
from llama_index.llms.base import ChatMessage
from llama_index.llms.base import ChatMessage, MessageRole
from llama_index.llms.openai import Tokenizer
from openai.types import Completion, CompletionChoice
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_message import ChatCompletionMessage


class MockTokenizer:
def encode(self, text: str) -> List[str]:
return text.split(" ")
class StubTokenizer(Tokenizer):
def encode(self, text: str) -> List[int]:
return [sum(ord(letter) for letter in word) for word in text.split(" ")]


STUB_MODEL_NAME = "models/stub.gguf"
STUB_API_KEY = "stub_key"


def test_interfaces() -> None:
llm = OpenAILike(model="placeholder")
llm = OpenAILike(model=STUB_MODEL_NAME, api_key=STUB_API_KEY)
assert llm.class_name() == type(llm).__name__
assert llm.model == "placeholder"
assert llm.model == STUB_MODEL_NAME


def mock_chat_completion(text: str) -> ChatCompletion:
return ChatCompletion(
id="chatcmpl-abc123",
object="chat.completion",
created=1677858242,
model="gpt-3.5-turbo-0301",
model=STUB_MODEL_NAME,
usage={"prompt_tokens": 13, "completion_tokens": 7, "total_tokens": 20},
choices=[
Choice(
Expand All @@ -38,10 +43,10 @@ def mock_chat_completion(text: str) -> ChatCompletion:

def mock_completion(text: str) -> Completion:
return Completion(
id="chatcmpl-abc123",
id="cmpl-abc123",
object="text_completion",
created=1677858242,
model="gpt-3.5-turbo-0301",
model=STUB_MODEL_NAME,
usage={"prompt_tokens": 13, "completion_tokens": 7, "total_tokens": 20},
choices=[
CompletionChoice(
Expand All @@ -61,14 +66,21 @@ def test_completion(MockSyncOpenAI: MagicMock) -> None:
mock_instance.completions.create.return_value = mock_completion(text)

llm = OpenAILike(
model="placeholder",
model=STUB_MODEL_NAME,
is_chat_model=False,
context_window=1024,
tokenizer=MockTokenizer(),
tokenizer=StubTokenizer(),
)

response = llm.complete("A long time ago in a galaxy far, far away")
assert response.text == text
mock_instance.completions.create.assert_called_once_with(
prompt="A long time ago in a galaxy far, far away",
stream=False,
model=STUB_MODEL_NAME,
temperature=0.1,
max_tokens=1014,
)


@patch("llama_index.llms.openai.SyncOpenAI")
Expand All @@ -79,22 +91,32 @@ def test_chat(MockSyncOpenAI: MagicMock) -> None:
mock_instance.chat.completions.create.return_value = mock_chat_completion(content)

llm = OpenAILike(
model="models/placeholder", is_chat_model=True, tokenizer=MockTokenizer()
model=STUB_MODEL_NAME, is_chat_model=True, tokenizer=StubTokenizer()
)

response = llm.chat([ChatMessage(role="user", content="test message")])
response = llm.chat([ChatMessage(role=MessageRole.USER, content="test message")])
assert response.message.content == content
mock_instance.chat.completions.create.assert_called_once_with(
messages=[{"role": MessageRole.USER, "content": "test message"}],
stream=False,
model=STUB_MODEL_NAME,
temperature=0.1,
)


def test_serialization() -> None:
llm = OpenAILike(
model="placeholder",
model=STUB_MODEL_NAME,
is_chat_model=True,
context_window=42,
tokenizer=MockTokenizer(),
max_tokens=42,
context_window=43,
tokenizer=StubTokenizer(),
)

serialized = llm.to_dict()

# Check OpenAI base class specifics
assert "api_key" not in serialized
assert serialized["max_tokens"] == 42
# Check OpenAILike subclass specifics
assert serialized["context_window"] == 43
assert serialized["is_chat_model"]
assert serialized["context_window"] == 42

0 comments on commit dee8f8e

Please sign in to comment.