4
4
5
5
import asyncio
6
6
import datetime
7
- import functools
8
7
import logging
9
8
import threading
10
- import voluptuous as vol
9
+ from typing import TYPE_CHECKING , Any
11
10
11
+ import homeassistant .util .dt as dt_util
12
+ import voluptuous as vol
12
13
from homeassistant .components .hassio .const import ATTR_DATA
13
- from homeassistant .components .group import DOMAIN as GROUP_DOMAIN
14
14
from homeassistant .config_entries import SOURCE_IMPORT , ConfigEntry
15
15
from homeassistant .const import (
16
16
ATTR_DOMAIN ,
34
34
)
35
35
from homeassistant .helpers import (
36
36
config_validation as cv ,
37
+ )
38
+ from homeassistant .helpers import (
37
39
event ,
38
- issue_registry as ir ,
39
40
script ,
40
41
)
41
- from homeassistant .helpers .entity import Entity
42
+ from homeassistant .helpers import (
43
+ issue_registry as ir ,
44
+ )
42
45
from homeassistant .helpers .entity_component import DATA_INSTANCES , EntityComponent
43
46
from homeassistant .helpers .service import async_extract_referenced_entity_ids
44
47
from homeassistant .helpers .template import Template , result_as_boolean
45
- from homeassistant .helpers .typing import ConfigType
46
- import homeassistant .util .dt as dt_util
48
+
49
+ if TYPE_CHECKING :
50
+ from homeassistant .helpers .entity import Entity
51
+ from homeassistant .helpers .typing import ConfigType
47
52
48
53
from .const import (
49
54
ACTIONS_SERVICE ,
50
55
ATTR_BACKOFF ,
51
56
ATTR_EXPECTED_STATE ,
52
57
ATTR_ON_ERROR ,
53
- ATTR_RETRY_ID ,
54
58
ATTR_RETRIES ,
59
+ ATTR_RETRY_ID ,
55
60
ATTR_STATE_DELAY ,
56
61
ATTR_STATE_GRACE ,
57
62
ATTR_VALIDATION ,
66
71
DEFAULT_BACKOFF = "[[ 2 ** attempt ]]"
67
72
DEFAULT_RETRIES = 7
68
73
DEFAULT_STATE_GRACE = 0.2
74
+ GROUP_DOMAIN = "group"
69
75
70
- _running_retries : dict [str , ( str , int ) ] = {}
76
+ _running_retries : dict [str , tuple [ str , int ] ] = {}
71
77
_running_retries_write_lock = threading .Lock ()
72
78
73
79
74
- def _template_parameter (value : any | None ) -> str :
80
+ def _template_parameter (value : Any ) -> str :
75
81
"""Render template parameter."""
76
82
output = cv .template (value ).async_render (parse_result = False )
77
83
if not isinstance (output , str ):
78
- raise vol .Invalid ("template rendered value should be a string" )
84
+ message = "template rendered value should be a string"
85
+ raise vol .Invalid (message )
79
86
return output
80
87
81
88
@@ -96,23 +103,23 @@ def _fix_template_tokens(value: str) -> str:
96
103
DEFAULT_BACKOFF_FIXED = _fix_template_tokens (DEFAULT_BACKOFF )
97
104
98
105
99
- def _backoff_parameter (value : any | None ) -> Template :
106
+ def _backoff_parameter (value : Any | None ) -> Template :
100
107
"""Convert backoff parameter to template."""
101
108
return cv .template (_fix_template_tokens (cv .string (value )))
102
109
103
110
104
- def _validation_parameter (value : any | None ) -> Template :
111
+ def _validation_parameter (value : Any | None ) -> Template :
105
112
"""Convert validation parameter to template."""
106
113
return cv .dynamic_template (_fix_template_tokens (cv .string (value )))
107
114
108
115
109
116
SERVICE_SCHEMA_BASE_FIELDS = {
110
- vol .Required (ATTR_RETRIES , default = DEFAULT_RETRIES ): cv .positive_int ,
111
- vol .Required (ATTR_BACKOFF , default = DEFAULT_BACKOFF ): _backoff_parameter ,
117
+ vol .Required (ATTR_RETRIES , default = DEFAULT_RETRIES ): cv .positive_int , # type: ignore[reportArgumentType]
118
+ vol .Required (ATTR_BACKOFF , default = DEFAULT_BACKOFF ): _backoff_parameter , # type: ignore[reportArgumentType]
112
119
vol .Optional (ATTR_EXPECTED_STATE ): vol .All (cv .ensure_list , [_template_parameter ]),
113
120
vol .Optional (ATTR_VALIDATION ): _validation_parameter ,
114
- vol .Required (ATTR_STATE_DELAY , default = 0 ): cv .positive_float ,
115
- vol .Required (ATTR_STATE_GRACE , default = DEFAULT_STATE_GRACE ): cv .positive_float ,
121
+ vol .Required (ATTR_STATE_DELAY , default = 0 ): cv .positive_float , # type: ignore[reportArgumentType]
122
+ vol .Required (ATTR_STATE_GRACE , default = DEFAULT_STATE_GRACE ): cv .positive_float , # type: ignore[reportArgumentType]
116
123
vol .Optional (ATTR_RETRY_ID ): vol .Any (cv .string , None ),
117
124
vol .Optional (ATTR_ON_ERROR ): cv .SCRIPT_SCHEMA ,
118
125
}
@@ -151,7 +158,7 @@ def __init__(
151
158
self ,
152
159
hass : HomeAssistant ,
153
160
config_entry : ConfigEntry | None ,
154
- data : dict [str , any ],
161
+ data : dict [str , Any ],
155
162
) -> None :
156
163
"""Initialize the object."""
157
164
self .config_entry = config_entry
@@ -161,8 +168,8 @@ def __init__(
161
168
162
169
@staticmethod
163
170
def _retry_service_data (
164
- hass : HomeAssistant , data : dict [str , any ]
165
- ) -> dict [str , any ]:
171
+ hass : HomeAssistant , data : dict [str , Any ]
172
+ ) -> dict [str , Any ]:
166
173
"""Compose retry parameters."""
167
174
retry_data = {
168
175
key : data [key ] for key in data if key in SERVICE_SCHEMA_BASE_FIELDS
@@ -176,8 +183,8 @@ def _retry_service_data(
176
183
return retry_data
177
184
178
185
def _inner_service_data (
179
- self , hass : HomeAssistant , data : dict [str , any ]
180
- ) -> dict [str , any ]:
186
+ self , hass : HomeAssistant , data : dict [str , Any ]
187
+ ) -> dict [str , Any ]:
181
188
"""Compose inner service parameters."""
182
189
inner_data = {
183
190
key : value
@@ -205,7 +212,9 @@ def _expand_group(self, hass: HomeAssistant, entity_id: str) -> list[str]:
205
212
and entity_obj .platform is not None
206
213
and entity_obj .platform .platform_name == GROUP_DOMAIN
207
214
):
208
- for member_id in entity_obj .extra_state_attributes .get (ATTR_ENTITY_ID , []):
215
+ for member_id in getattr (entity_obj , "extra_state_attributes" , {}).get (
216
+ ATTR_ENTITY_ID , []
217
+ ):
209
218
entity_ids .extend (self ._expand_group (hass , member_id ))
210
219
else :
211
220
entity_ids .append (entity_id )
@@ -256,7 +265,7 @@ def __init__(
256
265
if key in self ._inner_data :
257
266
del self ._inner_data [key ]
258
267
self ._inner_data = {
259
- ** { ATTR_ENTITY_ID : entity_id } ,
268
+ ATTR_ENTITY_ID : entity_id ,
260
269
** self ._inner_data ,
261
270
}
262
271
self ._entity_id = entity_id
@@ -267,31 +276,39 @@ def __init__(
267
276
if self ._entity_id :
268
277
self ._retry_id = self ._entity_id
269
278
else :
270
- self ._retry_id = f"{ params .retry_data [ATTR_DOMAIN ]} .{ params .retry_data [ATTR_SERVICE ]} "
279
+ self ._retry_id = (
280
+ f"{ params .retry_data [ATTR_DOMAIN ]} ."
281
+ + params .retry_data [ATTR_SERVICE ]
282
+ )
283
+ self ._service_call_str_value = None
271
284
self ._start_id ()
272
285
273
286
async def _async_validate (self ) -> None :
274
- """Verify that the entity is available, in the expected state, and pass the validation."""
287
+ """Check the entity is available has expected state and pass validation."""
275
288
if self ._entity_id :
276
289
if (
277
290
ent_obj := _get_entity (self ._hass , self ._entity_id )
278
291
) is None or not ent_obj .available :
279
- raise InvalidStateError (f"{ self ._entity_id } is not available" )
292
+ message = f"{ self ._entity_id } is not available"
293
+ raise InvalidStateError (message )
280
294
else :
281
295
ent_obj = None
282
296
if (state_delay := self ._params .retry_data [ATTR_STATE_DELAY ]) > 0 :
283
297
await asyncio .sleep (state_delay )
284
298
if not self ._check_state (ent_obj ) or not self ._check_validation ():
285
299
await asyncio .sleep (self ._params .retry_data [ATTR_STATE_GRACE ])
286
300
if not self ._check_state (ent_obj ):
287
- raise InvalidStateError (
288
- f'{ self ._entity_id } state is "{ ent_obj .state } " but '
289
- f'expecting one of "{ self ._params .retry_data [ATTR_EXPECTED_STATE ]} "'
301
+ message = (
302
+ f'{ self ._entity_id } state is "{ getattr (ent_obj , "state" , "None" )} " '
303
+ "but expecting one of "
304
+ f'"{ self ._params .retry_data [ATTR_EXPECTED_STATE ]} "'
290
305
)
306
+ raise InvalidStateError (message )
291
307
if not self ._check_validation ():
292
- raise InvalidStateError (
308
+ message = (
293
309
f'"{ self ._params .retry_data [ATTR_VALIDATION ].template } " is False'
294
310
)
311
+ raise InvalidStateError (message )
295
312
296
313
def _check_state (self , entity : Entity | None ) -> bool :
297
314
"""Check if the entity's state is expected."""
@@ -301,7 +318,7 @@ def _check_state(self, entity: Entity | None) -> bool:
301
318
if entity .state == expected :
302
319
return True
303
320
try :
304
- if float (entity .state ) == float (expected ):
321
+ if float (entity .state ) == float (expected ): # type: ignore[reportArgumentType]
305
322
return True
306
323
except ValueError :
307
324
pass
@@ -318,12 +335,19 @@ def _check_validation(self) -> bool:
318
335
)
319
336
320
337
@property
321
- @functools .cache
322
338
def _service_call_str (self ) -> str :
339
+ if self ._service_call_str_value is None :
340
+ self ._service_call_str_value = self ._compose_service_call_str ()
341
+ return self ._service_call_str_value
342
+
343
+ def _compose_service_call_str (self ) -> str :
323
344
"""Return a string with the service call parameters."""
324
345
service_call = (
325
- f"{ self ._params .retry_data [ATTR_DOMAIN ]} .{ self ._params .retry_data [ATTR_SERVICE ]} "
326
- f"({ ', ' .join ([f'{ key } ={ value } ' for key , value in self ._inner_data .items ()])} )"
346
+ f"{ self ._params .retry_data [ATTR_DOMAIN ]} ."
347
+ f"{ self ._params .retry_data [ATTR_SERVICE ]} "
348
+ f"({ ', ' .join (
349
+ [f'{ key } ={ value } ' for key , value in self ._inner_data .items ()]
350
+ )} )"
327
351
)
328
352
retry_params = []
329
353
if (
@@ -333,7 +357,9 @@ def _service_call_str(self) -> str:
333
357
retry_params .append (f"expected_state={ expected_state [0 ]} " )
334
358
else :
335
359
retry_params .append (
336
- f"expected_state in ({ ', ' .join (state for state in expected_state )} )"
360
+ f"expected_state in ({ ', ' .join (
361
+ state for state in expected_state
362
+ )} )"
337
363
)
338
364
for name , value , default in (
339
365
(
@@ -362,13 +388,14 @@ def _service_call_str(self) -> str:
362
388
):
363
389
if value != default :
364
390
if isinstance (value , str ):
365
- value = f'"{ value } "'
366
- retry_params .append (f"{ name } ={ value } " )
391
+ retry_params .append (f'{ name } ="{ value } "' )
392
+ else :
393
+ retry_params .append (f"{ name } ={ value } " )
367
394
if len (retry_params ) > 0 :
368
395
service_call += f"[{ ', ' .join (retry_params )} ]"
369
396
return service_call
370
397
371
- def _log (self , level : int , prefix : str , stack_info : bool = False ) -> None :
398
+ def _log (self , level : int , prefix : str , stack_info : bool = False ) -> None : # noqa: FBT001, FBT002
372
399
"""Log entry."""
373
400
LOGGER .log (
374
401
level ,
@@ -420,7 +447,8 @@ def _end_id(self) -> None:
420
447
421
448
def _set_id (self , count : int ) -> None :
422
449
"""Set the retry_id entry with a counter."""
423
- _running_retries [self ._retry_id ] = (self ._context .id , count )
450
+ if self ._retry_id :
451
+ _running_retries [self ._retry_id ] = (self ._context .id , count )
424
452
425
453
def _check_id (self ) -> bool :
426
454
"""Check if self is the retry ID running job."""
@@ -430,7 +458,7 @@ def _check_id(self) -> bool:
430
458
)
431
459
432
460
@callback
433
- async def async_retry (self , * _ ) -> None :
461
+ async def async_retry (self , _ : datetime . datetime | None = None ) -> None :
434
462
"""One service call attempt."""
435
463
if not self ._check_id ():
436
464
self ._log (logging .INFO , "Cancelled" )
@@ -440,62 +468,63 @@ async def async_retry(self, *_) -> None:
440
468
self ._params .retry_data [ATTR_DOMAIN ],
441
469
self ._params .retry_data [ATTR_SERVICE ],
442
470
self ._inner_data .copy (),
443
- True ,
444
- Context (self ._context .user_id , self ._context .id ),
471
+ blocking = True ,
472
+ context = Context (self ._context .user_id , self ._context .id ),
445
473
)
446
474
await self ._async_validate ()
447
475
self ._log (
448
476
logging .DEBUG if self ._attempt == 1 else logging .INFO , "Succeeded"
449
477
)
450
478
self ._end_id ()
451
- return
452
- except Exception : # pylint: disable=broad-except
479
+ except Exception : # noqa: BLE001
453
480
self ._log (
454
481
logging .WARNING
455
482
if self ._attempt < self ._params .retry_data [ATTR_RETRIES ]
456
483
else logging .ERROR ,
457
484
"Failed" ,
458
- True ,
485
+ stack_info = True ,
459
486
)
460
- if self ._attempt == self ._params .retry_data [ATTR_RETRIES ]:
461
- if not self ._params .config_entry .options .get (CONF_DISABLE_REPAIR ):
462
- self ._repair ()
463
- self ._end_id ()
464
- if (on_error := self ._params .retry_data .get (ATTR_ON_ERROR )) is not None :
465
- await script .Script (
466
- self ._hass , on_error , CALL_SERVICE , DOMAIN
467
- ).async_run (
468
- run_variables = {ATTR_ENTITY_ID : self ._entity_id }
469
- if self ._entity_id
470
- else None ,
471
- context = Context (self ._context .user_id , self ._context .id ),
472
- )
473
- return
474
- next_retry = dt_util .now () + datetime .timedelta (
475
- seconds = float (
476
- self ._params .retry_data [ATTR_BACKOFF ].async_render (
477
- variables = {"attempt" : self ._attempt - 1 }
487
+ if self ._attempt == self ._params .retry_data [ATTR_RETRIES ]:
488
+ if not getattr (self ._params .config_entry , "options" , {}).get (
489
+ CONF_DISABLE_REPAIR
490
+ ):
491
+ self ._repair ()
492
+ self ._end_id ()
493
+ if (on_error := self ._params .retry_data .get (ATTR_ON_ERROR )) is not None :
494
+ await script .Script (
495
+ self ._hass , on_error , CALL_SERVICE , DOMAIN
496
+ ).async_run (
497
+ run_variables = {ATTR_ENTITY_ID : self ._entity_id }
498
+ if self ._entity_id
499
+ else None ,
500
+ context = Context (self ._context .user_id , self ._context .id ),
501
+ )
502
+ return
503
+ next_retry = dt_util .now () + datetime .timedelta (
504
+ seconds = float (
505
+ self ._params .retry_data [ATTR_BACKOFF ].async_render (
506
+ variables = {"attempt" : self ._attempt - 1 }
507
+ )
478
508
)
479
509
)
480
- )
481
- self ._attempt += 1
482
- event .async_track_point_in_time (self ._hass , self .async_retry , next_retry )
510
+ self ._attempt += 1
511
+ event .async_track_point_in_time (self ._hass , self .async_retry , next_retry )
483
512
484
513
485
- def _wrap_service_calls (
486
- hass : HomeAssistant , sequence : list [dict ], retry_params : dict [str , any ]
514
+ def _wrap_service_calls ( # noqa: PLR0912
515
+ hass : HomeAssistant , sequence : list [dict ], retry_params : dict [str , Any ]
487
516
) -> None :
488
517
"""Warp any service call with retry."""
489
518
for action in sequence :
490
519
action_type = cv .determine_script_action (action )
491
520
match action_type :
492
521
case cv .SCRIPT_ACTION_CALL_SERVICE :
493
522
if action [ATTR_SERVICE ] == f"{ DOMAIN } .{ ACTIONS_SERVICE } " :
494
- raise IntegrationError ("Nested retry.actions are disallowed" )
523
+ message = "Nested retry.actions are disallowed"
524
+ raise IntegrationError (message )
495
525
if action [ATTR_SERVICE ] == f"{ DOMAIN } .{ CALL_SERVICE } " :
496
- raise IntegrationError (
497
- "retry.call inside retry.actions is disallowed"
498
- )
526
+ message = "retry.call inside retry.actions is disallowed"
527
+ raise IntegrationError (message )
499
528
action [ATTR_DATA ] = action .get (ATTR_DATA , {})
500
529
action [ATTR_DATA ][ATTR_SERVICE ] = action [ATTR_SERVICE ]
501
530
action [ATTR_DATA ].update (retry_params )
0 commit comments