|
2 | 2 | import asyncio
|
3 | 3 | from datetime import timedelta
|
4 | 4 | import logging
|
| 5 | +from typing import Any |
5 | 6 | from unittest.mock import ANY, Mock, patch
|
6 | 7 |
|
7 | 8 | import pytest
|
@@ -307,6 +308,7 @@ async def async_update(self):
|
307 | 308 | entity = AsyncEntity()
|
308 | 309 | await handle.async_add_entities([entity])
|
309 | 310 | assert entity.parallel_updates is None
|
| 311 | + assert handle._update_in_sequence is False |
310 | 312 |
|
311 | 313 |
|
312 | 314 | async def test_parallel_updates_async_platform_with_constant(
|
@@ -336,6 +338,7 @@ async def async_update(self):
|
336 | 338 | await handle.async_add_entities([entity])
|
337 | 339 | assert entity.parallel_updates is not None
|
338 | 340 | assert entity.parallel_updates._value == 2
|
| 341 | + assert handle._update_in_sequence is False |
339 | 342 |
|
340 | 343 |
|
341 | 344 | async def test_parallel_updates_sync_platform(hass: HomeAssistant) -> None:
|
@@ -412,6 +415,104 @@ async def update(self):
|
412 | 415 | assert entity.parallel_updates._value == 2
|
413 | 416 |
|
414 | 417 |
|
| 418 | +async def test_parallel_updates_async_platform_updates_in_parallel( |
| 419 | + hass: HomeAssistant, |
| 420 | +) -> None: |
| 421 | + """Test an async platform is updated in parallel.""" |
| 422 | + platform = MockPlatform() |
| 423 | + |
| 424 | + mock_entity_platform(hass, "test_domain.async_platform", platform) |
| 425 | + |
| 426 | + component = EntityComponent(_LOGGER, DOMAIN, hass) |
| 427 | + component._platforms = {} |
| 428 | + |
| 429 | + await component.async_setup({DOMAIN: {"platform": "async_platform"}}) |
| 430 | + await hass.async_block_till_done() |
| 431 | + |
| 432 | + handle = list(component._platforms.values())[-1] |
| 433 | + updating = [] |
| 434 | + peak_update_count = 0 |
| 435 | + |
| 436 | + class AsyncEntity(MockEntity): |
| 437 | + """Mock entity that has async_update.""" |
| 438 | + |
| 439 | + async def async_update(self): |
| 440 | + pass |
| 441 | + |
| 442 | + async def async_update_ha_state(self, *args: Any, **kwargs: Any) -> None: |
| 443 | + nonlocal peak_update_count |
| 444 | + updating.append(self.entity_id) |
| 445 | + await asyncio.sleep(0) |
| 446 | + peak_update_count = max(len(updating), peak_update_count) |
| 447 | + await asyncio.sleep(0) |
| 448 | + updating.remove(self.entity_id) |
| 449 | + |
| 450 | + entity1 = AsyncEntity() |
| 451 | + entity2 = AsyncEntity() |
| 452 | + entity3 = AsyncEntity() |
| 453 | + |
| 454 | + await handle.async_add_entities([entity1, entity2, entity3]) |
| 455 | + |
| 456 | + assert entity1.parallel_updates is None |
| 457 | + assert entity2.parallel_updates is None |
| 458 | + assert entity3.parallel_updates is None |
| 459 | + |
| 460 | + assert handle._update_in_sequence is False |
| 461 | + |
| 462 | + await handle._update_entity_states(dt_util.utcnow()) |
| 463 | + assert peak_update_count > 1 |
| 464 | + |
| 465 | + |
| 466 | +async def test_parallel_updates_sync_platform_updates_in_sequence( |
| 467 | + hass: HomeAssistant, |
| 468 | +) -> None: |
| 469 | + """Test a sync platform is updated in sequence.""" |
| 470 | + platform = MockPlatform() |
| 471 | + |
| 472 | + mock_entity_platform(hass, "test_domain.platform", platform) |
| 473 | + |
| 474 | + component = EntityComponent(_LOGGER, DOMAIN, hass) |
| 475 | + component._platforms = {} |
| 476 | + |
| 477 | + await component.async_setup({DOMAIN: {"platform": "platform"}}) |
| 478 | + await hass.async_block_till_done() |
| 479 | + |
| 480 | + handle = list(component._platforms.values())[-1] |
| 481 | + updating = [] |
| 482 | + peak_update_count = 0 |
| 483 | + |
| 484 | + class SyncEntity(MockEntity): |
| 485 | + """Mock entity that has update.""" |
| 486 | + |
| 487 | + def update(self): |
| 488 | + pass |
| 489 | + |
| 490 | + async def async_update_ha_state(self, *args: Any, **kwargs: Any) -> None: |
| 491 | + nonlocal peak_update_count |
| 492 | + updating.append(self.entity_id) |
| 493 | + await asyncio.sleep(0) |
| 494 | + peak_update_count = max(len(updating), peak_update_count) |
| 495 | + await asyncio.sleep(0) |
| 496 | + updating.remove(self.entity_id) |
| 497 | + |
| 498 | + entity1 = SyncEntity() |
| 499 | + entity2 = SyncEntity() |
| 500 | + entity3 = SyncEntity() |
| 501 | + |
| 502 | + await handle.async_add_entities([entity1, entity2, entity3]) |
| 503 | + assert entity1.parallel_updates is not None |
| 504 | + assert entity1.parallel_updates._value == 1 |
| 505 | + assert entity2.parallel_updates is not None |
| 506 | + assert entity2.parallel_updates._value == 1 |
| 507 | + assert entity3.parallel_updates is not None |
| 508 | + assert entity3.parallel_updates._value == 1 |
| 509 | + |
| 510 | + assert handle._update_in_sequence is True |
| 511 | + |
| 512 | + await handle._update_entity_states(dt_util.utcnow()) |
| 513 | + assert peak_update_count == 1 |
| 514 | + |
| 515 | + |
415 | 516 | async def test_raise_error_on_update(hass: HomeAssistant) -> None:
|
416 | 517 | """Test the add entity if they raise an error on update."""
|
417 | 518 | updates = []
|
|
0 commit comments