Skip to content

Commit

Permalink
Make sure entity platform services work for all platforms of d… (home…
Browse files Browse the repository at this point in the history
…-assistant#33176)

* Make sure entity platform services work for all platforms of domain

* Register a bad service handler

* Fix cleaning up

* Tiny cleanup
  • Loading branch information
balloob authored Mar 23, 2020
1 parent 2360fd4 commit 1ff245d
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 9 deletions.
8 changes: 7 additions & 1 deletion homeassistant/helpers/entity_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,13 @@ async def _async_reset(self) -> None:
This method must be run in the event loop.
"""
tasks = [platform.async_reset() for platform in self._platforms.values()]
tasks = []

for key, platform in self._platforms.items():
if key == self.domain:
tasks.append(platform.async_reset())
else:
tasks.append(platform.async_destroy())

if tasks:
await asyncio.wait(tasks)
Expand Down
35 changes: 27 additions & 8 deletions homeassistant/helpers/entity_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
SLOW_SETUP_WARNING = 10
SLOW_SETUP_MAX_WAIT = 60
PLATFORM_NOT_READY_RETRIES = 10
DATA_ENTITY_PLATFORM = "entity_platform"


class EntityPlatform:
Expand Down Expand Up @@ -57,15 +58,15 @@ def __init__(
self._async_cancel_retry_setup: Optional[CALLBACK_TYPE] = None
self._process_updates: Optional[asyncio.Lock] = None

self.parallel_updates: Optional[asyncio.Semaphore] = None

# Platform is None for the EntityComponent "catch-all" EntityPlatform
# which powers entity_component.add_entities
if platform is None:
self.parallel_updates_created = True
self.parallel_updates: Optional[asyncio.Semaphore] = None
return
self.parallel_updates_created = platform is None

self.parallel_updates_created = False
self.parallel_updates = None
hass.data.setdefault(DATA_ENTITY_PLATFORM, {}).setdefault(
self.platform_name, []
).append(self)

@callback
def _get_parallel_updates_semaphore(
Expand Down Expand Up @@ -464,6 +465,14 @@ async def async_reset(self) -> None:
self._async_unsub_polling()
self._async_unsub_polling = None

async def async_destroy(self) -> None:
"""Destroy an entity platform.
Call before discarding the object.
"""
await self.async_reset()
self.hass.data[DATA_ENTITY_PLATFORM][self.platform_name].remove(self)

async def async_remove_entity(self, entity_id: str) -> None:
"""Remove entity id from platform."""
await self.entities[entity_id].async_remove()
Expand All @@ -488,14 +497,24 @@ async def async_extract_from_service(self, service_call, expand_group=True):

@callback
def async_register_entity_service(self, name, schema, func, required_features=None):
"""Register an entity service."""
"""Register an entity service.
Services will automatically be shared by all platforms of the same domain.
"""
if self.hass.services.has_service(self.platform_name, name):
return

if isinstance(schema, dict):
schema = cv.make_entity_service_schema(schema)

async def handle_service(call):
"""Handle the service."""
await service.entity_service_call(
self.hass, [self], func, call, required_features
self.hass,
self.hass.data[DATA_ENTITY_PLATFORM][self.platform_name],
func,
call,
required_features,
)

self.hass.services.async_register(
Expand Down
35 changes: 35 additions & 0 deletions tests/helpers/test_entity_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest

from homeassistant.const import UNIT_PERCENTAGE
from homeassistant.core import callback
from homeassistant.exceptions import PlatformNotReady
from homeassistant.helpers import entity_platform, entity_registry
from homeassistant.helpers.entity import async_generate_entity_id
Expand Down Expand Up @@ -847,3 +848,37 @@ async def test_platform_with_no_setup(hass, caplog):
"The mock-platform platform for the mock-integration integration does not support platform setup."
in caplog.text
)


async def test_platforms_sharing_services(hass):
"""Test platforms share services."""
entity_platform1 = MockEntityPlatform(
hass, domain="mock_integration", platform_name="mock_platform", platform=None
)
entity1 = MockEntity(entity_id="mock_integration.entity_1")
await entity_platform1.async_add_entities([entity1])

entity_platform2 = MockEntityPlatform(
hass, domain="mock_integration", platform_name="mock_platform", platform=None
)
entity2 = MockEntity(entity_id="mock_integration.entity_2")
await entity_platform2.async_add_entities([entity2])

entities = []

@callback
def handle_service(entity, data):
entities.append(entity)

entity_platform1.async_register_entity_service("hello", {}, handle_service)
entity_platform2.async_register_entity_service(
"hello", {}, Mock(side_effect=AssertionError("Should not be called"))
)

await hass.services.async_call(
"mock_platform", "hello", {"entity_id": "all"}, blocking=True
)

assert len(entities) == 2
assert entity1 in entities
assert entity2 in entities

0 comments on commit 1ff245d

Please sign in to comment.