Skip to content

Commit

Permalink
Fix device name customization on ZHA add devices page (home-assistant…
Browse files Browse the repository at this point in the history
…#25180)

* ensure new device exists

* clean up dev reg handling

* update test

* fix tests
  • Loading branch information
dmulcahey authored and balloob committed Jul 16, 2019
1 parent 56841da commit ac91423
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 13 deletions.
4 changes: 2 additions & 2 deletions homeassistant/components/zha/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ async def async_setup_entry(hass, config_entry):
# pylint: disable=W0611, W0612
import zhaquirks # noqa

zha_gateway = ZHAGateway(hass, config)
await zha_gateway.async_initialize(config_entry)
zha_gateway = ZHAGateway(hass, config, config_entry)
await zha_gateway.async_initialize()

device_registry = await \
hass.helpers.device_registry.async_get_registry()
Expand Down
32 changes: 23 additions & 9 deletions homeassistant/components/zha/core/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from homeassistant.components.system_log import LogEntry, _figure_out_source
from homeassistant.core import callback
from homeassistant.helpers.device_registry import (
async_get_registry as get_dev_reg)
CONNECTION_ZIGBEE, async_get_registry as get_dev_reg)
from homeassistant.helpers.dispatcher import async_dispatcher_send

from ..api import async_get_device_info
Expand Down Expand Up @@ -46,13 +46,14 @@
class ZHAGateway:
"""Gateway that handles events that happen on the ZHA Zigbee network."""

def __init__(self, hass, config):
def __init__(self, hass, config, config_entry):
"""Initialize the gateway."""
self._hass = hass
self._config = config
self._devices = {}
self._device_registry = collections.defaultdict(list)
self.zha_storage = None
self.ha_device_registry = None
self.application_controller = None
self.radio_description = None
hass.data[DATA_ZHA][DATA_ZHA_GATEWAY] = self
Expand All @@ -62,14 +63,16 @@ def __init__(self, hass, config):
}
self.debug_enabled = False
self._log_relay_handler = LogRelayHandler(hass, self)
self._config_entry = config_entry

async def async_initialize(self, config_entry):
async def async_initialize(self):
"""Initialize controller and connect radio."""
self.zha_storage = await async_get_registry(self._hass)
self.ha_device_registry = await get_dev_reg(self._hass)

usb_path = config_entry.data.get(CONF_USB_PATH)
usb_path = self._config_entry.data.get(CONF_USB_PATH)
baudrate = self._config.get(CONF_BAUDRATE, DEFAULT_BAUDRATE)
radio_type = config_entry.data.get(CONF_RADIO_TYPE)
radio_type = self._config_entry.data.get(CONF_RADIO_TYPE)

radio_details = RADIO_TYPES[radio_type][RADIO]()
radio = radio_details[RADIO]
Expand Down Expand Up @@ -147,11 +150,10 @@ async def _async_remove_device(self, device, entity_refs):
for entity_ref in entity_refs:
remove_tasks.append(entity_ref.remove_future)
await asyncio.wait(remove_tasks)
ha_device_registry = await get_dev_reg(self._hass)
reg_device = ha_device_registry.async_get_device(
reg_device = self.ha_device_registry.async_get_device(
{(DOMAIN, str(device.ieee))}, set())
if reg_device is not None:
ha_device_registry.async_remove_device(reg_device.id)
self.ha_device_registry.async_remove_device(reg_device.id)

def device_removed(self, device):
"""Handle device being removed from the network."""
Expand Down Expand Up @@ -241,6 +243,14 @@ def _async_get_or_create_device(self, zigpy_device, is_new_join):
if zha_device is None:
zha_device = ZHADevice(self._hass, zigpy_device, self)
self._devices[zigpy_device.ieee] = zha_device
self.ha_device_registry.async_get_or_create(
config_entry_id=self._config_entry.entry_id,
connections={(CONNECTION_ZIGBEE, str(zha_device.ieee))},
identifiers={(DOMAIN, str(zha_device.ieee))},
name=zha_device.name,
manufacturer=zha_device.manufacturer,
model=zha_device.model
)
if not is_new_join:
entry = self.zha_storage.async_get_or_create(zha_device)
zha_device.async_update_last_seen(entry.last_seen)
Expand Down Expand Up @@ -322,7 +332,11 @@ async def async_device_initialized(self, device, is_new_join):
)

if is_new_join:
device_info = async_get_device_info(self._hass, zha_device)
device_info = async_get_device_info(
self._hass,
zha_device,
self.ha_device_registry
)
async_dispatcher_send(
self._hass,
ZHA_GW_MSG,
Expand Down
8 changes: 6 additions & 2 deletions tests/components/zha/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from homeassistant.components.zha.core.const import (
DOMAIN, DATA_ZHA, COMPONENTS
)
from homeassistant.helpers.device_registry import (
async_get_registry as get_dev_reg)
from homeassistant.components.zha.core.gateway import ZHAGateway
from homeassistant.components.zha.core.registries import \
establish_device_mappings
Expand All @@ -24,7 +26,7 @@ def config_entry_fixture(hass):


@pytest.fixture(name='zha_gateway')
async def zha_gateway_fixture(hass):
async def zha_gateway_fixture(hass, config_entry):
"""Fixture representing a zha gateway.
Create a ZHAGateway object that can be used to interact with as if we
Expand All @@ -37,8 +39,10 @@ async def zha_gateway_fixture(hass):
hass.data[DATA_ZHA].get(component, {})
)
zha_storage = await async_get_registry(hass)
gateway = ZHAGateway(hass, {})
dev_reg = await get_dev_reg(hass)
gateway = ZHAGateway(hass, {}, config_entry)
gateway.zha_storage = zha_storage
gateway.ha_device_registry = dev_reg
return gateway


Expand Down

0 comments on commit ac91423

Please sign in to comment.