forked from home-assistant/core
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdiscovery_flow.py
146 lines (117 loc) · 4.51 KB
/
discovery_flow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""The discovery flow helper."""
from __future__ import annotations
from collections.abc import Coroutine
import dataclasses
from typing import TYPE_CHECKING, Any, NamedTuple, Self
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED
from homeassistant.core import CoreState, Event, HomeAssistant, callback
from homeassistant.loader import bind_hass
from homeassistant.util.async_ import gather_with_limited_concurrency
from homeassistant.util.hass_dict import HassKey
if TYPE_CHECKING:
from homeassistant.config_entries import ConfigFlowContext, ConfigFlowResult
FLOW_INIT_LIMIT = 20
DISCOVERY_FLOW_DISPATCHER: HassKey[FlowDispatcher] = HassKey(
"discovery_flow_dispatcher"
)
@dataclasses.dataclass(kw_only=True, slots=True)
class DiscoveryKey:
"""Serializable discovery key."""
domain: str
key: str | tuple[str, ...]
version: int
@classmethod
def from_json_dict(cls, json_dict: dict[str, Any]) -> Self:
"""Construct from JSON dict."""
if type(key := json_dict["key"]) is list:
key = tuple(key)
return cls(domain=json_dict["domain"], key=key, version=json_dict["version"])
@bind_hass
@callback
def async_create_flow(
hass: HomeAssistant,
domain: str,
context: ConfigFlowContext,
data: Any,
*,
discovery_key: DiscoveryKey | None = None,
) -> None:
"""Create a discovery flow."""
dispatcher: FlowDispatcher | None = None
if DISCOVERY_FLOW_DISPATCHER in hass.data:
dispatcher = hass.data[DISCOVERY_FLOW_DISPATCHER]
elif hass.state is not CoreState.running:
dispatcher = hass.data[DISCOVERY_FLOW_DISPATCHER] = FlowDispatcher(hass)
dispatcher.async_setup()
if discovery_key:
context = context | {"discovery_key": discovery_key}
if not dispatcher or dispatcher.started:
if init_coro := _async_init_flow(hass, domain, context, data):
hass.async_create_background_task(
init_coro, f"discovery flow {domain} {context}", eager_start=True
)
return
dispatcher.async_create(domain, context, data)
@callback
def _async_init_flow(
hass: HomeAssistant, domain: str, context: ConfigFlowContext, data: Any
) -> Coroutine[None, None, ConfigFlowResult] | None:
"""Create a discovery flow."""
# Avoid spawning flows that have the same initial discovery data
# as ones in progress as it may cause additional device probing
# which can overload devices since zeroconf/ssdp updates can happen
# multiple times in the same minute
if (
hass.config_entries.flow.async_has_matching_discovery_flow(
domain, context, data
)
or hass.is_stopping
):
return None
return hass.config_entries.flow.async_init(domain, context=context, data=data)
class PendingFlowKey(NamedTuple):
"""Key for pending flows."""
domain: str
source: str
class PendingFlowValue(NamedTuple):
"""Value for pending flows."""
context: ConfigFlowContext
data: Any
class FlowDispatcher:
"""Dispatch discovery flows."""
def __init__(self, hass: HomeAssistant) -> None:
"""Init the discovery dispatcher."""
self.hass = hass
self.started = False
self.pending_flows: dict[PendingFlowKey, list[PendingFlowValue]] = {}
@callback
def async_setup(self) -> None:
"""Set up the flow disptcher."""
self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, self._async_start)
async def _async_start(self, event: Event) -> None:
"""Start processing pending flows."""
pending_flows = self.pending_flows
self.pending_flows = {}
self.started = True
init_coros = (
init_coro
for flow_key, flows in pending_flows.items()
for flow_values in flows
if (
init_coro := _async_init_flow(
self.hass,
flow_key.domain,
flow_values.context,
flow_values.data,
)
)
)
await gather_with_limited_concurrency(FLOW_INIT_LIMIT, *init_coros)
@callback
def async_create(self, domain: str, context: ConfigFlowContext, data: Any) -> None:
"""Create and add or queue a flow."""
key = PendingFlowKey(domain, context["source"])
values = PendingFlowValue(context, data)
existing = self.pending_flows.setdefault(key, [])
if not any(existing_values.data == data for existing_values in existing):
existing.append(values)