Skip to content

Commit

Permalink
Deduplicate entities derived from GroupEntity (home-assistant#98893)
Browse files Browse the repository at this point in the history
  • Loading branch information
emontnemery authored Aug 23, 2023
1 parent ee1b6a6 commit 3c10d0e
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 249 deletions.
62 changes: 61 additions & 1 deletion homeassistant/components/group/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from abc import abstractmethod
import asyncio
from collections.abc import Collection, Iterable
from collections.abc import Callable, Collection, Iterable, Mapping
from contextvars import ContextVar
import logging
from typing import Any, Protocol, cast
Expand Down Expand Up @@ -473,9 +473,60 @@ class GroupEntity(Entity):
"""Representation of a Group of entities."""

_attr_should_poll = False
_entity_ids: list[str]

@callback
def async_start_preview(
self,
preview_callback: Callable[[str, Mapping[str, Any]], None],
) -> CALLBACK_TYPE:
"""Render a preview."""

for entity_id in self._entity_ids:
if (state := self.hass.states.get(entity_id)) is None:
continue
self.async_update_supported_features(entity_id, state)

@callback
def async_state_changed_listener(
event: EventType[EventStateChangedData] | None,
) -> None:
"""Handle child updates."""
self.async_update_group_state()
if event:
self.async_update_supported_features(
event.data["entity_id"], event.data["new_state"]
)
preview_callback(*self._async_generate_attributes())

async_state_changed_listener(None)
return async_track_state_change_event(
self.hass, self._entity_ids, async_state_changed_listener
)

async def async_added_to_hass(self) -> None:
"""Register listeners."""
for entity_id in self._entity_ids:
if (state := self.hass.states.get(entity_id)) is None:
continue
self.async_update_supported_features(entity_id, state)

@callback
def async_state_changed_listener(
event: EventType[EventStateChangedData],
) -> None:
"""Handle child updates."""
self.async_set_context(event.context)
self.async_update_supported_features(
event.data["entity_id"], event.data["new_state"]
)
self.async_defer_or_update_ha_state()

self.async_on_remove(
async_track_state_change_event(
self.hass, self._entity_ids, async_state_changed_listener
)
)

async def _update_at_start(_: HomeAssistant) -> None:
self.async_update_group_state()
Expand All @@ -493,9 +544,18 @@ def async_defer_or_update_ha_state(self) -> None:
self.async_write_ha_state()

@abstractmethod
@callback
def async_update_group_state(self) -> None:
"""Abstract method to update the entity."""

@callback
def async_update_supported_features(
self,
entity_id: str,
new_state: State | None,
) -> None:
"""Update dictionaries with supported features."""


class Group(Entity):
"""Track a group of entity ids."""
Expand Down
50 changes: 2 additions & 48 deletions homeassistant/components/group/binary_sensor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
"""Platform allowing several binary sensor to be grouped into one binary sensor."""
from __future__ import annotations

from collections.abc import Callable, Mapping
from typing import Any

import voluptuous as vol

from homeassistant.components.binary_sensor import (
Expand All @@ -24,14 +21,10 @@
STATE_UNAVAILABLE,
STATE_UNKNOWN,
)
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import (
EventStateChangedData,
async_track_state_change_event,
)
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, EventType
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType

from . import GroupEntity

Expand Down Expand Up @@ -116,45 +109,6 @@ def __init__(
if mode:
self.mode = all

@callback
def async_start_preview(
self,
preview_callback: Callable[[str, Mapping[str, Any]], None],
) -> CALLBACK_TYPE:
"""Render a preview."""

@callback
def async_state_changed_listener(
event: EventType[EventStateChangedData] | None,
) -> None:
"""Handle child updates."""
self.async_update_group_state()
preview_callback(*self._async_generate_attributes())

async_state_changed_listener(None)
return async_track_state_change_event(
self.hass, self._entity_ids, async_state_changed_listener
)

async def async_added_to_hass(self) -> None:
"""Register callbacks."""

@callback
def async_state_changed_listener(
event: EventType[EventStateChangedData],
) -> None:
"""Handle child updates."""
self.async_set_context(event.context)
self.async_defer_or_update_ha_state()

self.async_on_remove(
async_track_state_change_event(
self.hass, self._entity_ids, async_state_changed_listener
)
)

await super().async_added_to_hass()

@callback
def async_update_group_state(self) -> None:
"""Query all members and determine the binary sensor group state."""
Expand Down
45 changes: 5 additions & 40 deletions homeassistant/components/group/cover.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,7 @@
from homeassistant.core import HomeAssistant, State, callback
from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import (
EventStateChangedData,
async_track_state_change_event,
)
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, EventType
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType

from . import GroupEntity
from .util import attribute_equal, reduce_attribute
Expand Down Expand Up @@ -112,7 +108,7 @@ class CoverGroup(GroupEntity, CoverEntity):

def __init__(self, unique_id: str | None, name: str, entities: list[str]) -> None:
"""Initialize a CoverGroup entity."""
self._entities = entities
self._entity_ids = entities
self._covers: dict[str, set[str]] = {
KEY_OPEN_CLOSE: set(),
KEY_STOP: set(),
Expand All @@ -128,30 +124,18 @@ def __init__(self, unique_id: str | None, name: str, entities: list[str]) -> Non
self._attr_extra_state_attributes = {ATTR_ENTITY_ID: entities}
self._attr_unique_id = unique_id

@callback
def _update_supported_features_event(
self, event: EventType[EventStateChangedData]
) -> None:
self.async_set_context(event.context)
self.async_update_supported_features(
event.data["entity_id"], event.data["new_state"]
)

@callback
def async_update_supported_features(
self,
entity_id: str,
new_state: State | None,
update_state: bool = True,
) -> None:
"""Update dictionaries with supported features."""
if not new_state:
for values in self._covers.values():
values.discard(entity_id)
for values in self._tilts.values():
values.discard(entity_id)
if update_state:
self.async_defer_or_update_ha_state()
return

features = new_state.attributes.get(ATTR_SUPPORTED_FEATURES, 0)
Expand Down Expand Up @@ -182,25 +166,6 @@ def async_update_supported_features(
else:
self._tilts[KEY_POSITION].discard(entity_id)

if update_state:
self.async_defer_or_update_ha_state()

async def async_added_to_hass(self) -> None:
"""Register listeners."""
for entity_id in self._entities:
if (new_state := self.hass.states.get(entity_id)) is None:
continue
self.async_update_supported_features(
entity_id, new_state, update_state=False
)
self.async_on_remove(
async_track_state_change_event(
self.hass, self._entities, self._update_supported_features_event
)
)

await super().async_added_to_hass()

async def async_open_cover(self, **kwargs: Any) -> None:
"""Move the covers up."""
data = {ATTR_ENTITY_ID: self._covers[KEY_OPEN_CLOSE]}
Expand Down Expand Up @@ -278,7 +243,7 @@ def async_update_group_state(self) -> None:

states = [
state.state
for entity_id in self._entities
for entity_id in self._entity_ids
if (state := self.hass.states.get(entity_id)) is not None
]

Expand All @@ -292,7 +257,7 @@ def async_update_group_state(self) -> None:
self._attr_is_closed = True
self._attr_is_closing = False
self._attr_is_opening = False
for entity_id in self._entities:
for entity_id in self._entity_ids:
if not (state := self.hass.states.get(entity_id)):
continue
if state.state == STATE_OPEN:
Expand Down Expand Up @@ -347,7 +312,7 @@ def async_update_group_state(self) -> None:
self._attr_supported_features = supported_features

if not self._attr_assumed_state:
for entity_id in self._entities:
for entity_id in self._entity_ids:
if (state := self.hass.states.get(entity_id)) is None:
continue
if state and state.attributes.get(ATTR_ASSUMED_STATE):
Expand Down
41 changes: 4 additions & 37 deletions homeassistant/components/group/fan.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,7 @@
from homeassistant.core import HomeAssistant, State, callback
from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import (
EventStateChangedData,
async_track_state_change_event,
)
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, EventType
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType

from . import GroupEntity
from .util import (
Expand Down Expand Up @@ -108,7 +104,7 @@ class FanGroup(GroupEntity, FanEntity):

def __init__(self, unique_id: str | None, name: str, entities: list[str]) -> None:
"""Initialize a FanGroup entity."""
self._entities = entities
self._entity_ids = entities
self._fans: dict[int, set[str]] = {flag: set() for flag in SUPPORTED_FLAGS}
self._percentage = None
self._oscillating = None
Expand Down Expand Up @@ -144,21 +140,11 @@ def oscillating(self) -> bool | None:
"""Return whether or not the fan is currently oscillating."""
return self._oscillating

@callback
def _update_supported_features_event(
self, event: EventType[EventStateChangedData]
) -> None:
self.async_set_context(event.context)
self.async_update_supported_features(
event.data["entity_id"], event.data["new_state"]
)

@callback
def async_update_supported_features(
self,
entity_id: str,
new_state: State | None,
update_state: bool = True,
) -> None:
"""Update dictionaries with supported features."""
if not new_state:
Expand All @@ -172,25 +158,6 @@ def async_update_supported_features(
else:
self._fans[feature].discard(entity_id)

if update_state:
self.async_defer_or_update_ha_state()

async def async_added_to_hass(self) -> None:
"""Register listeners."""
for entity_id in self._entities:
if (new_state := self.hass.states.get(entity_id)) is None:
continue
self.async_update_supported_features(
entity_id, new_state, update_state=False
)
self.async_on_remove(
async_track_state_change_event(
self.hass, self._entities, self._update_supported_features_event
)
)

await super().async_added_to_hass()

async def async_set_percentage(self, percentage: int) -> None:
"""Set the speed of the fan, as a percentage."""
if percentage == 0:
Expand Down Expand Up @@ -250,7 +217,7 @@ async def _async_call_all_entities(self, service: str) -> None:
await self.hass.services.async_call(
DOMAIN,
service,
{ATTR_ENTITY_ID: self._entities},
{ATTR_ENTITY_ID: self._entity_ids},
blocking=True,
context=self._context,
)
Expand All @@ -275,7 +242,7 @@ def async_update_group_state(self) -> None:

states = [
state
for entity_id in self._entities
for entity_id in self._entity_ids
if (state := self.hass.states.get(entity_id)) is not None
]
self._attr_assumed_state |= not states_equal(states)
Expand Down
25 changes: 1 addition & 24 deletions homeassistant/components/group/light.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,7 @@
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import (
EventStateChangedData,
async_track_state_change_event,
)
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, EventType
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType

from . import GroupEntity
from .util import find_state_attributes, mean_tuple, reduce_attribute
Expand Down Expand Up @@ -153,25 +149,6 @@ def __init__(
if mode:
self.mode = all

async def async_added_to_hass(self) -> None:
"""Register callbacks."""

@callback
def async_state_changed_listener(
event: EventType[EventStateChangedData],
) -> None:
"""Handle child updates."""
self.async_set_context(event.context)
self.async_defer_or_update_ha_state()

self.async_on_remove(
async_track_state_change_event(
self.hass, self._entity_ids, async_state_changed_listener
)
)

await super().async_added_to_hass()

async def async_turn_on(self, **kwargs: Any) -> None:
"""Forward the turn_on command to all lights in the light group."""
data = {
Expand Down
Loading

0 comments on commit 3c10d0e

Please sign in to comment.