Skip to content

Commit

Permalink
Make tts/stt/conversation optional on pipeline (home-assistant#91555)
Browse files Browse the repository at this point in the history
  • Loading branch information
bramkragten authored Apr 17, 2023
1 parent afc9e43 commit e3ff7d0
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 9 deletions.
6 changes: 3 additions & 3 deletions homeassistant/components/assist_pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@
STORAGE_VERSION = 1

STORAGE_FIELDS = {
vol.Required("conversation_engine"): str,
vol.Optional("conversation_engine", default=None): vol.Any(str, None),
vol.Required("language"): str,
vol.Required("name"): str,
vol.Required("stt_engine"): str,
vol.Required("tts_engine"): str,
vol.Optional("stt_engine", default=None): vol.Any(str, None),
vol.Optional("tts_engine", default=None): vol.Any(str, None),
}

STORED_PIPELINE_RUNS = 10
Expand Down
12 changes: 6 additions & 6 deletions tests/components/assist_pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ async def test_load_datasets(hass: HomeAssistant, init_components) -> None:
"tts_engine": "tts_engine_2",
},
{
"conversation_engine": "conversation_engine_3",
"conversation_engine": None,
"language": "language_3",
"name": "name_3",
"stt_engine": "stt_engine_3",
"tts_engine": "tts_engine_3",
"stt_engine": None,
"tts_engine": None,
},
]
pipeline_ids = []
Expand Down Expand Up @@ -91,12 +91,12 @@ async def test_loading_datasets_from_storage(
"tts_engine": "tts_engine_2",
},
{
"conversation_engine": "conversation_engine_3",
"conversation_engine": None,
"id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J",
"language": "language_3",
"name": "name_3",
"stt_engine": "stt_engine_3",
"tts_engine": "tts_engine_3",
"stt_engine": None,
"tts_engine": None,
},
],
"preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
Expand Down
61 changes: 61 additions & 0 deletions tests/components/assist_pipeline/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,35 @@ async def test_add_pipeline(
tts_engine="test_tts_engine",
)

await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/create",
"language": "test_language",
"name": "test_name",
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"conversation_engine": None,
"id": ANY,
"language": "test_language",
"name": "test_name",
"stt_engine": None,
"tts_engine": None,
}

assert len(pipeline_store.data) == 2
pipeline = pipeline_store.data[msg["result"]["id"]]
assert pipeline == Pipeline(
conversation_engine=None,
id=msg["result"]["id"],
language="test_language",
name="test_name",
stt_engine=None,
tts_engine=None,
)


async def test_delete_pipeline(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
Expand Down Expand Up @@ -808,6 +837,38 @@ async def test_update_pipeline(
tts_engine="new_tts_engine",
)

await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/update",
"conversation_engine": None,
"language": "new_language",
"name": "new_name",
"pipeline_id": pipeline_id,
"stt_engine": None,
"tts_engine": None,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"conversation_engine": None,
"id": pipeline_id,
"language": "new_language",
"name": "new_name",
"stt_engine": None,
"tts_engine": None,
}

pipeline = pipeline_store.data[pipeline_id]
assert pipeline == Pipeline(
conversation_engine=None,
id=pipeline_id,
language="new_language",
name="new_name",
stt_engine=None,
tts_engine=None,
)


async def test_set_preferred_pipeline(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
Expand Down

0 comments on commit e3ff7d0

Please sign in to comment.