Skip to content

Commit

Permalink
Refactor agent chat to prepare for handoff/swarm (microsoft#3949)
Browse files Browse the repository at this point in the history
Add handoff message type to chat message types
Add Swarm group chat that uses handoff message to select next speaker
Remove tool call and tool call result message types from chat message types
Remove BaseToolUseChatAgent, move tool call handling from group chat's chat agent container upward to the ToolUseAssistantAgent implementation, which subclasses BaseChatAgent directly.
Renaming for better clarity

---------

Co-authored-by: Victor Dibia <[email protected]>
  • Loading branch information
ekzhu and victordibia authored Oct 25, 2024
1 parent 0756ebd commit f31ff66
Show file tree
Hide file tree
Showing 19 changed files with 363 additions and 293 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from ._base_chat_agent import BaseChatAgent, BaseToolUseChatAgent
from ._base_chat_agent import BaseChatAgent
from ._code_executor_agent import CodeExecutorAgent
from ._coding_assistant_agent import CodingAssistantAgent
from ._tool_use_assistant_agent import ToolUseAssistantAgent

__all__ = [
"BaseChatAgent",
"BaseToolUseChatAgent",
"CodeExecutorAgent",
"CodingAssistantAgent",
"ToolUseAssistantAgent",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from abc import ABC, abstractmethod
from typing import List, Sequence
from typing import Sequence

from autogen_core.base import CancellationToken
from autogen_core.components.tools import Tool

from ..base import ChatAgent, TaskResult, TerminationCondition, ToolUseChatAgent
from ..base import ChatAgent, TaskResult, TerminationCondition
from ..messages import ChatMessage
from ..teams import RoundRobinGroupChat

Expand Down Expand Up @@ -51,21 +50,3 @@ async def run(
termination_condition=termination_condition,
)
return result


class BaseToolUseChatAgent(BaseChatAgent, ToolUseChatAgent):
"""Base class for a chat agent that can use tools.
Subclass this base class to create an agent class that uses tools by returning
ToolCallMessage message from the :meth:`on_messages` method and receiving
ToolCallResultMessage message from the input to the :meth:`on_messages` method.
"""

def __init__(self, name: str, description: str, registered_tools: List[Tool]) -> None:
super().__init__(name, description)
self._registered_tools = registered_tools

@property
def registered_tools(self) -> List[Tool]:
"""The list of tools that the agent can use."""
return self._registered_tools
Original file line number Diff line number Diff line change
@@ -1,29 +1,52 @@
import asyncio
import json
import logging
from typing import Any, Awaitable, Callable, List, Sequence

from autogen_core.base import CancellationToken
from autogen_core.components import FunctionCall
from autogen_core.components.models import (
AssistantMessage,
ChatCompletionClient,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
SystemMessage,
UserMessage,
)
from autogen_core.components.tools import FunctionTool, Tool
from pydantic import BaseModel, ConfigDict

from .. import EVENT_LOGGER_NAME
from ..messages import (
ChatMessage,
MultiModalMessage,
StopMessage,
TextMessage,
ToolCallMessage,
ToolCallResultMessage,
)
from ._base_chat_agent import BaseToolUseChatAgent
from ._base_chat_agent import BaseChatAgent

event_logger = logging.getLogger(EVENT_LOGGER_NAME)

class ToolUseAssistantAgent(BaseToolUseChatAgent):

class ToolCallEvent(BaseModel):
"""A tool call event."""

tool_calls: List[FunctionCall]
"""The tool call message."""

model_config = ConfigDict(arbitrary_types_allowed=True)


class ToolCallResultEvent(BaseModel):
"""A tool call result event."""

tool_call_results: List[FunctionExecutionResult]
"""The tool call result message."""

model_config = ConfigDict(arbitrary_types_allowed=True)


class ToolUseAssistantAgent(BaseChatAgent):
"""An agent that provides assistance with tool use.
It responds with a StopMessage when 'terminate' is detected in the response.
Expand All @@ -45,46 +68,50 @@ def __init__(
description: str = "An agent that provides assistance with ability to use tools.",
system_message: str = "You are a helpful AI assistant. Solve tasks using your tools. Reply with 'TERMINATE' when the task has been completed.",
):
tools: List[Tool] = []
super().__init__(name=name, description=description)
self._model_client = model_client
self._system_messages = [SystemMessage(content=system_message)]
self._tools: List[Tool] = []
for tool in registered_tools:
if isinstance(tool, Tool):
tools.append(tool)
self._tools.append(tool)
elif callable(tool):
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
description = tool.__doc__
else:
description = ""
tools.append(FunctionTool(tool, description=description))
self._tools.append(FunctionTool(tool, description=description))
else:
raise ValueError(f"Unsupported tool type: {type(tool)}")
super().__init__(name=name, description=description, registered_tools=tools)
self._model_client = model_client
self._system_messages = [SystemMessage(content=system_message)]
self._tool_schema = [tool.schema for tool in tools]
self._model_context: List[LLMMessage] = []

async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
# Add messages to the model context.
for msg in messages:
if isinstance(msg, ToolCallResultMessage):
self._model_context.append(FunctionExecutionResultMessage(content=msg.content))
elif not isinstance(msg, TextMessage | MultiModalMessage | StopMessage):
raise ValueError(f"Unsupported message type: {type(msg)}")
else:
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
# TODO: add special handling for handoff messages
self._model_context.append(UserMessage(content=msg.content, source=msg.source))

# Generate an inference result based on the current model context.
llm_messages = self._system_messages + self._model_context
result = await self._model_client.create(
llm_messages, tools=self._tool_schema, cancellation_token=cancellation_token
)
result = await self._model_client.create(llm_messages, tools=self._tools, cancellation_token=cancellation_token)

# Add the response to the model context.
self._model_context.append(AssistantMessage(content=result.content, source=self.name))

# Detect tool calls.
if isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
return ToolCallMessage(content=result.content, source=self.name)
# Run tool calls until the model produces a string response.
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
event_logger.debug(ToolCallEvent(tool_calls=result.content))
# Execute the tool calls.
results = await asyncio.gather(
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
)
event_logger.debug(ToolCallResultEvent(tool_call_results=results))
self._model_context.append(FunctionExecutionResultMessage(content=results))
# Generate an inference result based on the current model context.
result = await self._model_client.create(
self._model_context, tools=self._tools, cancellation_token=cancellation_token
)
self._model_context.append(AssistantMessage(content=result.content, source=self.name))

assert isinstance(result.content, str)
# Detect stop request.
Expand All @@ -93,3 +120,20 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token:
return StopMessage(content=result.content, source=self.name)

return TextMessage(content=result.content, source=self.name)

async def _execute_tool_call(
self, tool_call: FunctionCall, cancellation_token: CancellationToken
) -> FunctionExecutionResult:
"""Execute a tool call and return the result."""
try:
if not self._tools:
raise ValueError("No tools are available.")
tool = next((t for t in self._tools if t.name == tool_call.name), None)
if tool is None:
raise ValueError(f"The tool '{tool_call.name}' is not available.")
arguments = json.loads(tool_call.arguments)
result = await tool.run_json(arguments, cancellation_token)
result_as_str = tool.return_value_as_string(result)
return FunctionExecutionResult(content=result_as_str, call_id=tool_call.id)
except Exception as e:
return FunctionExecutionResult(content=f"Error: {e}", call_id=tool_call.id)
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from ._chat_agent import ChatAgent, ToolUseChatAgent
from ._chat_agent import ChatAgent
from ._task import TaskResult, TaskRunner
from ._team import Team
from ._termination import TerminatedException, TerminationCondition

__all__ = [
"ChatAgent",
"ToolUseChatAgent",
"Team",
"TerminatedException",
"TerminationCondition",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import List, Protocol, Sequence, runtime_checkable
from typing import Protocol, Sequence, runtime_checkable

from autogen_core.base import CancellationToken
from autogen_core.components.tools import Tool

from ..messages import ChatMessage
from ._task import TaskResult, TaskRunner
Expand Down Expand Up @@ -38,13 +37,3 @@ async def run(
) -> TaskResult:
"""Run the agent with the given task and return the result."""
...


@runtime_checkable
class ToolUseChatAgent(ChatAgent, Protocol):
"""Protocol for a chat agent that can use tools."""

@property
def registered_tools(self) -> List[Tool]:
"""The list of tools that the agent can use."""
...
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
import sys
from datetime import datetime

from ..agents._tool_use_assistant_agent import ToolCallEvent, ToolCallResultEvent
from ..messages import ChatMessage, StopMessage, TextMessage
from ..teams._events import (
ContentPublishEvent,
SelectSpeakerEvent,
GroupChatPublishEvent,
GroupChatSelectSpeakerEvent,
TerminationEvent,
ToolCallEvent,
ToolCallResultEvent,
)


Expand All @@ -25,7 +24,7 @@ def serialize_chat_message(message: ChatMessage) -> str:

def emit(self, record: logging.LogRecord) -> None:
ts = datetime.fromtimestamp(record.created).isoformat()
if isinstance(record.msg, ContentPublishEvent):
if isinstance(record.msg, GroupChatPublishEvent):
if record.msg.source is None:
sys.stdout.write(
f"\n{'-'*75} \n"
Expand All @@ -41,19 +40,15 @@ def emit(self, record: logging.LogRecord) -> None:
sys.stdout.flush()
elif isinstance(record.msg, ToolCallEvent):
sys.stdout.write(
f"\n{'-'*75} \n"
f"\033[91m[{ts}], Tool Call:\033[0m\n"
f"\n{self.serialize_chat_message(record.msg.agent_message)}"
f"\n{'-'*75} \n" f"\033[91m[{ts}], Tool Call:\033[0m\n" f"\n{str(record.msg.model_dump())}"
)
sys.stdout.flush()
elif isinstance(record.msg, ToolCallResultEvent):
sys.stdout.write(
f"\n{'-'*75} \n"
f"\033[91m[{ts}], Tool Call Result:\033[0m\n"
f"\n{self.serialize_chat_message(record.msg.agent_message)}"
f"\n{'-'*75} \n" f"\033[91m[{ts}], Tool Call Result:\033[0m\n" f"\n{str(record.msg.model_dump())}"
)
sys.stdout.flush()
elif isinstance(record.msg, SelectSpeakerEvent):
elif isinstance(record.msg, GroupChatSelectSpeakerEvent):
sys.stdout.write(
f"\n{'-'*75} \n" f"\033[91m[{ts}], Selected Next Speaker:\033[0m\n" f"\n{record.msg.selected_speaker}"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
from datetime import datetime
from typing import Any

from ..agents._tool_use_assistant_agent import ToolCallEvent, ToolCallResultEvent
from ..teams._events import (
ContentPublishEvent,
SelectSpeakerEvent,
GroupChatPublishEvent,
GroupChatSelectSpeakerEvent,
TerminationEvent,
ToolCallEvent,
ToolCallResultEvent,
)


Expand All @@ -21,7 +20,7 @@ def __init__(self, filename: str) -> None:

def emit(self, record: logging.LogRecord) -> None:
ts = datetime.fromtimestamp(record.created).isoformat()
if isinstance(record.msg, ContentPublishEvent | ToolCallEvent | ToolCallResultEvent | TerminationEvent):
if isinstance(record.msg, GroupChatPublishEvent | TerminationEvent):
log_entry = json.dumps(
{
"timestamp": ts,
Expand All @@ -31,7 +30,7 @@ def emit(self, record: logging.LogRecord) -> None:
},
default=self.json_serializer,
)
elif isinstance(record.msg, SelectSpeakerEvent):
elif isinstance(record.msg, GroupChatSelectSpeakerEvent):
log_entry = json.dumps(
{
"timestamp": ts,
Expand All @@ -41,6 +40,24 @@ def emit(self, record: logging.LogRecord) -> None:
},
default=self.json_serializer,
)
elif isinstance(record.msg, ToolCallEvent):
log_entry = json.dumps(
{
"timestamp": ts,
"tool_calls": record.msg.model_dump(),
"type": "ToolCallEvent",
},
default=self.json_serializer,
)
elif isinstance(record.msg, ToolCallResultEvent):
log_entry = json.dumps(
{
"timestamp": ts,
"tool_call_results": record.msg.model_dump(),
"type": "ToolCallResultEvent",
},
default=self.json_serializer,
)
else:
raise ValueError(f"Unexpected log record: {record.msg}")
file_record = logging.LogRecord(
Expand Down
29 changes: 10 additions & 19 deletions python/packages/autogen-agentchat/src/autogen_agentchat/messages.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import List

from autogen_core.components import FunctionCall, Image
from autogen_core.components.models import FunctionExecutionResult
from autogen_core.components import Image
from pydantic import BaseModel


Expand All @@ -26,37 +25,29 @@ class MultiModalMessage(BaseMessage):
"""The content of the message."""


class ToolCallMessage(BaseMessage):
"""A message containing a list of function calls."""

content: List[FunctionCall]
"""The list of function calls."""


class ToolCallResultMessage(BaseMessage):
"""A message containing the results of function calls."""

content: List[FunctionExecutionResult]
"""The list of function execution results."""


class StopMessage(BaseMessage):
"""A message requesting stop of a conversation."""

content: str
"""The content for the stop message."""


ChatMessage = TextMessage | MultiModalMessage | StopMessage | ToolCallMessage | ToolCallResultMessage
class HandoffMessage(BaseMessage):
"""A message requesting handoff of a conversation to another agent."""

content: str
"""The agent name to handoff the conversation to."""


ChatMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage
"""A message used by agents in a team."""


__all__ = [
"BaseMessage",
"TextMessage",
"MultiModalMessage",
"ToolCallMessage",
"ToolCallResultMessage",
"StopMessage",
"HandoffMessage",
"ChatMessage",
]
Loading

0 comments on commit f31ff66

Please sign in to comment.