Skip to content

Commit

Permalink
google-genai[patch]: added parsing of function call / response (langc…
Browse files Browse the repository at this point in the history
  • Loading branch information
lkuligin authored Feb 8, 2024
1 parent a210a8b commit 1862900
Showing 1 changed file with 52 additions and 107 deletions.
159 changes: 52 additions & 107 deletions libs/partners/google-genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,28 @@
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
from urllib.parse import urlparse

import google.ai.generativelanguage as glm
import google.api_core

# TODO: remove ignore once the google package is published with types
import google.generativeai as genai # type: ignore[import]
import proto # type: ignore[import]
import requests
from google.ai.generativelanguage_v1beta import FunctionCall
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
HumanMessage,
HumanMessageChunk,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
Expand Down Expand Up @@ -327,15 +323,30 @@ def _parse_chat_history(
continue
elif isinstance(message, AIMessage):
role = "model"
# TODO: Handle AImessage with function call
parts = _convert_to_parts(message.content)
raw_function_call = message.additional_kwargs.get("function_call")
if raw_function_call:
function_call = glm.FunctionCall(
{
"name": raw_function_call["name"],
"args": json.loads(raw_function_call["arguments"]),
}
)
parts = [glm.Part(function_call=function_call)]
else:
parts = _convert_to_parts(message.content)
elif isinstance(message, HumanMessage):
role = "user"
parts = _convert_to_parts(message.content)
elif isinstance(message, FunctionMessage):
role = "user"
# TODO: Handle FunctionMessage
parts = _convert_to_parts(message.content)
parts = [
glm.Part(
function_response=glm.FunctionResponse(
name=message.name,
response=message.content,
)
)
]
else:
raise ValueError(
f"Unexpected message with type {type(message)} at the position {i}."
Expand All @@ -353,100 +364,44 @@ def _parse_chat_history(
return messages


def _retrieve_function_call_response(
parts: List[genai.types.PartType],
) -> Optional[Dict]:
for idx, part in enumerate(parts):
if part.function_call and part.function_call.name:
fc: FunctionCall = part.function_call
return {
"function_call": {
"name": fc.name,
"arguments": json.dumps(
dict(fc.args.items())
), # dump to match other function calling llms for now
}
}
return None


def _parts_to_content(
parts: List[genai.types.PartType],
) -> Tuple[Union[str, List[Union[Dict[Any, Any], str]]], Optional[Dict]]:
"""Converts a list of Gemini API Part objects into a list of LangChain messages."""
function_call_resp = _retrieve_function_call_response(parts)

if len(parts) == 1 and parts[0].text is not None and not parts[0].inline_data:
# Simple text response. The typical response
return parts[0].text, function_call_resp
elif not parts:
logger.warning("Gemini produced an empty response.")
return "", function_call_resp
messages: List[Union[Dict[Any, Any], str]] = []
for part in parts:
if part.text is not None:
messages.append(
{
"type": "text",
"text": part.text,
}
)
else:
# TODO: Handle inline_data if that's a thing?
raise ChatGoogleGenerativeAIError(f"Unexpected part type. {part}")
return messages, function_call_resp
def _parse_response_candidate(response_candidate: glm.Candidate) -> AIMessage:
first_part = response_candidate.content.parts[0]
if first_part.function_call:
function_call = proto.Message.to_dict(first_part.function_call)
function_call["arguments"] = json.dumps(function_call.pop("args", {}))
return AIMessage(content="", additional_kwargs={"function_call": function_call})
else:
parts = response_candidate.content.parts

if len(parts) == 1 and parts[0].text:
content: Union[str, List[Union[str, Dict]]] = parts[0].text
else:
content = [proto.Message.to_dict(part) for part in parts]
return AIMessage(content=content, additional_kwargs={})


def _response_to_result(
response: genai.types.GenerateContentResponse,
ai_msg_t: Type[BaseMessage] = AIMessage,
human_msg_t: Type[BaseMessage] = HumanMessage,
chat_msg_t: Type[BaseMessage] = ChatMessage,
generation_t: Type[ChatGeneration] = ChatGeneration,
response: glm.GenerateContentResponse,
) -> ChatResult:
"""Converts a PaLM API response into a LangChain ChatResult."""
llm_output = {}
if response.prompt_feedback:
try:
prompt_feedback = type(response.prompt_feedback).to_dict(
response.prompt_feedback, use_integers_for_enums=False
)
llm_output["prompt_feedback"] = prompt_feedback
except Exception as e:
logger.debug(f"Unable to convert prompt_feedback to dict: {e}")
llm_output = {"prompt_feedback": proto.Message.to_dict(response.prompt_feedback)}

generations: List[ChatGeneration] = []

role_map = {
"model": ai_msg_t,
"user": human_msg_t,
}

for candidate in response.candidates:
content = candidate.content
parts_content, additional_kwargs = _parts_to_content(content.parts)
if content.role not in role_map:
logger.warning(
f"Unrecognized role: {content.role}. Treating as a ChatMessage."
)
msg = chat_msg_t(
content=parts_content,
role=content.role,
additional_kwargs=additional_kwargs or {},
)
else:
msg = role_map[content.role](
content=parts_content,
additional_kwargs=additional_kwargs or {},
)
generation_info = {}
if candidate.finish_reason:
generation_info["finish_reason"] = candidate.finish_reason.name
if candidate.safety_ratings:
generation_info["safety_ratings"] = [
type(rating).to_dict(rating) for rating in candidate.safety_ratings
]
generations.append(generation_t(message=msg, generation_info=generation_info))
generation_info["safety_ratings"] = [
proto.Message.to_dict(safety_rating, use_integers_for_enums=False)
for safety_rating in candidate.safety_ratings
]
generations.append(
ChatGeneration(
message=_parse_response_candidate(candidate),
generation_info=generation_info,
)
)
if not response.candidates:
# Likely a "prompt feedback" violation (e.g., toxic input)
# Raising an error would be different than how OpenAI handles it,
Expand All @@ -455,7 +410,9 @@ def _response_to_result(
"Gemini produced an empty response. Continuing with empty message\n"
f"Feedback: {response.prompt_feedback}"
)
generations = [generation_t(message=ai_msg_t(content=""), generation_info={})]
generations = [
ChatGeneration(message=AIMessage(content=""), generation_info={})
]
return ChatResult(generations=generations, llm_output=llm_output)


Expand Down Expand Up @@ -616,13 +573,7 @@ def _stream(
stream=True,
)
for chunk in response:
_chat_result = _response_to_result(
chunk,
ai_msg_t=AIMessageChunk,
human_msg_t=HumanMessageChunk,
chat_msg_t=ChatMessageChunk,
generation_t=ChatGenerationChunk,
)
_chat_result = _response_to_result(chunk)
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
if run_manager:
run_manager.on_llm_new_token(gen.text)
Expand All @@ -646,13 +597,7 @@ async def _astream(
generation_method=chat.send_message_async,
stream=True,
):
_chat_result = _response_to_result(
chunk,
ai_msg_t=AIMessageChunk,
human_msg_t=HumanMessageChunk,
chat_msg_t=ChatMessageChunk,
generation_t=ChatGenerationChunk,
)
_chat_result = _response_to_result(chunk)
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
if run_manager:
await run_manager.on_llm_new_token(gen.text)
Expand Down

0 comments on commit 1862900

Please sign in to comment.