Skip to content

Commit

Permalink
Add ComponentProtocol to improve type checking (home-assistant#90586)
Browse files Browse the repository at this point in the history
  • Loading branch information
epenet authored Mar 31, 2023
1 parent 03137fe commit 611d413
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 17 deletions.
4 changes: 2 additions & 2 deletions homeassistant/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
)
from .helpers.entity_values import EntityValues
from .helpers.typing import ConfigType
from .loader import Integration, IntegrationNotFound
from .loader import ComponentProtocol, Integration, IntegrationNotFound
from .requirements import RequirementsNotFound, async_get_integration_with_requirements
from .util.package import is_docker_env
from .util.unit_system import get_unit_system, validate_unit_system
Expand Down Expand Up @@ -681,7 +681,7 @@ def _log_pkg_error(package: str, component: str, config: dict, message: str) ->
_LOGGER.error(message)


def _identify_config_schema(module: ModuleType) -> str | None:
def _identify_config_schema(module: ComponentProtocol) -> str | None:
"""Extract the schema and identify list or dict based."""
if not isinstance(module.CONFIG_SCHEMA, vol.Schema):
return None
Expand Down
10 changes: 4 additions & 6 deletions homeassistant/config_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ async def async_setup(
result = await component.async_setup_entry(hass, self)

if not isinstance(result, bool):
_LOGGER.error(
_LOGGER.error( # type: ignore[unreachable]
"%s.async_setup_entry did not return boolean", integration.domain
)
result = False
Expand Down Expand Up @@ -546,8 +546,7 @@ async def async_unload(

await self._async_process_on_unload()

# https://github.com/python/mypy/issues/11839
return result # type: ignore[no-any-return]
return result
except Exception as ex: # pylint: disable=broad-except
_LOGGER.exception(
"Error unloading entry %s for %s", self.title, integration.domain
Expand Down Expand Up @@ -628,15 +627,14 @@ async def async_migrate(self, hass: HomeAssistant) -> bool:
try:
result = await component.async_migrate_entry(hass, self)
if not isinstance(result, bool):
_LOGGER.error(
_LOGGER.error( # type: ignore[unreachable]
"%s.async_migrate_entry did not return boolean", self.domain
)
return False
if result:
# pylint: disable-next=protected-access
hass.config_entries._async_schedule_save()
# https://github.com/python/mypy/issues/11839
return result # type: ignore[no-any-return]
return result
except Exception: # pylint: disable=broad-except
_LOGGER.exception(
"Error migrating entry %s for %s", self.title, self.domain
Expand Down
70 changes: 62 additions & 8 deletions homeassistant/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
import pathlib
import sys
from types import ModuleType
from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeVar, cast
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypedDict, TypeVar, cast

from awesomeversion import (
AwesomeVersion,
AwesomeVersionException,
AwesomeVersionStrategy,
)
import voluptuous as vol

from . import generated
from .generated.application_credentials import APPLICATION_CREDENTIALS
Expand All @@ -35,7 +36,10 @@

# Typing imports that create a circular dependency
if TYPE_CHECKING:
from .config_entries import ConfigEntry
from .core import HomeAssistant
from .helpers import device_registry as dr
from .helpers.typing import ConfigType

_CallableT = TypeVar("_CallableT", bound=Callable[..., Any])

Expand Down Expand Up @@ -260,6 +264,52 @@ async def async_get_config_flows(
return flows


class ComponentProtocol(Protocol):
"""Define the format of an integration."""

CONFIG_SCHEMA: vol.Schema
DOMAIN: str

async def async_setup_entry(
self, hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Set up a config entry."""

async def async_unload_entry(
self, hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Unload a config entry."""

async def async_migrate_entry(
self, hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Migrate an old config entry."""

async def async_remove_entry(
self, hass: HomeAssistant, config_entry: ConfigEntry
) -> None:
"""Remove a config entry."""

async def async_remove_config_entry_device(
self,
hass: HomeAssistant,
config_entry: ConfigEntry,
device_entry: dr.DeviceEntry,
) -> bool:
"""Remove a config entry device."""

async def async_reset_platform(
self, hass: HomeAssistant, integration_name: str
) -> None:
"""Release resources."""

async def async_setup(self, hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up integration."""

def setup(self, hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up integration."""


async def async_get_integration_descriptions(
hass: HomeAssistant,
) -> dict[str, Any]:
Expand Down Expand Up @@ -750,14 +800,18 @@ async def resolve_dependencies(self) -> bool:

return self._all_dependencies_resolved

def get_component(self) -> ModuleType:
def get_component(self) -> ComponentProtocol:
"""Return the component."""
cache: dict[str, ModuleType] = self.hass.data.setdefault(DATA_COMPONENTS, {})
cache: dict[str, ComponentProtocol] = self.hass.data.setdefault(
DATA_COMPONENTS, {}
)
if self.domain in cache:
return cache[self.domain]

try:
cache[self.domain] = importlib.import_module(self.pkg_path)
cache[self.domain] = cast(
ComponentProtocol, importlib.import_module(self.pkg_path)
)
except ImportError:
raise
except Exception as err:
Expand Down Expand Up @@ -922,7 +976,7 @@ def __init__(self, from_domain: str, to_domain: str) -> None:

def _load_file(
hass: HomeAssistant, comp_or_platform: str, base_paths: list[str]
) -> ModuleType | None:
) -> ComponentProtocol | None:
"""Try to load specified file.
Looks in config dir first, then built-in components.
Expand Down Expand Up @@ -957,7 +1011,7 @@ def _load_file(

cache[comp_or_platform] = module

return module
return cast(ComponentProtocol, module)

except ImportError as err:
# This error happens if for example custom_components/switch
Expand All @@ -981,7 +1035,7 @@ def _load_file(
class ModuleWrapper:
"""Class to wrap a Python module and auto fill in hass argument."""

def __init__(self, hass: HomeAssistant, module: ModuleType) -> None:
def __init__(self, hass: HomeAssistant, module: ComponentProtocol) -> None:
"""Initialize the module wrapper."""
self._hass = hass
self._module = module
Expand Down Expand Up @@ -1010,7 +1064,7 @@ def __getattr__(self, comp_name: str) -> ModuleWrapper:
integration = self._hass.data.get(DATA_INTEGRATIONS, {}).get(comp_name)

if isinstance(integration, Integration):
component: ModuleType | None = integration.get_component()
component: ComponentProtocol | None = integration.get_component()
else:
# Fallback to importing old-school
component = _load_file(self._hass, comp_name, _lookup_path(self._hass))
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def log_error(msg: str) -> None:
SLOW_SETUP_WARNING,
)

task = None
task: Awaitable[bool] | None = None
result: Any | bool = True
try:
if hasattr(component, "async_setup"):
Expand Down
8 changes: 8 additions & 0 deletions pylint/plugins/hass_enforce_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,14 @@ class ClassTypeHintMatch:
},
return_type="bool",
),
TypeHintMatch(
function_name="async_reset_platform",
arg_types={
0: "HomeAssistant",
1: "str",
},
return_type=None,
),
],
"__any_platform__": [
TypeHintMatch(
Expand Down

0 comments on commit 611d413

Please sign in to comment.