Skip to content

Commit

Permalink
Expose async_scanner_devices_by_address from the bluetooth api (home-…
Browse files Browse the repository at this point in the history
…assistant#83733)

Co-authored-by: J. Nick Koston <[email protected]>
fixes undefined
  • Loading branch information
dbuezas authored Jan 9, 2023
1 parent 06a35fb commit 112b2c2
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 46 deletions.
5 changes: 4 additions & 1 deletion homeassistant/components/bluetooth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,10 @@
async_register_scanner,
async_scanner_by_source,
async_scanner_count,
async_scanner_devices_by_address,
async_track_unavailable,
)
from .base_scanner import BaseHaRemoteScanner, BaseHaScanner
from .base_scanner import BaseHaRemoteScanner, BaseHaScanner, BluetoothScannerDevice
from .const import (
BLUETOOTH_DISCOVERY_COOLDOWN_SECONDS,
CONF_ADAPTER,
Expand Down Expand Up @@ -99,6 +100,7 @@
"async_track_unavailable",
"async_scanner_by_source",
"async_scanner_count",
"async_scanner_devices_by_address",
"BaseHaScanner",
"BaseHaRemoteScanner",
"BluetoothCallbackMatcher",
Expand All @@ -107,6 +109,7 @@
"BluetoothServiceInfoBleak",
"BluetoothScanningMode",
"BluetoothCallback",
"BluetoothScannerDevice",
"HaBluetoothConnector",
"SOURCE_LOCAL",
"FALLBACK_MAXIMUM_STALE_ADVERTISEMENT_SECONDS",
Expand Down
10 changes: 9 additions & 1 deletion homeassistant/components/bluetooth/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback as hass_callback

from .base_scanner import BaseHaScanner
from .base_scanner import BaseHaScanner, BluetoothScannerDevice
from .const import DATA_MANAGER
from .manager import BluetoothManager
from .match import BluetoothCallbackMatcher
Expand Down Expand Up @@ -93,6 +93,14 @@ def async_ble_device_from_address(
return _get_manager(hass).async_ble_device_from_address(address, connectable)


@hass_callback
def async_scanner_devices_by_address(
hass: HomeAssistant, address: str, connectable: bool = True
) -> list[BluetoothScannerDevice]:
"""Return all discovered BluetoothScannerDevice for an address."""
return _get_manager(hass).async_scanner_devices_by_address(address, connectable)


@hass_callback
def async_address_present(
hass: HomeAssistant, address: str, connectable: bool = True
Expand Down
10 changes: 10 additions & 0 deletions homeassistant/components/bluetooth/base_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from abc import ABC, abstractmethod
from collections.abc import Callable, Generator
from contextlib import contextmanager
from dataclasses import dataclass
import datetime
from datetime import timedelta
import logging
Expand Down Expand Up @@ -39,6 +40,15 @@
_LOGGER = logging.getLogger(__name__)


@dataclass
class BluetoothScannerDevice:
"""Data for a bluetooth device from a given scanner."""

scanner: BaseHaScanner
ble_device: BLEDevice
advertisement: AdvertisementData


class BaseHaScanner(ABC):
"""Base class for Ha Scanners."""

Expand Down
28 changes: 16 additions & 12 deletions homeassistant/components/bluetooth/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from homeassistant.util.dt import monotonic_time_coarse

from .advertisement_tracker import AdvertisementTracker
from .base_scanner import BaseHaScanner
from .base_scanner import BaseHaScanner, BluetoothScannerDevice
from .const import (
FALLBACK_MAXIMUM_STALE_ADVERTISEMENT_SECONDS,
UNAVAILABLE_TRACK_SECONDS,
Expand Down Expand Up @@ -217,18 +217,22 @@ def async_stop(self, event: Event) -> None:
uninstall_multiple_bleak_catcher()

@hass_callback
def async_get_scanner_discovered_devices_and_advertisement_data_by_address(
def async_scanner_devices_by_address(
self, address: str, connectable: bool
) -> list[tuple[BaseHaScanner, BLEDevice, AdvertisementData]]:
"""Get scanner, devices, and advertisement_data by address."""
types_ = (True,) if connectable else (True, False)
results: list[tuple[BaseHaScanner, BLEDevice, AdvertisementData]] = []
for type_ in types_:
for scanner in self._get_scanners_by_type(type_):
devices_and_adv_data = scanner.discovered_devices_and_advertisement_data
if device_adv_data := devices_and_adv_data.get(address):
results.append((scanner, *device_adv_data))
return results
) -> list[BluetoothScannerDevice]:
"""Get BluetoothScannerDevice by address."""
scanners = self._get_scanners_by_type(True)
if not connectable:
scanners.extend(self._get_scanners_by_type(False))
return [
BluetoothScannerDevice(scanner, *device_adv)
for scanner in scanners
if (
device_adv := scanner.discovered_devices_and_advertisement_data.get(
address
)
)
]

@hass_callback
def _async_all_discovered_addresses(self, connectable: bool) -> Iterable[str]:
Expand Down
46 changes: 16 additions & 30 deletions homeassistant/components/bluetooth/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@
from bleak import BleakClient, BleakError
from bleak.backends.client import BaseBleakClient, get_platform_client_backend_type
from bleak.backends.device import BLEDevice
from bleak.backends.scanner import (
AdvertisementData,
AdvertisementDataCallback,
BaseBleakScanner,
)
from bleak.backends.scanner import AdvertisementDataCallback, BaseBleakScanner
from bleak_retry_connector import (
NO_RSSI_VALUE,
ble_device_description,
Expand All @@ -28,7 +24,7 @@
from homeassistant.helpers.frame import report

from . import models
from .base_scanner import BaseHaScanner
from .base_scanner import BaseHaScanner, BluetoothScannerDevice

FILTER_UUIDS: Final = "UUIDs"
_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -149,9 +145,7 @@ def __del__(self) -> None:


def _rssi_sorter_with_connection_failure_penalty(
scanner_device_advertisement_data: tuple[
BaseHaScanner, BLEDevice, AdvertisementData
],
device: BluetoothScannerDevice,
connection_failure_count: dict[BaseHaScanner, int],
rssi_diff: int,
) -> float:
Expand All @@ -168,9 +162,8 @@ def _rssi_sorter_with_connection_failure_penalty(
best adapter twice before moving on to the next best adapter since
the first failure may be a transient service resolution issue.
"""
scanner, _, advertisement_data = scanner_device_advertisement_data
base_rssi = advertisement_data.rssi or NO_RSSI_VALUE
if connect_failures := connection_failure_count.get(scanner):
base_rssi = device.advertisement.rssi or NO_RSSI_VALUE
if connect_failures := connection_failure_count.get(device.scanner):
if connect_failures > 1 and not rssi_diff:
rssi_diff = 1
return base_rssi - (rssi_diff * connect_failures * 0.51)
Expand Down Expand Up @@ -300,46 +293,39 @@ def _async_get_best_available_backend_and_device(
that has a free connection slot.
"""
address = self.__address
scanner_device_advertisement_datas = manager.async_get_scanner_discovered_devices_and_advertisement_data_by_address( # noqa: E501
address, True
)
sorted_scanner_device_advertisement_datas = sorted(
scanner_device_advertisement_datas,
key=lambda scanner_device_advertisement_data: (
scanner_device_advertisement_data[2].rssi or NO_RSSI_VALUE
),
devices = manager.async_scanner_devices_by_address(self.__address, True)
sorted_devices = sorted(
devices,
key=lambda device: device.advertisement.rssi or NO_RSSI_VALUE,
reverse=True,
)

# If we have connection failures we adjust the rssi sorting
# to prefer the adapter/scanner with the less failures so
# we don't keep trying to connect with an adapter
# that is failing
if (
self.__connect_failures
and len(sorted_scanner_device_advertisement_datas) > 1
):
if self.__connect_failures and len(sorted_devices) > 1:
# We use the rssi diff between to the top two
# to adjust the rssi sorter so that each failure
# will reduce the rssi sorter by the diff amount
rssi_diff = (
sorted_scanner_device_advertisement_datas[0][2].rssi
- sorted_scanner_device_advertisement_datas[1][2].rssi
sorted_devices[0].advertisement.rssi
- sorted_devices[1].advertisement.rssi
)
adjusted_rssi_sorter = partial(
_rssi_sorter_with_connection_failure_penalty,
connection_failure_count=self.__connect_failures,
rssi_diff=rssi_diff,
)
sorted_scanner_device_advertisement_datas = sorted(
scanner_device_advertisement_datas,
sorted_devices = sorted(
devices,
key=adjusted_rssi_sorter,
reverse=True,
)

for (scanner, ble_device, _) in sorted_scanner_device_advertisement_datas:
for device in sorted_devices:
if backend := self._async_get_backend_for_ble_device(
manager, scanner, ble_device
manager, device.scanner, device.ble_device
):
return backend

Expand Down
125 changes: 123 additions & 2 deletions tests/components/bluetooth/test_api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
"""Tests for the Bluetooth integration API."""


from bleak.backends.scanner import AdvertisementData, BLEDevice

from homeassistant.components import bluetooth
from homeassistant.components.bluetooth import async_scanner_by_source
from homeassistant.components.bluetooth import (
BaseHaRemoteScanner,
BaseHaScanner,
HaBluetoothConnector,
async_scanner_by_source,
async_scanner_devices_by_address,
)

from . import FakeScanner
from . import FakeScanner, MockBleakClient, _get_manager, generate_advertisement_data


async def test_scanner_by_source(hass, enable_bluetooth):
Expand All @@ -16,3 +24,116 @@ async def test_scanner_by_source(hass, enable_bluetooth):
assert async_scanner_by_source(hass, "hci2") is hci2_scanner
cancel_hci2()
assert async_scanner_by_source(hass, "hci2") is None


async def test_async_scanner_devices_by_address_connectable(hass, enable_bluetooth):
"""Test getting scanner devices by address with connectable devices."""
manager = _get_manager()

class FakeInjectableScanner(BaseHaRemoteScanner):
def inject_advertisement(
self, device: BLEDevice, advertisement_data: AdvertisementData
) -> None:
"""Inject an advertisement."""
self._async_on_advertisement(
device.address,
advertisement_data.rssi,
device.name,
advertisement_data.service_uuids,
advertisement_data.service_data,
advertisement_data.manufacturer_data,
advertisement_data.tx_power,
{"scanner_specific_data": "test"},
)

new_info_callback = manager.scanner_adv_received
connector = (
HaBluetoothConnector(MockBleakClient, "mock_bleak_client", lambda: False),
)
scanner = FakeInjectableScanner(
hass, "esp32", "esp32", new_info_callback, connector, False
)
unsetup = scanner.async_setup()
cancel = manager.async_register_scanner(scanner, True)
switchbot_device = BLEDevice(
"44:44:33:11:23:45",
"wohand",
{},
rssi=-100,
)
switchbot_device_adv = generate_advertisement_data(
local_name="wohand",
service_uuids=["050a021a-0000-1000-8000-00805f9b34fb"],
service_data={"050a021a-0000-1000-8000-00805f9b34fb": b"\n\xff"},
manufacturer_data={1: b"\x01"},
rssi=-100,
)
scanner.inject_advertisement(switchbot_device, switchbot_device_adv)
assert async_scanner_devices_by_address(
hass, switchbot_device.address, connectable=True
) == async_scanner_devices_by_address(hass, "44:44:33:11:23:45", connectable=False)
devices = async_scanner_devices_by_address(
hass, switchbot_device.address, connectable=False
)
assert len(devices) == 1
assert devices[0].scanner == scanner
assert devices[0].ble_device.name == switchbot_device.name
assert devices[0].advertisement.local_name == switchbot_device_adv.local_name
unsetup()
cancel()


async def test_async_scanner_devices_by_address_non_connectable(hass, enable_bluetooth):
"""Test getting scanner devices by address with non-connectable devices."""
manager = _get_manager()
switchbot_device = BLEDevice(
"44:44:33:11:23:45",
"wohand",
{},
rssi=-100,
)
switchbot_device_adv = generate_advertisement_data(
local_name="wohand",
service_uuids=["050a021a-0000-1000-8000-00805f9b34fb"],
service_data={"050a021a-0000-1000-8000-00805f9b34fb": b"\n\xff"},
manufacturer_data={1: b"\x01"},
rssi=-100,
)

class FakeStaticScanner(BaseHaScanner):
@property
def discovered_devices(self) -> list[BLEDevice]:
"""Return a list of discovered devices."""
return [switchbot_device]

@property
def discovered_devices_and_advertisement_data(
self,
) -> dict[str, tuple[BLEDevice, AdvertisementData]]:
"""Return a list of discovered devices and their advertisement data."""
return {switchbot_device.address: (switchbot_device, switchbot_device_adv)}

connector = (
HaBluetoothConnector(MockBleakClient, "mock_bleak_client", lambda: False),
)
scanner = FakeStaticScanner(hass, "esp32", "esp32", connector)
cancel = manager.async_register_scanner(scanner, False)

assert scanner.discovered_devices_and_advertisement_data == {
switchbot_device.address: (switchbot_device, switchbot_device_adv)
}

assert (
async_scanner_devices_by_address(
hass, switchbot_device.address, connectable=True
)
== []
)
devices = async_scanner_devices_by_address(
hass, switchbot_device.address, connectable=False
)
assert len(devices) == 1
assert devices[0].scanner == scanner
assert devices[0].ble_device.name == switchbot_device.name
assert devices[0].advertisement.local_name == switchbot_device_adv.local_name
cancel()

0 comments on commit 112b2c2

Please sign in to comment.