Skip to content

Commit

Permalink
fix ChatOpenAI.agenerate (langchain-ai#1504)
Browse files Browse the repository at this point in the history
  • Loading branch information
agola11 authored Mar 7, 2023
1 parent 4f41e20 commit 27104d4
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 30 deletions.
9 changes: 3 additions & 6 deletions langchain/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,14 @@ def generate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
) -> LLMResult:
"""Top Level call"""
results = []
for m in messages:
results.append(self._generate(m, stop=stop))
results = [self._generate(m, stop=stop) for m in messages]
return LLMResult(generations=[res.generations for res in results])

async def agenerate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
) -> LLMResult:
results = []
for m in messages:
results.append(self._generate(m, stop=stop))
"""Top Level call"""
results = [await self._agenerate(m, stop=stop) for m in messages]
return LLMResult(generations=[res.generations for res in results])

def generate_prompt(
Expand Down
46 changes: 23 additions & 23 deletions langchain/chat_models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import logging
import sys
from typing import Any, Callable, Dict, List, Mapping, Optional
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple

from pydantic import BaseModel, Extra, Field, root_validator
from tenacity import (
Expand Down Expand Up @@ -91,6 +91,15 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
return message_dict


def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
generations = []
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
gen = ChatGeneration(message=message)
generations.append(gen)
return ChatResult(generations=generations)


class ChatOpenAI(BaseChatModel, BaseModel):
"""Wrapper around OpenAI Chat large language models.
Expand Down Expand Up @@ -215,12 +224,7 @@ def _completion_with_retry(**kwargs: Any) -> Any:
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params}
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [_convert_message_to_dict(m) for m in messages]
message_dicts, params = self._create_message_dicts(messages, stop)
if self.streaming:
inner_completion = ""
role = "assistant"
Expand All @@ -240,22 +244,23 @@ def _generate(
)
return ChatResult(generations=[ChatGeneration(message=message)])
response = self.completion_with_retry(messages=message_dicts, **params)
generations = []
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
gen = ChatGeneration(message=message)
generations.append(gen)
return ChatResult(generations=generations)
return _create_chat_result(response)

async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params}
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params

async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
if self.streaming:
inner_completion = ""
role = "assistant"
Expand All @@ -281,15 +286,10 @@ async def _agenerate(
)
return ChatResult(generations=[ChatGeneration(message=message)])
else:
full_response = await acompletion_with_retry(
response = await acompletion_with_retry(
self, messages=message_dicts, **params
)
generations = []
for res in full_response["choices"]:
message = _convert_dict_to_message(res["message"])
gen = ChatGeneration(message=message)
generations.append(gen)
return ChatResult(generations=generations)
return _create_chat_result(response)

@property
def _identifying_params(self) -> Mapping[str, Any]:
Expand Down
41 changes: 41 additions & 0 deletions tests/integration_tests/chat_models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,44 @@ def test_chat_openai_invalid_streaming_params() -> None:
temperature=0,
n=5,
)


@pytest.mark.asyncio
async def test_async_chat_openai() -> None:
"""Test async generation."""
chat = ChatOpenAI(max_tokens=10, n=2)
message = HumanMessage(content="Hello")
response = await chat.agenerate([[message], [message]])
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
assert len(generations) == 2
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content


@pytest.mark.asyncio
async def test_async_chat_openai_streaming() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatOpenAI(
max_tokens=10,
streaming=True,
temperature=0,
callback_manager=callback_manager,
verbose=True,
)
message = HumanMessage(content="Hello")
response = await chat.agenerate([[message], [message]])
assert callback_handler.llm_streams > 0
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
assert len(generations) == 1
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content
2 changes: 1 addition & 1 deletion tests/integration_tests/llms/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_openai_stop_error() -> None:


def test_saving_loading_llm(tmp_path: Path) -> None:
"""Test saving/loading an OpenAPI LLM."""
"""Test saving/loading an OpenAI LLM."""
llm = OpenAI(max_tokens=10)
llm.save(file_path=tmp_path / "openai.yaml")
loaded_llm = load_llm(tmp_path / "openai.yaml")
Expand Down

0 comments on commit 27104d4

Please sign in to comment.