forked from home-assistant/core
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcommon.py
1208 lines (941 loc) · 37.2 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Test the helper method for writing tests."""
from __future__ import annotations
import asyncio
import collections
from collections import OrderedDict
from contextlib import contextmanager
from datetime import datetime, timedelta
import functools as ft
from io import StringIO
import json
import logging
import os
import pathlib
import threading
import time
from time import monotonic
import types
from typing import Any, Awaitable, Collection
from unittest.mock import AsyncMock, Mock, patch
from aiohttp.test_utils import unused_port as get_test_instance_port # noqa: F401
from homeassistant import auth, config_entries, core as ha, loader
from homeassistant.auth import (
auth_store,
models as auth_models,
permissions as auth_permissions,
providers as auth_providers,
)
from homeassistant.auth.permissions import system_policies
from homeassistant.components import device_automation, recorder
from homeassistant.components.device_automation import ( # noqa: F401
_async_get_device_automation_capabilities as async_get_device_automation_capabilities,
)
from homeassistant.components.mqtt.models import ReceiveMessage
from homeassistant.config import async_process_component_config
from homeassistant.const import (
DEVICE_DEFAULT_NAME,
EVENT_HOMEASSISTANT_CLOSE,
EVENT_STATE_CHANGED,
EVENT_TIME_CHANGED,
STATE_OFF,
STATE_ON,
)
from homeassistant.core import BLOCK_LOG_TIMEOUT, HomeAssistant, State
from homeassistant.helpers import (
area_registry,
device_registry,
entity,
entity_platform,
entity_registry,
intent,
restore_state,
storage,
)
from homeassistant.helpers.json import JSONEncoder
from homeassistant.setup import async_setup_component, setup_component
from homeassistant.util.async_ import run_callback_threadsafe
import homeassistant.util.dt as date_util
from homeassistant.util.unit_system import METRIC_SYSTEM
import homeassistant.util.uuid as uuid_util
import homeassistant.util.yaml.loader as yaml_loader
_LOGGER = logging.getLogger(__name__)
INSTANCES = []
CLIENT_ID = "https://example.com/app"
CLIENT_REDIRECT_URI = "https://example.com/app/callback"
async def async_get_device_automations(
hass: HomeAssistant, automation_type: str, device_id: str
) -> Any:
"""Get a device automation for a single device id."""
automations = await device_automation.async_get_device_automations(
hass, automation_type, [device_id]
)
return automations.get(device_id)
def threadsafe_callback_factory(func):
"""Create threadsafe functions out of callbacks.
Callback needs to have `hass` as first argument.
"""
@ft.wraps(func)
def threadsafe(*args, **kwargs):
"""Call func threadsafe."""
hass = args[0]
return run_callback_threadsafe(
hass.loop, ft.partial(func, *args, **kwargs)
).result()
return threadsafe
def threadsafe_coroutine_factory(func):
"""Create threadsafe functions out of coroutine.
Callback needs to have `hass` as first argument.
"""
@ft.wraps(func)
def threadsafe(*args, **kwargs):
"""Call func threadsafe."""
hass = args[0]
return asyncio.run_coroutine_threadsafe(
func(*args, **kwargs), hass.loop
).result()
return threadsafe
def get_test_config_dir(*add_path):
"""Return a path to a test config dir."""
return os.path.join(os.path.dirname(__file__), "testing_config", *add_path)
def get_test_home_assistant():
"""Return a Home Assistant object pointing at test config directory."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
hass = loop.run_until_complete(async_test_home_assistant(loop))
loop_stop_event = threading.Event()
def run_loop():
"""Run event loop."""
# pylint: disable=protected-access
loop._thread_ident = threading.get_ident()
loop.run_forever()
loop_stop_event.set()
orig_stop = hass.stop
hass._stopped = Mock(set=loop.stop)
def start_hass(*mocks):
"""Start hass."""
asyncio.run_coroutine_threadsafe(hass.async_start(), loop).result()
def stop_hass():
"""Stop hass."""
orig_stop()
loop_stop_event.wait()
loop.close()
hass.start = start_hass
hass.stop = stop_hass
threading.Thread(name="LoopThread", target=run_loop, daemon=False).start()
return hass
# pylint: disable=protected-access
async def async_test_home_assistant(loop, load_registries=True):
"""Return a Home Assistant object pointing at test config dir."""
hass = ha.HomeAssistant()
store = auth_store.AuthStore(hass)
hass.auth = auth.AuthManager(hass, store, {}, {})
ensure_auth_manager_loaded(hass.auth)
INSTANCES.append(hass)
orig_async_add_job = hass.async_add_job
orig_async_add_executor_job = hass.async_add_executor_job
orig_async_create_task = hass.async_create_task
def async_add_job(target, *args):
"""Add job."""
check_target = target
while isinstance(check_target, ft.partial):
check_target = check_target.func
if isinstance(check_target, Mock) and not isinstance(target, AsyncMock):
fut = asyncio.Future()
fut.set_result(target(*args))
return fut
return orig_async_add_job(target, *args)
def async_add_executor_job(target, *args):
"""Add executor job."""
check_target = target
while isinstance(check_target, ft.partial):
check_target = check_target.func
if isinstance(check_target, Mock):
fut = asyncio.Future()
fut.set_result(target(*args))
return fut
return orig_async_add_executor_job(target, *args)
def async_create_task(coroutine):
"""Create task."""
if isinstance(coroutine, Mock) and not isinstance(coroutine, AsyncMock):
fut = asyncio.Future()
fut.set_result(None)
return fut
return orig_async_create_task(coroutine)
async def async_wait_for_task_count(self, max_remaining_tasks: int = 0) -> None:
"""Block until at most max_remaining_tasks remain.
Based on HomeAssistant.async_block_till_done
"""
# To flush out any call_soon_threadsafe
await asyncio.sleep(0)
start_time: float | None = None
while len(self._pending_tasks) > max_remaining_tasks:
pending: Collection[Awaitable[Any]] = [
task for task in self._pending_tasks if not task.done()
]
self._pending_tasks.clear()
if len(pending) > max_remaining_tasks:
remaining_pending = await self._await_count_and_log_pending(
pending, max_remaining_tasks=max_remaining_tasks
)
self._pending_tasks.extend(remaining_pending)
if start_time is None:
# Avoid calling monotonic() until we know
# we may need to start logging blocked tasks.
start_time = 0
elif start_time == 0:
# If we have waited twice then we set the start
# time
start_time = monotonic()
elif monotonic() - start_time > BLOCK_LOG_TIMEOUT:
# We have waited at least three loops and new tasks
# continue to block. At this point we start
# logging all waiting tasks.
for task in pending:
_LOGGER.debug("Waiting for task: %s", task)
else:
self._pending_tasks.extend(pending)
await asyncio.sleep(0)
async def _await_count_and_log_pending(
self, pending: Collection[Awaitable[Any]], max_remaining_tasks: int = 0
) -> Collection[Awaitable[Any]]:
"""Block at most max_remaining_tasks remain and log tasks that take a long time.
Based on HomeAssistant._await_and_log_pending
"""
wait_time = 0
return_when = asyncio.ALL_COMPLETED
if max_remaining_tasks:
return_when = asyncio.FIRST_COMPLETED
while len(pending) > max_remaining_tasks:
_, pending = await asyncio.wait(
pending, timeout=BLOCK_LOG_TIMEOUT, return_when=return_when
)
if not pending or max_remaining_tasks:
return pending
wait_time += BLOCK_LOG_TIMEOUT
for task in pending:
_LOGGER.debug("Waited %s seconds for task: %s", wait_time, task)
return []
hass.async_add_job = async_add_job
hass.async_add_executor_job = async_add_executor_job
hass.async_create_task = async_create_task
hass.async_wait_for_task_count = types.MethodType(async_wait_for_task_count, hass)
hass._await_count_and_log_pending = types.MethodType(
_await_count_and_log_pending, hass
)
hass.data[loader.DATA_CUSTOM_COMPONENTS] = {}
hass.config.location_name = "test home"
hass.config.config_dir = get_test_config_dir()
hass.config.latitude = 32.87336
hass.config.longitude = -117.22743
hass.config.elevation = 0
hass.config.time_zone = "US/Pacific"
hass.config.units = METRIC_SYSTEM
hass.config.media_dirs = {"local": get_test_config_dir("media")}
hass.config.skip_pip = True
hass.config_entries = config_entries.ConfigEntries(hass, {})
hass.config_entries._entries = {}
hass.config_entries._store._async_ensure_stop_listener = lambda: None
# Load the registries
if load_registries:
await asyncio.gather(
device_registry.async_load(hass),
entity_registry.async_load(hass),
area_registry.async_load(hass),
)
await hass.async_block_till_done()
hass.state = ha.CoreState.running
# Mock async_start
orig_start = hass.async_start
async def mock_async_start():
"""Start the mocking."""
# We only mock time during tests and we want to track tasks
with patch("homeassistant.core._async_create_timer"), patch.object(
hass, "async_stop_track_tasks"
):
await orig_start()
hass.async_start = mock_async_start
@ha.callback
def clear_instance(event):
"""Clear global instance."""
INSTANCES.remove(hass)
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, clear_instance)
return hass
def async_mock_service(hass, domain, service, schema=None):
"""Set up a fake service & return a calls log list to this service."""
calls = []
@ha.callback
def mock_service_log(call): # pylint: disable=unnecessary-lambda
"""Mock service call."""
calls.append(call)
hass.services.async_register(domain, service, mock_service_log, schema=schema)
return calls
mock_service = threadsafe_callback_factory(async_mock_service)
@ha.callback
def async_mock_intent(hass, intent_typ):
"""Set up a fake intent handler."""
intents = []
class MockIntentHandler(intent.IntentHandler):
intent_type = intent_typ
async def async_handle(self, intent):
"""Handle the intent."""
intents.append(intent)
return intent.create_response()
intent.async_register(hass, MockIntentHandler())
return intents
@ha.callback
def async_fire_mqtt_message(hass, topic, payload, qos=0, retain=False):
"""Fire the MQTT message."""
if isinstance(payload, str):
payload = payload.encode("utf-8")
msg = ReceiveMessage(topic, payload, qos, retain)
hass.data["mqtt"]._mqtt_handle_message(msg)
fire_mqtt_message = threadsafe_callback_factory(async_fire_mqtt_message)
@ha.callback
def async_fire_time_changed(
hass: HomeAssistant, datetime_: datetime, fire_all: bool = False
) -> None:
"""Fire a time changes event."""
hass.bus.async_fire(EVENT_TIME_CHANGED, {"now": date_util.as_utc(datetime_)})
for task in list(hass.loop._scheduled):
if not isinstance(task, asyncio.TimerHandle):
continue
if task.cancelled():
continue
mock_seconds_into_future = datetime_.timestamp() - time.time()
future_seconds = task.when() - hass.loop.time()
if fire_all or mock_seconds_into_future >= future_seconds:
with patch(
"homeassistant.helpers.event.time_tracker_utcnow",
return_value=date_util.as_utc(datetime_),
):
task._run()
task.cancel()
fire_time_changed = threadsafe_callback_factory(async_fire_time_changed)
def load_fixture(filename):
"""Load a fixture."""
path = os.path.join(os.path.dirname(__file__), "fixtures", filename)
with open(path, encoding="utf-8") as fptr:
return fptr.read()
def mock_state_change_event(hass, new_state, old_state=None):
"""Mock state change envent."""
event_data = {"entity_id": new_state.entity_id, "new_state": new_state}
if old_state:
event_data["old_state"] = old_state
hass.bus.fire(EVENT_STATE_CHANGED, event_data, context=new_state.context)
@ha.callback
def mock_component(hass, component):
"""Mock a component is setup."""
if component in hass.config.components:
AssertionError(f"Integration {component} is already setup")
hass.config.components.add(component)
def mock_registry(hass, mock_entries=None):
"""Mock the Entity Registry."""
registry = entity_registry.EntityRegistry(hass)
registry.entities = mock_entries or OrderedDict()
registry._rebuild_index()
hass.data[entity_registry.DATA_REGISTRY] = registry
return registry
def mock_area_registry(hass, mock_entries=None):
"""Mock the Area Registry."""
registry = area_registry.AreaRegistry(hass)
registry.areas = mock_entries or OrderedDict()
hass.data[area_registry.DATA_REGISTRY] = registry
return registry
def mock_device_registry(hass, mock_entries=None, mock_deleted_entries=None):
"""Mock the Device Registry."""
registry = device_registry.DeviceRegistry(hass)
registry.devices = mock_entries or OrderedDict()
registry.deleted_devices = mock_deleted_entries or OrderedDict()
registry._rebuild_index()
hass.data[device_registry.DATA_REGISTRY] = registry
return registry
class MockGroup(auth_models.Group):
"""Mock a group in Home Assistant."""
def __init__(self, id=None, name="Mock Group", policy=system_policies.ADMIN_POLICY):
"""Mock a group."""
kwargs = {"name": name, "policy": policy}
if id is not None:
kwargs["id"] = id
super().__init__(**kwargs)
def add_to_hass(self, hass):
"""Test helper to add entry to hass."""
return self.add_to_auth_manager(hass.auth)
def add_to_auth_manager(self, auth_mgr):
"""Test helper to add entry to hass."""
ensure_auth_manager_loaded(auth_mgr)
auth_mgr._store._groups[self.id] = self
return self
class MockUser(auth_models.User):
"""Mock a user in Home Assistant."""
def __init__(
self,
id=None,
is_owner=False,
is_active=True,
name="Mock User",
system_generated=False,
groups=None,
):
"""Initialize mock user."""
kwargs = {
"is_owner": is_owner,
"is_active": is_active,
"name": name,
"system_generated": system_generated,
"groups": groups or [],
"perm_lookup": None,
}
if id is not None:
kwargs["id"] = id
super().__init__(**kwargs)
def add_to_hass(self, hass):
"""Test helper to add entry to hass."""
return self.add_to_auth_manager(hass.auth)
def add_to_auth_manager(self, auth_mgr):
"""Test helper to add entry to hass."""
ensure_auth_manager_loaded(auth_mgr)
auth_mgr._store._users[self.id] = self
return self
def mock_policy(self, policy):
"""Mock a policy for a user."""
self._permissions = auth_permissions.PolicyPermissions(policy, self.perm_lookup)
async def register_auth_provider(hass, config):
"""Register an auth provider."""
provider = await auth_providers.auth_provider_from_config(
hass, hass.auth._store, config
)
assert provider is not None, "Invalid config specified"
key = (provider.type, provider.id)
providers = hass.auth._providers
if key in providers:
raise ValueError("Provider already registered")
providers[key] = provider
return provider
@ha.callback
def ensure_auth_manager_loaded(auth_mgr):
"""Ensure an auth manager is considered loaded."""
store = auth_mgr._store
if store._users is None:
store._set_defaults()
class MockModule:
"""Representation of a fake module."""
# pylint: disable=invalid-name
def __init__(
self,
domain=None,
dependencies=None,
setup=None,
requirements=None,
config_schema=None,
platform_schema=None,
platform_schema_base=None,
async_setup=None,
async_setup_entry=None,
async_unload_entry=None,
async_migrate_entry=None,
async_remove_entry=None,
partial_manifest=None,
):
"""Initialize the mock module."""
self.__name__ = f"homeassistant.components.{domain}"
self.__file__ = f"homeassistant/components/{domain}"
self.DOMAIN = domain
self.DEPENDENCIES = dependencies or []
self.REQUIREMENTS = requirements or []
# Overlay to be used when generating manifest from this module
self._partial_manifest = partial_manifest
if config_schema is not None:
self.CONFIG_SCHEMA = config_schema
if platform_schema is not None:
self.PLATFORM_SCHEMA = platform_schema
if platform_schema_base is not None:
self.PLATFORM_SCHEMA_BASE = platform_schema_base
if setup:
# We run this in executor, wrap it in function
self.setup = lambda *args: setup(*args)
if async_setup is not None:
self.async_setup = async_setup
if setup is None and async_setup is None:
self.async_setup = AsyncMock(return_value=True)
if async_setup_entry is not None:
self.async_setup_entry = async_setup_entry
if async_unload_entry is not None:
self.async_unload_entry = async_unload_entry
if async_migrate_entry is not None:
self.async_migrate_entry = async_migrate_entry
if async_remove_entry is not None:
self.async_remove_entry = async_remove_entry
def mock_manifest(self):
"""Generate a mock manifest to represent this module."""
return {
**loader.manifest_from_legacy_module(self.DOMAIN, self),
**(self._partial_manifest or {}),
}
class MockPlatform:
"""Provide a fake platform."""
__name__ = "homeassistant.components.light.bla"
__file__ = "homeassistant/components/blah/light"
# pylint: disable=invalid-name
def __init__(
self,
setup_platform=None,
dependencies=None,
platform_schema=None,
async_setup_platform=None,
async_setup_entry=None,
scan_interval=None,
):
"""Initialize the platform."""
self.DEPENDENCIES = dependencies or []
if platform_schema is not None:
self.PLATFORM_SCHEMA = platform_schema
if scan_interval is not None:
self.SCAN_INTERVAL = scan_interval
if setup_platform is not None:
# We run this in executor, wrap it in function
self.setup_platform = lambda *args: setup_platform(*args)
if async_setup_platform is not None:
self.async_setup_platform = async_setup_platform
if async_setup_entry is not None:
self.async_setup_entry = async_setup_entry
if setup_platform is None and async_setup_platform is None:
self.async_setup_platform = AsyncMock(return_value=None)
class MockEntityPlatform(entity_platform.EntityPlatform):
"""Mock class with some mock defaults."""
def __init__(
self,
hass,
logger=None,
domain="test_domain",
platform_name="test_platform",
platform=None,
scan_interval=timedelta(seconds=15),
entity_namespace=None,
):
"""Initialize a mock entity platform."""
if logger is None:
logger = logging.getLogger("homeassistant.helpers.entity_platform")
# Otherwise the constructor will blow up.
if isinstance(platform, Mock) and isinstance(platform.PARALLEL_UPDATES, Mock):
platform.PARALLEL_UPDATES = 0
super().__init__(
hass=hass,
logger=logger,
domain=domain,
platform_name=platform_name,
platform=platform,
scan_interval=scan_interval,
entity_namespace=entity_namespace,
)
class MockToggleEntity(entity.ToggleEntity):
"""Provide a mock toggle device."""
def __init__(self, name, state, unique_id=None):
"""Initialize the mock entity."""
self._name = name or DEVICE_DEFAULT_NAME
self._state = state
self.calls = []
@property
def name(self):
"""Return the name of the entity if any."""
self.calls.append(("name", {}))
return self._name
@property
def state(self):
"""Return the state of the entity if any."""
self.calls.append(("state", {}))
return self._state
@property
def is_on(self):
"""Return true if entity is on."""
self.calls.append(("is_on", {}))
return self._state == STATE_ON
def turn_on(self, **kwargs):
"""Turn the entity on."""
self.calls.append(("turn_on", kwargs))
self._state = STATE_ON
def turn_off(self, **kwargs):
"""Turn the entity off."""
self.calls.append(("turn_off", kwargs))
self._state = STATE_OFF
def last_call(self, method=None):
"""Return the last call."""
if not self.calls:
return None
if method is None:
return self.calls[-1]
try:
return next(call for call in reversed(self.calls) if call[0] == method)
except StopIteration:
return None
class MockConfigEntry(config_entries.ConfigEntry):
"""Helper for creating config entries that adds some defaults."""
def __init__(
self,
*,
domain="test",
data=None,
version=1,
entry_id=None,
source=config_entries.SOURCE_USER,
title="Mock Title",
state=None,
options={},
pref_disable_new_entities=None,
pref_disable_polling=None,
unique_id=None,
disabled_by=None,
reason=None,
):
"""Initialize a mock config entry."""
kwargs = {
"entry_id": entry_id or uuid_util.random_uuid_hex(),
"domain": domain,
"data": data or {},
"pref_disable_new_entities": pref_disable_new_entities,
"pref_disable_polling": pref_disable_polling,
"options": options,
"version": version,
"title": title,
"unique_id": unique_id,
"disabled_by": disabled_by,
}
if source is not None:
kwargs["source"] = source
if state is not None:
kwargs["state"] = state
super().__init__(**kwargs)
if reason is not None:
self.reason = reason
def add_to_hass(self, hass):
"""Test helper to add entry to hass."""
hass.config_entries._entries[self.entry_id] = self
def add_to_manager(self, manager):
"""Test helper to add entry to entry manager."""
manager._entries[self.entry_id] = self
def patch_yaml_files(files_dict, endswith=True):
"""Patch load_yaml with a dictionary of yaml files."""
# match using endswith, start search with longest string
matchlist = sorted(files_dict.keys(), key=len) if endswith else []
def mock_open_f(fname, **_):
"""Mock open() in the yaml module, used by load_yaml."""
# Return the mocked file on full match
if isinstance(fname, pathlib.Path):
fname = str(fname)
if fname in files_dict:
_LOGGER.debug("patch_yaml_files match %s", fname)
res = StringIO(files_dict[fname])
setattr(res, "name", fname)
return res
# Match using endswith
for ends in matchlist:
if fname.endswith(ends):
_LOGGER.debug("patch_yaml_files end match %s: %s", ends, fname)
res = StringIO(files_dict[ends])
setattr(res, "name", fname)
return res
# Fallback for hass.components (i.e. services.yaml)
if "homeassistant/components" in fname:
_LOGGER.debug("patch_yaml_files using real file: %s", fname)
return open(fname, encoding="utf-8")
# Not found
raise FileNotFoundError(f"File not found: {fname}")
return patch.object(yaml_loader, "open", mock_open_f, create=True)
def mock_coro(return_value=None, exception=None):
"""Return a coro that returns a value or raise an exception."""
fut = asyncio.Future()
if exception is not None:
fut.set_exception(exception)
else:
fut.set_result(return_value)
return fut
@contextmanager
def assert_setup_component(count, domain=None):
"""Collect valid configuration from setup_component.
- count: The amount of valid platforms that should be setup
- domain: The domain to count is optional. It can be automatically
determined most of the time
Use as a context manager around setup.setup_component
with assert_setup_component(0) as result_config:
setup_component(hass, domain, start_config)
# using result_config is optional
"""
config = {}
async def mock_psc(hass, config_input, integration):
"""Mock the prepare_setup_component to capture config."""
domain_input = integration.domain
res = await async_process_component_config(hass, config_input, integration)
config[domain_input] = None if res is None else res.get(domain_input)
_LOGGER.debug(
"Configuration for %s, Validated: %s, Original %s",
domain_input,
config[domain_input],
config_input.get(domain_input),
)
return res
assert isinstance(config, dict)
with patch("homeassistant.config.async_process_component_config", mock_psc):
yield config
if domain is None:
assert len(config) == 1, "assert_setup_component requires DOMAIN: {}".format(
list(config.keys())
)
domain = list(config.keys())[0]
res = config.get(domain)
res_len = 0 if res is None else len(res)
assert (
res_len == count
), f"setup_component failed, expected {count} got {res_len}: {res}"
def init_recorder_component(hass, add_config=None):
"""Initialize the recorder."""
config = dict(add_config) if add_config else {}
config[recorder.CONF_DB_URL] = "sqlite://" # In memory DB
with patch("homeassistant.components.recorder.migration.migrate_schema"):
assert setup_component(hass, recorder.DOMAIN, {recorder.DOMAIN: config})
assert recorder.DOMAIN in hass.config.components
_LOGGER.info("In-memory recorder successfully started")
async def async_init_recorder_component(hass, add_config=None):
"""Initialize the recorder asynchronously."""
config = dict(add_config) if add_config else {}
config[recorder.CONF_DB_URL] = "sqlite://"
with patch("homeassistant.components.recorder.migration.migrate_schema"):
assert await async_setup_component(
hass, recorder.DOMAIN, {recorder.DOMAIN: config}
)
assert recorder.DOMAIN in hass.config.components
_LOGGER.info("In-memory recorder successfully started")
def mock_restore_cache(hass, states):
"""Mock the DATA_RESTORE_CACHE."""
key = restore_state.DATA_RESTORE_STATE_TASK
data = restore_state.RestoreStateData(hass)
now = date_util.utcnow()
last_states = {}
for state in states:
restored_state = state.as_dict()
restored_state["attributes"] = json.loads(
json.dumps(restored_state["attributes"], cls=JSONEncoder)
)
last_states[state.entity_id] = restore_state.StoredState(
State.from_dict(restored_state), now
)
data.last_states = last_states
_LOGGER.debug("Restore cache: %s", data.last_states)
assert len(data.last_states) == len(states), f"Duplicate entity_id? {states}"
hass.data[key] = data
class MockEntity(entity.Entity):
"""Mock Entity class."""
def __init__(self, **values):
"""Initialize an entity."""
self._values = values
if "entity_id" in values:
self.entity_id = values["entity_id"]
@property
def name(self):
"""Return the name of the entity."""
return self._handle("name")
@property
def should_poll(self):
"""Return the ste of the polling."""
return self._handle("should_poll")
@property
def unique_id(self):
"""Return the unique ID of the entity."""
return self._handle("unique_id")
@property
def state(self):
"""Return the state of the entity."""
return self._handle("state")
@property
def available(self):
"""Return True if entity is available."""
return self._handle("available")
@property
def device_info(self):
"""Info how it links to a device."""
return self._handle("device_info")
@property
def device_class(self):
"""Info how device should be classified."""
return self._handle("device_class")
@property
def unit_of_measurement(self):
"""Info on the units the entity state is in."""
return self._handle("unit_of_measurement")
@property
def capability_attributes(self):
"""Info about capabilities."""
return self._handle("capability_attributes")
@property
def supported_features(self):
"""Info about supported features."""
return self._handle("supported_features")
@property
def entity_registry_enabled_default(self):
"""Return if the entity should be enabled when first added to the entity registry."""
return self._handle("entity_registry_enabled_default")
def _handle(self, attr):
"""Return attribute value."""
if attr in self._values:
return self._values[attr]
return getattr(super(), attr)
@contextmanager
def mock_storage(data=None):
"""Mock storage.
Data is a dict {'key': {'version': version, 'data': data}}
Written data will be converted to JSON to ensure JSON parsing works.
"""
if data is None:
data = {}
orig_load = storage.Store._async_load