forked from langchain-ai/langgraph
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
251 additions
and
0 deletions.
There are no files selected for viewing
52 changes: 52 additions & 0 deletions
52
examples/chatbot-simulation-evaluation/_testing/simulation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from typing import List | ||
|
||
import openai | ||
from langchain_core.messages import HumanMessage | ||
from langchain_openai import ChatOpenAI | ||
from simulation_utils import ( | ||
create_chat_simulator, | ||
create_simulated_user, | ||
langchain_to_openai_messages, | ||
) | ||
|
||
openai_client = openai.Client() | ||
|
||
|
||
def my_chat_bot(messages: list) -> str: | ||
oai_messages = langchain_to_openai_messages(messages) | ||
system_message = { | ||
"role": "system", | ||
"content": "You are a customer support agent for an airline.", | ||
} | ||
messages = [system_message] + oai_messages | ||
completion = openai_client.chat.completions.create( | ||
messages=messages, model="gpt-3.5-turbo" | ||
) | ||
return completion.choices[0].message.content | ||
|
||
|
||
my_chat_bot([{"role": "user", "content": "hi!"}]) | ||
|
||
|
||
system_prompt_template = """You are a customer of an airline company. \ | ||
You are interacting with a user who is a customer support person. \ | ||
{instructions} | ||
Your task is to get a big discount on your next flight. \ | ||
When you are finished with the conversation, respond with a single word 'FINISHED'""" | ||
|
||
simulated_user = create_simulated_user( | ||
system_prompt_template, llm=ChatOpenAI(model="gpt-3.5-turbo") | ||
) | ||
|
||
# my chat bot accepts a list of LangChain mesages | ||
# Simulated user accepts a list of LangChain messages | ||
# TODO: Pass additional arguments to the simulated user | ||
simulator = create_chat_simulator(my_chat_bot, simulated_user) | ||
simulator.invoke( | ||
{ | ||
"instructions": "You are extremely disgruntled and will cusss and swear to get your way. Try to get a discount by any means necessary." | ||
} | ||
) |
199 changes: 199 additions & 0 deletions
199
examples/chatbot-simulation-evaluation/_testing/simulation_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
import functools | ||
from typing import Annotated, Any, Callable, Dict, List, Optional, Union | ||
|
||
from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, HumanMessage | ||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | ||
from langchain_core.runnables import Runnable, RunnableLambda | ||
from langchain_core.runnables import chain as as_runnable | ||
from langchain_openai import ChatOpenAI | ||
from typing_extensions import TypedDict | ||
|
||
from langgraph.graph import END, StateGraph | ||
|
||
|
||
def langchain_to_openai_messages(messages: List[BaseMessage]): | ||
""" | ||
Convert a list of langchain base messages to a list of openai messages. | ||
Parameters: | ||
messages (List[BaseMessage]): A list of langchain base messages. | ||
Returns: | ||
List[dict]: A list of openai messages. | ||
""" | ||
from langchain.adapters.openai import convert_message_to_dict # noqa: I001 | ||
|
||
return [ | ||
convert_message_to_dict(m) if isinstance(m, BaseMessage) else m | ||
for m in messages | ||
] | ||
|
||
|
||
def create_simulated_user( | ||
system_prompt: str, llm: Runnable | None = None | ||
) -> Runnable[Dict, AIMessage]: | ||
""" | ||
Creates a simulated user for chatbot simulation. | ||
Args: | ||
system_prompt (str): The system prompt to be used by the simulated user. | ||
llm (Runnable | None, optional): The language model to be used for the simulation. | ||
Defaults to gpt-3.5-turbo. | ||
Returns: | ||
Runnable[Dict, AIMessage]: The simulated user for chatbot simulation. | ||
""" | ||
return ChatPromptTemplate.from_messages( | ||
[ | ||
("system", system_prompt), | ||
MessagesPlaceholder(variable_name="messages"), | ||
] | ||
) | (llm or ChatOpenAI(model="gpt-3.5-turbo")).with_config( | ||
run_name="simulated_user" | ||
) | ||
|
||
|
||
Messages = Union[list[AnyMessage], AnyMessage] | ||
|
||
|
||
def add_messages(left: Messages, right: Messages) -> Messages: | ||
if not isinstance(left, list): | ||
left = [left] | ||
if not isinstance(right, list): | ||
right = [right] | ||
return left + right | ||
|
||
|
||
class SimulationState(TypedDict): | ||
""" | ||
Represents the state of a simulation. | ||
Attributes: | ||
messages (List[AnyMessage]): A list of messages in the simulation. | ||
inputs (Optional[dict[str, Any]]): Optional inputs for the simulation. | ||
""" | ||
|
||
messages: Annotated[List[AnyMessage], add_messages] | ||
inputs: Optional[dict[str, Any]] | ||
|
||
|
||
def create_chat_simulator( | ||
assistant: ( | ||
Callable[[List[AnyMessage]], str | AIMessage] | ||
| Runnable[List[AnyMessage], str | AIMessage] | ||
), | ||
simulated_user: Runnable[Dict, AIMessage], | ||
*, | ||
input_key: Optional[str] = None, | ||
max_turns: int = 6, | ||
should_continue: Optional[Callable[[SimulationState], str]] = None, | ||
): | ||
"""Creates a chat simulator for evaluating a chatbot. | ||
Args: | ||
assistant: The chatbot assistant function or runnable object. | ||
simulated_user: The simulated user object. | ||
max_turns: The maximum number of turns in the chat simulation. Default is 6. | ||
should_continue: Optional function to determine if the simulation should continue. | ||
If not provided, a default function will be used. | ||
Returns: | ||
The compiled chat simulation graph. | ||
""" | ||
graph_builder = StateGraph(SimulationState) | ||
graph_builder.add_node( | ||
"user", | ||
_create_simulated_user_node(simulated_user), | ||
) | ||
graph_builder.add_node( | ||
"assistant", _fetch_messages | assistant | _coerce_to_message | ||
) | ||
graph_builder.add_edge("assistant", "user") | ||
graph_builder.add_conditional_edges( | ||
"user", | ||
should_continue or functools.partial(_should_continue, max_turns=max_turns), | ||
) | ||
# If your dataset has a 'leading question/input', then we route first to the assistant, otherwise, we let the user take the lead. | ||
graph_builder.set_entry_point("assistant" if input_key is not None else "user") | ||
|
||
return ( | ||
RunnableLambda(_prepare_example).bind(input_key=input_key) | ||
| graph_builder.compile() | ||
) | ||
|
||
|
||
## Private methods | ||
|
||
|
||
def _prepare_example(inputs: dict[str, Any], input_key: Optional[str] = None): | ||
if input_key is not None: | ||
if input_key not in inputs: | ||
raise ValueError( | ||
f"Dataset's example input must contain the provided input key: '{input_key}'.\nFound: {list(input.keys())}" | ||
) | ||
messages = [HumanMessage(content=inputs[input_key])] | ||
return { | ||
"inputs": {k: v for k, v in inputs.items() if k != input_key}, | ||
"messages": messages, | ||
} | ||
return {"inputs": inputs} | ||
|
||
|
||
def _invoke_simulated_user(state: SimulationState, simulated_user: Runnable): | ||
"""Invoke the simulated user node.""" | ||
runnable = ( | ||
simulated_user | ||
if isinstance(simulated_user, Runnable) | ||
else RunnableLambda(simulated_user) | ||
) | ||
inputs = state.get("inputs", {}) | ||
inputs["messages"] = state["messages"] | ||
return runnable.invoke(inputs) | ||
|
||
|
||
def _swap_roles(messages: List[AnyMessage]): | ||
new_messages = [] | ||
for m in messages: | ||
if isinstance(m, AIMessage): | ||
new_messages.append(HumanMessage(content=m.content)) | ||
else: | ||
new_messages.append(AIMessage(content=m.content)) | ||
return new_messages | ||
|
||
|
||
@as_runnable | ||
def _fetch_messages(state: SimulationState): | ||
"""Invoke the simulated user node.""" | ||
return state["messages"] | ||
|
||
|
||
def _convert_to_human_message(message: BaseMessage): | ||
return HumanMessage(content=message.content) | ||
|
||
|
||
def _create_simulated_user_node(simulated_user: Runnable): | ||
"""Simulated user accepts a {"messages": [...]} argument and returns a single message.""" | ||
return ( | ||
_swap_roles | ||
| RunnableLambda(_invoke_simulated_user).bind(simulated_user=simulated_user) | ||
| _convert_to_human_message | ||
) | ||
|
||
|
||
def _coerce_to_message(assistant_output: str | BaseMessage): | ||
if isinstance(assistant_output, str): | ||
return AIMessage(content=assistant_output) | ||
else: | ||
return assistant_output | ||
|
||
|
||
def _should_continue(state: SimulationState, max_turns: int = 6): | ||
messages = state["messages"] | ||
# TODO support other stop criteria | ||
if len(messages) > max_turns: | ||
return END | ||
elif messages[-1].content.strip() == "FINISHED": | ||
return END | ||
else: | ||
return "assistant" |