Skip to content

Commit

Permalink
Add conversation agent to Wyoming (home-assistant#124373)
Browse files Browse the repository at this point in the history
* Add conversation agent to Wyoming

* Remove error

* Remove conversation platform from satellite list

* Clean up

* Update homeassistant/components/wyoming/conversation.py

Co-authored-by: Paulus Schoutsen <[email protected]>

* Remove unnecessary attribute

---------

Co-authored-by: Paulus Schoutsen <[email protected]>
  • Loading branch information
synesthesiam and balloob authored Oct 16, 2024
1 parent bcac851 commit 11ac8f8
Show file tree
Hide file tree
Showing 6 changed files with 553 additions and 1 deletion.
194 changes: 194 additions & 0 deletions homeassistant/components/wyoming/conversation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""Support for Wyoming intent recognition services."""

import logging

from wyoming.asr import Transcript
from wyoming.client import AsyncTcpClient
from wyoming.handle import Handled, NotHandled
from wyoming.info import HandleProgram, IntentProgram
from wyoming.intent import Intent, NotRecognized

from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers import intent
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import ulid

from .const import DOMAIN
from .data import WyomingService
from .error import WyomingError
from .models import DomainDataItem

_LOGGER = logging.getLogger(__name__)


async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up Wyoming conversation."""
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
async_add_entities(
[
WyomingConversationEntity(config_entry, item.service),
]
)


class WyomingConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent
):
"""Wyoming conversation agent."""

_attr_has_entity_name = True

def __init__(
self,
config_entry: ConfigEntry,
service: WyomingService,
) -> None:
"""Set up provider."""
super().__init__()

self.service = service

self._intent_service: IntentProgram | None = None
self._handle_service: HandleProgram | None = None

for maybe_intent in self.service.info.intent:
if maybe_intent.installed:
self._intent_service = maybe_intent
break

for maybe_handle in self.service.info.handle:
if maybe_handle.installed:
self._handle_service = maybe_handle
break

model_languages: set[str] = set()

if self._intent_service is not None:
for intent_model in self._intent_service.models:
if intent_model.installed:
model_languages.update(intent_model.languages)

self._attr_name = self._intent_service.name
self._attr_supported_features = (
conversation.ConversationEntityFeature.CONTROL
)
elif self._handle_service is not None:
for handle_model in self._handle_service.models:
if handle_model.installed:
model_languages.update(handle_model.languages)

self._attr_name = self._handle_service.name

self._supported_languages = list(model_languages)
self._attr_unique_id = f"{config_entry.entry_id}-conversation"

@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return self._supported_languages

async def async_process(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
conversation_id = user_input.conversation_id or ulid.ulid_now()
intent_response = intent.IntentResponse(language=user_input.language)

try:
async with AsyncTcpClient(self.service.host, self.service.port) as client:
await client.write_event(
Transcript(
user_input.text, context={"conversation_id": conversation_id}
).event()
)

while True:
event = await client.read_event()
if event is None:
_LOGGER.debug("Connection lost")
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
"Connection to service was lost",
)
return conversation.ConversationResult(
response=intent_response,
conversation_id=user_input.conversation_id,
)

if Intent.is_type(event.type):
# Success
recognized_intent = Intent.from_event(event)
_LOGGER.debug("Recognized intent: %s", recognized_intent)

intent_type = recognized_intent.name
intent_slots = {
e.name: {"value": e.value}
for e in recognized_intent.entities
}
intent_response = await intent.async_handle(
self.hass,
DOMAIN,
intent_type,
intent_slots,
text_input=user_input.text,
language=user_input.language,
)

if (not intent_response.speech) and recognized_intent.text:
intent_response.async_set_speech(recognized_intent.text)

break

if NotRecognized.is_type(event.type):
not_recognized = NotRecognized.from_event(event)
intent_response.async_set_error(
intent.IntentResponseErrorCode.NO_INTENT_MATCH,
not_recognized.text,
)
break

if Handled.is_type(event.type):
# Success
handled = Handled.from_event(event)
intent_response.async_set_speech(handled.text)
break

if NotHandled.is_type(event.type):
not_handled = NotHandled.from_event(event)
intent_response.async_set_error(
intent.IntentResponseErrorCode.FAILED_TO_HANDLE,
not_handled.text,
)
break

except (OSError, WyomingError) as err:
_LOGGER.exception("Unexpected error while communicating with service")
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Error communicating with service: {err}",
)
return conversation.ConversationResult(
response=intent_response,
conversation_id=user_input.conversation_id,
)
except intent.IntentError as err:
_LOGGER.exception("Unexpected error while handling intent")
intent_response.async_set_error(
intent.IntentResponseErrorCode.FAILED_TO_HANDLE,
f"Error handling intent: {err}",
)
return conversation.ConversationResult(
response=intent_response,
conversation_id=user_input.conversation_id,
)

# Success
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
16 changes: 16 additions & 0 deletions homeassistant/components/wyoming/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,19 @@ def __init__(self, host: str, port: int, info: Info) -> None:
self.platforms.append(Platform.TTS)
if any(wake.installed for wake in info.wake):
self.platforms.append(Platform.WAKE_WORD)
if any(intent.installed for intent in info.intent) or any(
handle.installed for handle in info.handle
):
self.platforms.append(Platform.CONVERSATION)

def has_services(self) -> bool:
"""Return True if services are installed that Home Assistant can use."""
return (
any(asr for asr in self.info.asr if asr.installed)
or any(tts for tts in self.info.tts if tts.installed)
or any(wake for wake in self.info.wake if wake.installed)
or any(intent for intent in self.info.intent if intent.installed)
or any(handle for handle in self.info.handle if handle.installed)
or ((self.info.satellite is not None) and self.info.satellite.installed)
)

Expand All @@ -70,6 +76,16 @@ def get_name(self) -> str | None:
if wake_installed:
return wake_installed[0].name

# intent recognition (text -> intent)
intent_installed = [intent for intent in self.info.intent if intent.installed]
if intent_installed:
return intent_installed[0].name

# intent handling (text -> text)
handle_installed = [handle for handle in self.info.handle if handle.installed]
if handle_installed:
return handle_installed[0].name

return None

@classmethod
Expand Down
46 changes: 46 additions & 0 deletions tests/components/wyoming/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
AsrModel,
AsrProgram,
Attribution,
HandleModel,
HandleProgram,
Info,
IntentModel,
IntentProgram,
Satellite,
TtsProgram,
TtsVoice,
Expand Down Expand Up @@ -87,6 +91,48 @@
)
]
)
INTENT_INFO = Info(
intent=[
IntentProgram(
name="Test Intent",
description="Test Intent",
installed=True,
attribution=TEST_ATTR,
models=[
IntentModel(
name="Test Model",
description="Test Model",
installed=True,
attribution=TEST_ATTR,
languages=["en-US"],
version=None,
)
],
version=None,
)
]
)
HANDLE_INFO = Info(
handle=[
HandleProgram(
name="Test Handle",
description="Test Handle",
installed=True,
attribution=TEST_ATTR,
models=[
HandleModel(
name="Test Model",
description="Test Model",
installed=True,
attribution=TEST_ATTR,
languages=["en-US"],
version=None,
)
],
version=None,
)
]
)
SATELLITE_INFO = Info(
satellite=Satellite(
name="Test Satellite",
Expand Down
67 changes: 66 additions & 1 deletion tests/components/wyoming/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component

from . import SATELLITE_INFO, STT_INFO, TTS_INFO, WAKE_WORD_INFO
from . import (
HANDLE_INFO,
INTENT_INFO,
SATELLITE_INFO,
STT_INFO,
TTS_INFO,
WAKE_WORD_INFO,
)

from tests.common import MockConfigEntry

Expand Down Expand Up @@ -83,6 +90,36 @@ def wake_word_config_entry(hass: HomeAssistant) -> ConfigEntry:
return entry


@pytest.fixture
def intent_config_entry(hass: HomeAssistant) -> ConfigEntry:
"""Create a config entry."""
entry = MockConfigEntry(
domain="wyoming",
data={
"host": "1.2.3.4",
"port": 1234,
},
title="Test Intent",
)
entry.add_to_hass(hass)
return entry


@pytest.fixture
def handle_config_entry(hass: HomeAssistant) -> ConfigEntry:
"""Create a config entry."""
entry = MockConfigEntry(
domain="wyoming",
data={
"host": "1.2.3.4",
"port": 1234,
},
title="Test Handle",
)
entry.add_to_hass(hass)
return entry


@pytest.fixture
async def init_wyoming_stt(hass: HomeAssistant, stt_config_entry: ConfigEntry):
"""Initialize Wyoming STT."""
Expand Down Expand Up @@ -115,6 +152,34 @@ async def init_wyoming_wake_word(
await hass.config_entries.async_setup(wake_word_config_entry.entry_id)


@pytest.fixture
async def init_wyoming_intent(
hass: HomeAssistant, intent_config_entry: ConfigEntry
) -> ConfigEntry:
"""Initialize Wyoming intent recognizer."""
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=INTENT_INFO,
):
await hass.config_entries.async_setup(intent_config_entry.entry_id)

return intent_config_entry


@pytest.fixture
async def init_wyoming_handle(
hass: HomeAssistant, handle_config_entry: ConfigEntry
) -> ConfigEntry:
"""Initialize Wyoming intent handler."""
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=HANDLE_INFO,
):
await hass.config_entries.async_setup(handle_config_entry.entry_id)

return handle_config_entry


@pytest.fixture
def metadata(hass: HomeAssistant) -> stt.SpeechMetadata:
"""Get default STT metadata."""
Expand Down
7 changes: 7 additions & 0 deletions tests/components/wyoming/snapshots/test_conversation.ambr
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# serializer version: 1
# name: test_connection_lost
'Connection to service was lost'
# ---
# name: test_oserror
'Error communicating with service: Boom!'
# ---
Loading

0 comments on commit 11ac8f8

Please sign in to comment.