Skip to content

Commit

Permalink
Add token usage to messages (microsoft#4028)
Browse files Browse the repository at this point in the history
* Add token usage to messages

* small test edit
  • Loading branch information
ekzhu authored Nov 1, 2024
1 parent e9c16fe commit ca7caa7
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,8 @@ async def on_messages_stream(
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name))
# Add the tool call message to the output.
inner_messages.append(ToolCallMessage(content=result.content, source=self.name))
yield ToolCallMessage(content=result.content, source=self.name)
inner_messages.append(ToolCallMessage(content=result.content, source=self.name, model_usage=result.usage))
yield ToolCallMessage(content=result.content, source=self.name, model_usage=result.usage)

# Execute the tool calls.
results = await asyncio.gather(
Expand Down Expand Up @@ -303,7 +303,8 @@ async def on_messages_stream(

assert isinstance(result.content, str)
yield Response(
chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages
chat_message=TextMessage(content=result.content, source=self.name, model_usage=result.usage),
inner_messages=inner_messages,
)

async def _execute_tool_call(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List

from autogen_core.components import FunctionCall, Image
from autogen_core.components.models import FunctionExecutionResult
from autogen_core.components.models import FunctionExecutionResult, RequestUsage
from pydantic import BaseModel


Expand All @@ -11,6 +11,9 @@ class BaseMessage(BaseModel):
source: str
"""The name of the agent that sent this message."""

model_usage: RequestUsage | None = None
"""The model client usage incurred when producing this message."""


class TextMessage(BaseMessage):
"""A text message."""
Expand Down
24 changes: 20 additions & 4 deletions python/packages/autogen-agentchat/tests/test_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
ChatCompletion(
id="id2",
Expand All @@ -88,7 +88,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
ChatCompletion(
id="id2",
Expand All @@ -100,7 +100,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
]
mock = _MockChatCompletion(chat_completions)
Expand All @@ -113,9 +113,17 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
result = await tool_use_agent.run("task")
assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].model_usage is None
assert isinstance(result.messages[1], ToolCallMessage)
assert result.messages[1].model_usage is not None
assert result.messages[1].model_usage.completion_tokens == 5
assert result.messages[1].model_usage.prompt_tokens == 10
assert isinstance(result.messages[2], ToolCallResultMessage)
assert result.messages[2].model_usage is None
assert isinstance(result.messages[3], TextMessage)
assert result.messages[3].model_usage is not None
assert result.messages[3].model_usage.completion_tokens == 5
assert result.messages[3].model_usage.prompt_tokens == 10

# Test streaming.
mock._curr_index = 0 # pyright: ignore
Expand Down Expand Up @@ -158,7 +166,7 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
usage=CompletionUsage(prompt_tokens=42, completion_tokens=43, total_tokens=85),
),
]
mock = _MockChatCompletion(chat_completions)
Expand All @@ -173,9 +181,17 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
result = await tool_use_agent.run("task")
assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].model_usage is None
assert isinstance(result.messages[1], ToolCallMessage)
assert result.messages[1].model_usage is not None
assert result.messages[1].model_usage.completion_tokens == 43
assert result.messages[1].model_usage.prompt_tokens == 42
assert isinstance(result.messages[2], ToolCallResultMessage)
assert result.messages[2].model_usage is None
assert isinstance(result.messages[3], HandoffMessage)
assert result.messages[3].content == handoff.message
assert result.messages[3].target == handoff.target
assert result.messages[3].model_usage is None

# Test streaming.
mock._curr_index = 0 # pyright: ignore
Expand Down

0 comments on commit ca7caa7

Please sign in to comment.