Skip to content

Commit

Permalink
Assist Satellite to use ChatSession for conversation ID (home-assista…
Browse files Browse the repository at this point in the history
…nt#137142)

* Assist Satellite to use ChatSession for conversation ID

* Adjust for changes main branch

* Ensure the initial message is in the chat log
  • Loading branch information
balloob authored Feb 3, 2025
1 parent 4531a46 commit 8acab6c
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 52 deletions.
108 changes: 58 additions & 50 deletions homeassistant/components/assist_satellite/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from enum import StrEnum
import logging
import time
from typing import Any, Final, Literal, final
from typing import Any, Literal, final

from homeassistant.components import conversation, media_source, stt, tts
from homeassistant.components.assist_pipeline import (
Expand All @@ -28,14 +28,12 @@
)
from homeassistant.core import Context, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity
from homeassistant.helpers import chat_session, entity
from homeassistant.helpers.entity import EntityDescription

from .const import AssistSatelliteEntityFeature
from .errors import AssistSatelliteError, SatelliteBusyError

_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes

_LOGGER = logging.getLogger(__name__)


Expand Down Expand Up @@ -114,7 +112,6 @@ class AssistSatelliteEntity(entity.Entity):
_attr_vad_sensitivity_entity_id: str | None = None

_conversation_id: str | None = None
_conversation_id_time: float | None = None

_run_has_tts: bool = False
_is_announcing = False
Expand Down Expand Up @@ -260,6 +257,21 @@ async def async_internal_start_conversation(
else:
self._extra_system_prompt = start_message or None

with (
# Not passing in a conversation ID will force a new one to be created
chat_session.async_get_chat_session(self.hass) as session,
conversation.async_get_chat_log(self.hass, session) as chat_log,
):
self._conversation_id = session.conversation_id

if start_message:
async for _tool_response in chat_log.async_add_assistant_content(
conversation.AssistantContent(
agent_id=self.entity_id, content=start_message
)
):
pass # no tool responses.

try:
await self.async_start_conversation(announcement)
finally:
Expand Down Expand Up @@ -325,51 +337,52 @@ async def async_accept_pipeline_from_satellite(

assert self._context is not None

# Reset conversation id if necessary
if self._conversation_id_time and (
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
):
self._conversation_id = None
self._conversation_id_time = None

# Set entity state based on pipeline events
self._run_has_tts = False

assert self.platform.config_entry is not None
self._pipeline_task = self.platform.config_entry.async_create_background_task(
self.hass,
async_pipeline_from_audio_stream(
self.hass,
context=self._context,
event_callback=self._internal_on_pipeline_event,
stt_metadata=stt.SpeechMetadata(
language="", # set in async_pipeline_from_audio_stream
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_stream,
pipeline_id=self._resolve_pipeline(),
conversation_id=self._conversation_id,
device_id=device_id,
tts_audio_output=self.tts_options,
wake_word_phrase=wake_word_phrase,
audio_settings=AudioSettings(
silence_seconds=self._resolve_vad_sensitivity()
),
start_stage=start_stage,
end_stage=end_stage,
conversation_extra_system_prompt=extra_system_prompt,
),
f"{self.entity_id}_pipeline",
)

try:
await self._pipeline_task
finally:
self._pipeline_task = None
with chat_session.async_get_chat_session(
self.hass, self._conversation_id
) as session:
# Store the conversation ID. If it is no longer valid, get_chat_session will reset it
self._conversation_id = session.conversation_id
self._pipeline_task = (
self.platform.config_entry.async_create_background_task(
self.hass,
async_pipeline_from_audio_stream(
self.hass,
context=self._context,
event_callback=self._internal_on_pipeline_event,
stt_metadata=stt.SpeechMetadata(
language="", # set in async_pipeline_from_audio_stream
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_stream,
pipeline_id=self._resolve_pipeline(),
conversation_id=session.conversation_id,
device_id=device_id,
tts_audio_output=self.tts_options,
wake_word_phrase=wake_word_phrase,
audio_settings=AudioSettings(
silence_seconds=self._resolve_vad_sensitivity()
),
start_stage=start_stage,
end_stage=end_stage,
conversation_extra_system_prompt=extra_system_prompt,
),
f"{self.entity_id}_pipeline",
)
)

try:
await self._pipeline_task
finally:
self._pipeline_task = None

async def _cancel_running_pipeline(self) -> None:
"""Cancel the current pipeline if it's running."""
Expand All @@ -393,11 +406,6 @@ def _internal_on_pipeline_event(self, event: PipelineEvent) -> None:
self._set_state(AssistSatelliteState.LISTENING)
elif event.type is PipelineEventType.INTENT_START:
self._set_state(AssistSatelliteState.PROCESSING)
elif event.type is PipelineEventType.INTENT_END:
assert event.data is not None
# Update timeout
self._conversation_id_time = time.monotonic()
self._conversation_id = event.data["intent_output"]["conversation_id"]
elif event.type is PipelineEventType.TTS_START:
# Wait until tts_response_finished is called to return to waiting state
self._run_has_tts = True
Expand Down
4 changes: 3 additions & 1 deletion tests/components/assist_satellite/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ async def async_start_conversation(
self, start_announcement: AssistSatelliteConfiguration
) -> None:
"""Start a conversation from the satellite."""
self.start_conversations.append((self._extra_system_prompt, start_announcement))
self.start_conversations.append(
(self._conversation_id, self._extra_system_prompt, start_announcement)
)


@pytest.fixture
Expand Down
15 changes: 14 additions & 1 deletion tests/components/assist_satellite/test_entity.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Test the Assist Satellite entity."""

import asyncio
from unittest.mock import patch
from collections.abc import Generator
from unittest.mock import Mock, patch

import pytest

Expand Down Expand Up @@ -31,6 +32,14 @@
from .conftest import MockAssistSatellite


@pytest.fixture
def mock_chat_session_conversation_id() -> Generator[Mock]:
"""Mock the ulid library."""
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
mock_ulid_now.return_value = "mock-conversation-id"
yield mock_ulid_now


@pytest.fixture(autouse=True)
async def set_pipeline_tts(hass: HomeAssistant, init_components: ConfigEntry) -> None:
"""Set up a pipeline with a TTS engine."""
Expand Down Expand Up @@ -487,6 +496,7 @@ async def test_vad_sensitivity_entity_not_found(
"extra_system_prompt": "Better system prompt",
},
(
"mock-conversation-id",
"Better system prompt",
AssistSatelliteAnnouncement(
message="Hello",
Expand All @@ -502,6 +512,7 @@ async def test_vad_sensitivity_entity_not_found(
"start_media_id": "media-source://given",
},
(
"mock-conversation-id",
"Hello",
AssistSatelliteAnnouncement(
message="Hello",
Expand All @@ -514,6 +525,7 @@ async def test_vad_sensitivity_entity_not_found(
(
{"start_media_id": "http://example.com/given.mp3"},
(
"mock-conversation-id",
None,
AssistSatelliteAnnouncement(
message="",
Expand All @@ -525,6 +537,7 @@ async def test_vad_sensitivity_entity_not_found(
),
],
)
@pytest.mark.usefixtures("mock_chat_session_conversation_id")
async def test_start_conversation(
hass: HomeAssistant,
init_components: ConfigEntry,
Expand Down

0 comments on commit 8acab6c

Please sign in to comment.