Skip to content

Commit 083cf7a

Browse files
authored
Fix refactoring error with updating polling entities in sequence (home-assistant#93693)
* Fix refactoring error with updating in sequence see home-assistant#93649 * coverage * make sure entities are being updated in parallel * make sure entities are being updated in sequence
1 parent 49c3a88 commit 083cf7a

File tree

2 files changed

+109
-6
lines changed

2 files changed

+109
-6
lines changed

homeassistant/helpers/entity_platform.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def __init__(
136136
self._process_updates: asyncio.Lock | None = None
137137

138138
self.parallel_updates: asyncio.Semaphore | None = None
139-
self._update_in_parallel: bool = True
139+
self._update_in_sequence: bool = False
140140

141141
# Platform is None for the EntityComponent "catch-all" EntityPlatform
142142
# which powers entity_component.add_entities
@@ -187,7 +187,7 @@ def _get_parallel_updates_semaphore(
187187

188188
if parallel_updates is not None:
189189
self.parallel_updates = asyncio.Semaphore(parallel_updates)
190-
self._update_in_parallel = parallel_updates != 1
190+
self._update_in_sequence = parallel_updates == 1
191191

192192
return self.parallel_updates
193193

@@ -846,11 +846,13 @@ async def _update_entity_states(self, now: datetime) -> None:
846846
return
847847

848848
async with self._process_updates:
849-
if self._update_in_parallel or len(self.entities) <= 1:
850-
# If we know are going to update sequentially, we want to update
851-
# to avoid scheduling the coroutines as tasks that will we know
852-
# are going to wait on the semaphore lock.
849+
if self._update_in_sequence or len(self.entities) <= 1:
850+
# If we know we will update sequentially, we want to avoid scheduling
851+
# the coroutines as tasks that will wait on the semaphore lock.
853852
for entity in list(self.entities.values()):
853+
# If the entity is removed from hass during the previous
854+
# entity being updated, we need to skip updating the
855+
# entity.
854856
if entity.should_poll and entity.hass:
855857
await entity.async_update_ha_state(True)
856858
return

tests/helpers/test_entity_platform.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import asyncio
33
from datetime import timedelta
44
import logging
5+
from typing import Any
56
from unittest.mock import ANY, Mock, patch
67

78
import pytest
@@ -307,6 +308,7 @@ async def async_update(self):
307308
entity = AsyncEntity()
308309
await handle.async_add_entities([entity])
309310
assert entity.parallel_updates is None
311+
assert handle._update_in_sequence is False
310312

311313

312314
async def test_parallel_updates_async_platform_with_constant(
@@ -336,6 +338,7 @@ async def async_update(self):
336338
await handle.async_add_entities([entity])
337339
assert entity.parallel_updates is not None
338340
assert entity.parallel_updates._value == 2
341+
assert handle._update_in_sequence is False
339342

340343

341344
async def test_parallel_updates_sync_platform(hass: HomeAssistant) -> None:
@@ -412,6 +415,104 @@ async def update(self):
412415
assert entity.parallel_updates._value == 2
413416

414417

418+
async def test_parallel_updates_async_platform_updates_in_parallel(
419+
hass: HomeAssistant,
420+
) -> None:
421+
"""Test an async platform is updated in parallel."""
422+
platform = MockPlatform()
423+
424+
mock_entity_platform(hass, "test_domain.async_platform", platform)
425+
426+
component = EntityComponent(_LOGGER, DOMAIN, hass)
427+
component._platforms = {}
428+
429+
await component.async_setup({DOMAIN: {"platform": "async_platform"}})
430+
await hass.async_block_till_done()
431+
432+
handle = list(component._platforms.values())[-1]
433+
updating = []
434+
peak_update_count = 0
435+
436+
class AsyncEntity(MockEntity):
437+
"""Mock entity that has async_update."""
438+
439+
async def async_update(self):
440+
pass
441+
442+
async def async_update_ha_state(self, *args: Any, **kwargs: Any) -> None:
443+
nonlocal peak_update_count
444+
updating.append(self.entity_id)
445+
await asyncio.sleep(0)
446+
peak_update_count = max(len(updating), peak_update_count)
447+
await asyncio.sleep(0)
448+
updating.remove(self.entity_id)
449+
450+
entity1 = AsyncEntity()
451+
entity2 = AsyncEntity()
452+
entity3 = AsyncEntity()
453+
454+
await handle.async_add_entities([entity1, entity2, entity3])
455+
456+
assert entity1.parallel_updates is None
457+
assert entity2.parallel_updates is None
458+
assert entity3.parallel_updates is None
459+
460+
assert handle._update_in_sequence is False
461+
462+
await handle._update_entity_states(dt_util.utcnow())
463+
assert peak_update_count > 1
464+
465+
466+
async def test_parallel_updates_sync_platform_updates_in_sequence(
467+
hass: HomeAssistant,
468+
) -> None:
469+
"""Test a sync platform is updated in sequence."""
470+
platform = MockPlatform()
471+
472+
mock_entity_platform(hass, "test_domain.platform", platform)
473+
474+
component = EntityComponent(_LOGGER, DOMAIN, hass)
475+
component._platforms = {}
476+
477+
await component.async_setup({DOMAIN: {"platform": "platform"}})
478+
await hass.async_block_till_done()
479+
480+
handle = list(component._platforms.values())[-1]
481+
updating = []
482+
peak_update_count = 0
483+
484+
class SyncEntity(MockEntity):
485+
"""Mock entity that has update."""
486+
487+
def update(self):
488+
pass
489+
490+
async def async_update_ha_state(self, *args: Any, **kwargs: Any) -> None:
491+
nonlocal peak_update_count
492+
updating.append(self.entity_id)
493+
await asyncio.sleep(0)
494+
peak_update_count = max(len(updating), peak_update_count)
495+
await asyncio.sleep(0)
496+
updating.remove(self.entity_id)
497+
498+
entity1 = SyncEntity()
499+
entity2 = SyncEntity()
500+
entity3 = SyncEntity()
501+
502+
await handle.async_add_entities([entity1, entity2, entity3])
503+
assert entity1.parallel_updates is not None
504+
assert entity1.parallel_updates._value == 1
505+
assert entity2.parallel_updates is not None
506+
assert entity2.parallel_updates._value == 1
507+
assert entity3.parallel_updates is not None
508+
assert entity3.parallel_updates._value == 1
509+
510+
assert handle._update_in_sequence is True
511+
512+
await handle._update_entity_states(dt_util.utcnow())
513+
assert peak_update_count == 1
514+
515+
415516
async def test_raise_error_on_update(hass: HomeAssistant) -> None:
416517
"""Test the add entity if they raise an error on update."""
417518
updates = []

0 commit comments

Comments
 (0)