Skip to content

Commit 54dc31d

Browse files
committed
Ruff and Lint fixes
1 parent 788a0b1 commit 54dc31d

File tree

6 files changed

+282
-247
lines changed

6 files changed

+282
-247
lines changed

custom_components/retry/__init__.py

Lines changed: 102 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44

55
import asyncio
66
import datetime
7-
import functools
87
import logging
98
import threading
10-
import voluptuous as vol
9+
from typing import TYPE_CHECKING, Any
1110

11+
import homeassistant.util.dt as dt_util
12+
import voluptuous as vol
1213
from homeassistant.components.hassio.const import ATTR_DATA
13-
from homeassistant.components.group import DOMAIN as GROUP_DOMAIN
1414
from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry
1515
from homeassistant.const import (
1616
ATTR_DOMAIN,
@@ -34,24 +34,29 @@
3434
)
3535
from homeassistant.helpers import (
3636
config_validation as cv,
37+
)
38+
from homeassistant.helpers import (
3739
event,
38-
issue_registry as ir,
3940
script,
4041
)
41-
from homeassistant.helpers.entity import Entity
42+
from homeassistant.helpers import (
43+
issue_registry as ir,
44+
)
4245
from homeassistant.helpers.entity_component import DATA_INSTANCES, EntityComponent
4346
from homeassistant.helpers.service import async_extract_referenced_entity_ids
4447
from homeassistant.helpers.template import Template, result_as_boolean
45-
from homeassistant.helpers.typing import ConfigType
46-
import homeassistant.util.dt as dt_util
48+
49+
if TYPE_CHECKING:
50+
from homeassistant.helpers.entity import Entity
51+
from homeassistant.helpers.typing import ConfigType
4752

4853
from .const import (
4954
ACTIONS_SERVICE,
5055
ATTR_BACKOFF,
5156
ATTR_EXPECTED_STATE,
5257
ATTR_ON_ERROR,
53-
ATTR_RETRY_ID,
5458
ATTR_RETRIES,
59+
ATTR_RETRY_ID,
5560
ATTR_STATE_DELAY,
5661
ATTR_STATE_GRACE,
5762
ATTR_VALIDATION,
@@ -66,16 +71,18 @@
6671
DEFAULT_BACKOFF = "[[ 2 ** attempt ]]"
6772
DEFAULT_RETRIES = 7
6873
DEFAULT_STATE_GRACE = 0.2
74+
GROUP_DOMAIN = "group"
6975

70-
_running_retries: dict[str, (str, int)] = {}
76+
_running_retries: dict[str, tuple[str, int]] = {}
7177
_running_retries_write_lock = threading.Lock()
7278

7379

74-
def _template_parameter(value: any | None) -> str:
80+
def _template_parameter(value: Any) -> str:
7581
"""Render template parameter."""
7682
output = cv.template(value).async_render(parse_result=False)
7783
if not isinstance(output, str):
78-
raise vol.Invalid("template rendered value should be a string")
84+
message = "template rendered value should be a string"
85+
raise vol.Invalid(message)
7986
return output
8087

8188

@@ -96,23 +103,23 @@ def _fix_template_tokens(value: str) -> str:
96103
DEFAULT_BACKOFF_FIXED = _fix_template_tokens(DEFAULT_BACKOFF)
97104

98105

99-
def _backoff_parameter(value: any | None) -> Template:
106+
def _backoff_parameter(value: Any | None) -> Template:
100107
"""Convert backoff parameter to template."""
101108
return cv.template(_fix_template_tokens(cv.string(value)))
102109

103110

104-
def _validation_parameter(value: any | None) -> Template:
111+
def _validation_parameter(value: Any | None) -> Template:
105112
"""Convert validation parameter to template."""
106113
return cv.dynamic_template(_fix_template_tokens(cv.string(value)))
107114

108115

109116
SERVICE_SCHEMA_BASE_FIELDS = {
110-
vol.Required(ATTR_RETRIES, default=DEFAULT_RETRIES): cv.positive_int,
111-
vol.Required(ATTR_BACKOFF, default=DEFAULT_BACKOFF): _backoff_parameter,
117+
vol.Required(ATTR_RETRIES, default=DEFAULT_RETRIES): cv.positive_int, # type: ignore[reportArgumentType]
118+
vol.Required(ATTR_BACKOFF, default=DEFAULT_BACKOFF): _backoff_parameter, # type: ignore[reportArgumentType]
112119
vol.Optional(ATTR_EXPECTED_STATE): vol.All(cv.ensure_list, [_template_parameter]),
113120
vol.Optional(ATTR_VALIDATION): _validation_parameter,
114-
vol.Required(ATTR_STATE_DELAY, default=0): cv.positive_float,
115-
vol.Required(ATTR_STATE_GRACE, default=DEFAULT_STATE_GRACE): cv.positive_float,
121+
vol.Required(ATTR_STATE_DELAY, default=0): cv.positive_float, # type: ignore[reportArgumentType]
122+
vol.Required(ATTR_STATE_GRACE, default=DEFAULT_STATE_GRACE): cv.positive_float, # type: ignore[reportArgumentType]
116123
vol.Optional(ATTR_RETRY_ID): vol.Any(cv.string, None),
117124
vol.Optional(ATTR_ON_ERROR): cv.SCRIPT_SCHEMA,
118125
}
@@ -151,7 +158,7 @@ def __init__(
151158
self,
152159
hass: HomeAssistant,
153160
config_entry: ConfigEntry | None,
154-
data: dict[str, any],
161+
data: dict[str, Any],
155162
) -> None:
156163
"""Initialize the object."""
157164
self.config_entry = config_entry
@@ -161,8 +168,8 @@ def __init__(
161168

162169
@staticmethod
163170
def _retry_service_data(
164-
hass: HomeAssistant, data: dict[str, any]
165-
) -> dict[str, any]:
171+
hass: HomeAssistant, data: dict[str, Any]
172+
) -> dict[str, Any]:
166173
"""Compose retry parameters."""
167174
retry_data = {
168175
key: data[key] for key in data if key in SERVICE_SCHEMA_BASE_FIELDS
@@ -176,8 +183,8 @@ def _retry_service_data(
176183
return retry_data
177184

178185
def _inner_service_data(
179-
self, hass: HomeAssistant, data: dict[str, any]
180-
) -> dict[str, any]:
186+
self, hass: HomeAssistant, data: dict[str, Any]
187+
) -> dict[str, Any]:
181188
"""Compose inner service parameters."""
182189
inner_data = {
183190
key: value
@@ -205,7 +212,9 @@ def _expand_group(self, hass: HomeAssistant, entity_id: str) -> list[str]:
205212
and entity_obj.platform is not None
206213
and entity_obj.platform.platform_name == GROUP_DOMAIN
207214
):
208-
for member_id in entity_obj.extra_state_attributes.get(ATTR_ENTITY_ID, []):
215+
for member_id in getattr(entity_obj, "extra_state_attributes", {}).get(
216+
ATTR_ENTITY_ID, []
217+
):
209218
entity_ids.extend(self._expand_group(hass, member_id))
210219
else:
211220
entity_ids.append(entity_id)
@@ -256,7 +265,7 @@ def __init__(
256265
if key in self._inner_data:
257266
del self._inner_data[key]
258267
self._inner_data = {
259-
**{ATTR_ENTITY_ID: entity_id},
268+
ATTR_ENTITY_ID: entity_id,
260269
**self._inner_data,
261270
}
262271
self._entity_id = entity_id
@@ -267,31 +276,39 @@ def __init__(
267276
if self._entity_id:
268277
self._retry_id = self._entity_id
269278
else:
270-
self._retry_id = f"{params.retry_data[ATTR_DOMAIN]}.{params.retry_data[ATTR_SERVICE]}"
279+
self._retry_id = (
280+
f"{params.retry_data[ATTR_DOMAIN]}."
281+
+ params.retry_data[ATTR_SERVICE]
282+
)
283+
self._service_call_str_value = None
271284
self._start_id()
272285

273286
async def _async_validate(self) -> None:
274-
"""Verify that the entity is available, in the expected state, and pass the validation."""
287+
"""Check the entity is available has expected state and pass validation."""
275288
if self._entity_id:
276289
if (
277290
ent_obj := _get_entity(self._hass, self._entity_id)
278291
) is None or not ent_obj.available:
279-
raise InvalidStateError(f"{self._entity_id} is not available")
292+
message = f"{self._entity_id} is not available"
293+
raise InvalidStateError(message)
280294
else:
281295
ent_obj = None
282296
if (state_delay := self._params.retry_data[ATTR_STATE_DELAY]) > 0:
283297
await asyncio.sleep(state_delay)
284298
if not self._check_state(ent_obj) or not self._check_validation():
285299
await asyncio.sleep(self._params.retry_data[ATTR_STATE_GRACE])
286300
if not self._check_state(ent_obj):
287-
raise InvalidStateError(
288-
f'{self._entity_id} state is "{ent_obj.state}" but '
289-
f'expecting one of "{self._params.retry_data[ATTR_EXPECTED_STATE]}"'
301+
message = (
302+
f'{self._entity_id} state is "{getattr(ent_obj, "state", "None")}" '
303+
"but expecting one of "
304+
f'"{self._params.retry_data[ATTR_EXPECTED_STATE]}"'
290305
)
306+
raise InvalidStateError(message)
291307
if not self._check_validation():
292-
raise InvalidStateError(
308+
message = (
293309
f'"{self._params.retry_data[ATTR_VALIDATION].template}" is False'
294310
)
311+
raise InvalidStateError(message)
295312

296313
def _check_state(self, entity: Entity | None) -> bool:
297314
"""Check if the entity's state is expected."""
@@ -301,7 +318,7 @@ def _check_state(self, entity: Entity | None) -> bool:
301318
if entity.state == expected:
302319
return True
303320
try:
304-
if float(entity.state) == float(expected):
321+
if float(entity.state) == float(expected): # type: ignore[reportArgumentType]
305322
return True
306323
except ValueError:
307324
pass
@@ -318,12 +335,19 @@ def _check_validation(self) -> bool:
318335
)
319336

320337
@property
321-
@functools.cache
322338
def _service_call_str(self) -> str:
339+
if self._service_call_str_value is None:
340+
self._service_call_str_value = self._compose_service_call_str()
341+
return self._service_call_str_value
342+
343+
def _compose_service_call_str(self) -> str:
323344
"""Return a string with the service call parameters."""
324345
service_call = (
325-
f"{self._params.retry_data[ATTR_DOMAIN]}.{self._params.retry_data[ATTR_SERVICE]}"
326-
f"({', '.join([f'{key}={value}' for key, value in self._inner_data.items()])})"
346+
f"{self._params.retry_data[ATTR_DOMAIN]}."
347+
f"{self._params.retry_data[ATTR_SERVICE]}"
348+
f"({', '.join(
349+
[f'{key}={value}' for key, value in self._inner_data.items()]
350+
)})"
327351
)
328352
retry_params = []
329353
if (
@@ -333,7 +357,9 @@ def _service_call_str(self) -> str:
333357
retry_params.append(f"expected_state={expected_state[0]}")
334358
else:
335359
retry_params.append(
336-
f"expected_state in ({', '.join(state for state in expected_state)})"
360+
f"expected_state in ({', '.join(
361+
state for state in expected_state
362+
)})"
337363
)
338364
for name, value, default in (
339365
(
@@ -362,13 +388,14 @@ def _service_call_str(self) -> str:
362388
):
363389
if value != default:
364390
if isinstance(value, str):
365-
value = f'"{value}"'
366-
retry_params.append(f"{name}={value}")
391+
retry_params.append(f'{name}="{value}"')
392+
else:
393+
retry_params.append(f"{name}={value}")
367394
if len(retry_params) > 0:
368395
service_call += f"[{', '.join(retry_params)}]"
369396
return service_call
370397

371-
def _log(self, level: int, prefix: str, stack_info: bool = False) -> None:
398+
def _log(self, level: int, prefix: str, stack_info: bool = False) -> None: # noqa: FBT001, FBT002
372399
"""Log entry."""
373400
LOGGER.log(
374401
level,
@@ -420,7 +447,8 @@ def _end_id(self) -> None:
420447

421448
def _set_id(self, count: int) -> None:
422449
"""Set the retry_id entry with a counter."""
423-
_running_retries[self._retry_id] = (self._context.id, count)
450+
if self._retry_id:
451+
_running_retries[self._retry_id] = (self._context.id, count)
424452

425453
def _check_id(self) -> bool:
426454
"""Check if self is the retry ID running job."""
@@ -430,7 +458,7 @@ def _check_id(self) -> bool:
430458
)
431459

432460
@callback
433-
async def async_retry(self, *_) -> None:
461+
async def async_retry(self, _: datetime.datetime | None = None) -> None:
434462
"""One service call attempt."""
435463
if not self._check_id():
436464
self._log(logging.INFO, "Cancelled")
@@ -440,62 +468,63 @@ async def async_retry(self, *_) -> None:
440468
self._params.retry_data[ATTR_DOMAIN],
441469
self._params.retry_data[ATTR_SERVICE],
442470
self._inner_data.copy(),
443-
True,
444-
Context(self._context.user_id, self._context.id),
471+
blocking=True,
472+
context=Context(self._context.user_id, self._context.id),
445473
)
446474
await self._async_validate()
447475
self._log(
448476
logging.DEBUG if self._attempt == 1 else logging.INFO, "Succeeded"
449477
)
450478
self._end_id()
451-
return
452-
except Exception: # pylint: disable=broad-except
479+
except Exception: # noqa: BLE001
453480
self._log(
454481
logging.WARNING
455482
if self._attempt < self._params.retry_data[ATTR_RETRIES]
456483
else logging.ERROR,
457484
"Failed",
458-
True,
485+
stack_info=True,
459486
)
460-
if self._attempt == self._params.retry_data[ATTR_RETRIES]:
461-
if not self._params.config_entry.options.get(CONF_DISABLE_REPAIR):
462-
self._repair()
463-
self._end_id()
464-
if (on_error := self._params.retry_data.get(ATTR_ON_ERROR)) is not None:
465-
await script.Script(
466-
self._hass, on_error, CALL_SERVICE, DOMAIN
467-
).async_run(
468-
run_variables={ATTR_ENTITY_ID: self._entity_id}
469-
if self._entity_id
470-
else None,
471-
context=Context(self._context.user_id, self._context.id),
472-
)
473-
return
474-
next_retry = dt_util.now() + datetime.timedelta(
475-
seconds=float(
476-
self._params.retry_data[ATTR_BACKOFF].async_render(
477-
variables={"attempt": self._attempt - 1}
487+
if self._attempt == self._params.retry_data[ATTR_RETRIES]:
488+
if not getattr(self._params.config_entry, "options", {}).get(
489+
CONF_DISABLE_REPAIR
490+
):
491+
self._repair()
492+
self._end_id()
493+
if (on_error := self._params.retry_data.get(ATTR_ON_ERROR)) is not None:
494+
await script.Script(
495+
self._hass, on_error, CALL_SERVICE, DOMAIN
496+
).async_run(
497+
run_variables={ATTR_ENTITY_ID: self._entity_id}
498+
if self._entity_id
499+
else None,
500+
context=Context(self._context.user_id, self._context.id),
501+
)
502+
return
503+
next_retry = dt_util.now() + datetime.timedelta(
504+
seconds=float(
505+
self._params.retry_data[ATTR_BACKOFF].async_render(
506+
variables={"attempt": self._attempt - 1}
507+
)
478508
)
479509
)
480-
)
481-
self._attempt += 1
482-
event.async_track_point_in_time(self._hass, self.async_retry, next_retry)
510+
self._attempt += 1
511+
event.async_track_point_in_time(self._hass, self.async_retry, next_retry)
483512

484513

485-
def _wrap_service_calls(
486-
hass: HomeAssistant, sequence: list[dict], retry_params: dict[str, any]
514+
def _wrap_service_calls( # noqa: PLR0912
515+
hass: HomeAssistant, sequence: list[dict], retry_params: dict[str, Any]
487516
) -> None:
488517
"""Warp any service call with retry."""
489518
for action in sequence:
490519
action_type = cv.determine_script_action(action)
491520
match action_type:
492521
case cv.SCRIPT_ACTION_CALL_SERVICE:
493522
if action[ATTR_SERVICE] == f"{DOMAIN}.{ACTIONS_SERVICE}":
494-
raise IntegrationError("Nested retry.actions are disallowed")
523+
message = "Nested retry.actions are disallowed"
524+
raise IntegrationError(message)
495525
if action[ATTR_SERVICE] == f"{DOMAIN}.{CALL_SERVICE}":
496-
raise IntegrationError(
497-
"retry.call inside retry.actions is disallowed"
498-
)
526+
message = "retry.call inside retry.actions is disallowed"
527+
raise IntegrationError(message)
499528
action[ATTR_DATA] = action.get(ATTR_DATA, {})
500529
action[ATTR_DATA][ATTR_SERVICE] = action[ATTR_SERVICE]
501530
action[ATTR_DATA].update(retry_params)

0 commit comments

Comments
 (0)