|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +import asyncio |
3 | 4 | from collections import defaultdict
|
4 | 5 | from collections.abc import MutableMapping
|
5 | 6 | from enum import Enum
|
6 | 7 | from pathlib import Path
|
| 8 | +from time import sleep |
7 | 9 | from typing import TYPE_CHECKING, Any, Optional, TypedDict, TypeVar, Union
|
8 | 10 | from urllib.parse import urlencode
|
9 | 11 |
|
|
22 | 24 | )
|
23 | 25 | from .generator import InfrahubGenerator
|
24 | 26 | from .graphql import Mutation
|
| 27 | +from .queries import SCHEMA_HASH_SYNC_STATUS |
25 | 28 | from .transforms import InfrahubTransform
|
26 | 29 | from .utils import duplicates
|
27 | 30 |
|
@@ -616,15 +619,36 @@ async def all(
|
616 | 619 |
|
617 | 620 | return self.cache[branch]
|
618 | 621 |
|
619 |
| - async def load(self, schemas: list[dict], branch: Optional[str] = None) -> SchemaLoadResponse: |
| 622 | + async def load( |
| 623 | + self, schemas: list[dict], branch: Optional[str] = None, wait_until_converged: bool = False |
| 624 | + ) -> SchemaLoadResponse: |
620 | 625 | branch = branch or self.client.default_branch
|
621 | 626 | url = f"{self.client.address}/api/schema/load?branch={branch}"
|
622 | 627 | response = await self.client._post(
|
623 | 628 | url=url, timeout=max(120, self.client.default_timeout), payload={"schemas": schemas}
|
624 | 629 | )
|
625 | 630 |
|
| 631 | + if wait_until_converged: |
| 632 | + await self.wait_until_converged(branch=branch) |
| 633 | + |
626 | 634 | return self._validate_load_schema_response(response=response)
|
627 | 635 |
|
| 636 | + async def wait_until_converged(self, branch: Optional[str] = None) -> None: |
| 637 | + """Wait until the schema has converged on the selected branch or the timeout has been reached""" |
| 638 | + waited = 0 |
| 639 | + while True: |
| 640 | + status = await self.client.execute_graphql(query=SCHEMA_HASH_SYNC_STATUS, branch_name=branch) |
| 641 | + if status["InfrahubStatus"]["summary"]["schema_hash_synced"]: |
| 642 | + self.client.log.info(f"Schema successfully converged after {waited} seconds") |
| 643 | + return |
| 644 | + |
| 645 | + if waited >= self.client.config.schema_converge_timeout: |
| 646 | + self.client.log.warning(f"Schema not converged after {waited} seconds, proceeding regardless") |
| 647 | + return |
| 648 | + |
| 649 | + waited += 1 |
| 650 | + await asyncio.sleep(delay=1) |
| 651 | + |
628 | 652 | async def check(self, schemas: list[dict], branch: Optional[str] = None) -> tuple[bool, Optional[dict]]:
|
629 | 653 | branch = branch or self.client.default_branch
|
630 | 654 | url = f"{self.client.address}/api/schema/check?branch={branch}"
|
@@ -999,15 +1023,36 @@ def fetch(
|
999 | 1023 |
|
1000 | 1024 | return nodes
|
1001 | 1025 |
|
1002 |
| - def load(self, schemas: list[dict], branch: Optional[str] = None) -> SchemaLoadResponse: |
| 1026 | + def load( |
| 1027 | + self, schemas: list[dict], branch: Optional[str] = None, wait_until_converged: bool = False |
| 1028 | + ) -> SchemaLoadResponse: |
1003 | 1029 | branch = branch or self.client.default_branch
|
1004 | 1030 | url = f"{self.client.address}/api/schema/load?branch={branch}"
|
1005 | 1031 | response = self.client._post(
|
1006 | 1032 | url=url, timeout=max(120, self.client.default_timeout), payload={"schemas": schemas}
|
1007 | 1033 | )
|
1008 | 1034 |
|
| 1035 | + if wait_until_converged: |
| 1036 | + self.wait_until_converged(branch=branch) |
| 1037 | + |
1009 | 1038 | return self._validate_load_schema_response(response=response)
|
1010 | 1039 |
|
| 1040 | + def wait_until_converged(self, branch: Optional[str] = None) -> None: |
| 1041 | + """Wait until the schema has converged on the selected branch or the timeout has been reached""" |
| 1042 | + waited = 0 |
| 1043 | + while True: |
| 1044 | + status = self.client.execute_graphql(query=SCHEMA_HASH_SYNC_STATUS, branch_name=branch) |
| 1045 | + if status["InfrahubStatus"]["summary"]["schema_hash_synced"]: |
| 1046 | + self.client.log.info(f"Schema successfully converged after {waited} seconds") |
| 1047 | + return |
| 1048 | + |
| 1049 | + if waited >= self.client.config.schema_converge_timeout: |
| 1050 | + self.client.log.warning(f"Schema not converged after {waited} seconds, proceeding regardless") |
| 1051 | + return |
| 1052 | + |
| 1053 | + waited += 1 |
| 1054 | + sleep(1) |
| 1055 | + |
1011 | 1056 | def check(self, schemas: list[dict], branch: Optional[str] = None) -> tuple[bool, Optional[dict]]:
|
1012 | 1057 | branch = branch or self.client.default_branch
|
1013 | 1058 | url = f"{self.client.address}/api/schema/check?branch={branch}"
|
|
0 commit comments