diff --git a/CHANGELOG.md b/CHANGELOG.md index 6162718b..c242a18c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,29 @@ This project uses [*towncrier*](https://towncrier.readthedocs.io/) and the chang +## [0.14.0](https://github.com/opsmill/infrahub-sdk-python/tree/v0.14.0) - 2024-10-04 + +### Removed + +- Removed depreceted methods InfrahubClient.init and InfrahubClientSync.init ([#33](https://github.com/opsmill/infrahub-sdk-python/issues/33)) + +### Changed + +- Query filters are not validated locally anymore, the validation will be done on the server side instead. ([#9](https://github.com/opsmill/infrahub-sdk-python/issues/9)) +- Method client.get() can now return `None` instead of raising an exception when `raise_when_missing` is set to False + + ```python + response = await clients.get( + kind="CoreRepository", name__value="infrahub-demo", raise_when_missing=False + ) + ``` ([#11](https://github.com/opsmill/infrahub-sdk-python/issues/11)) + +### Fixed + +- prefix and address attribute filters are now available in the Python SDK ([#10](https://github.com/opsmill/infrahub-sdk-python/issues/10)) +- Queries using isnull as a filter are now supported by the Python SDK ([#30](https://github.com/opsmill/infrahub-sdk-python/issues/30)) +- `execute_graphql` method for InfrahubClient(Sync) now properly considers the `default_branch` setting ([#46](https://github.com/opsmill/infrahub-sdk-python/issues/46)) + ## [0.13.1.dev0](https://github.com/opsmill/infrahub-sdk-python/tree/v0.13.1.dev0) - 2024-09-24 ### Added diff --git a/changelog/33.removed.md b/changelog/33.removed.md deleted file mode 100644 index c1b747a3..00000000 --- a/changelog/33.removed.md +++ /dev/null @@ -1 +0,0 @@ -Removed depreceted methods InfrahubClient.init and InfrahubClientSync.init diff --git a/changelog/46.fixed.md b/changelog/46.fixed.md deleted file mode 100644 index b21af46d..00000000 --- a/changelog/46.fixed.md +++ /dev/null @@ -1 +0,0 @@ -`execute_graphql` method for InfrahubClient(Sync) now properly considers the `default_branch` setting diff --git a/infrahub_sdk/__init__.py b/infrahub_sdk/__init__.py index 4d44fdd2..624f7ee0 100644 --- a/infrahub_sdk/__init__.py +++ b/infrahub_sdk/__init__.py @@ -10,7 +10,6 @@ from infrahub_sdk.exceptions import ( AuthenticationError, Error, - FilterNotFoundError, GraphQLError, NodeNotFoundError, ServerNotReachableError, @@ -50,7 +49,6 @@ "InfrahubNodeSync", "InfrahubRepositoryConfig", "InfrahubSchema", - "FilterNotFoundError", "generate_uuid", "GenericSchema", "GraphQLQueryAnalyzer", diff --git a/infrahub_sdk/batch.py b/infrahub_sdk/batch.py index 81348888..6cfd8f43 100644 --- a/infrahub_sdk/batch.py +++ b/infrahub_sdk/batch.py @@ -10,7 +10,7 @@ class BatchTask: task: Callable[[Any], Awaitable[Any]] args: tuple[Any, ...] kwargs: dict[str, Any] - node: Optional[InfrahubNode] = None + node: Optional[Any] = None async def execute_batch_task_in_pool( @@ -43,9 +43,7 @@ def __init__( def num_tasks(self) -> int: return len(self._tasks) - def add( - self, *args: Any, task: Callable[[Any], Awaitable[Any]], node: Optional[InfrahubNode] = None, **kwargs: Any - ) -> None: + def add(self, *args: Any, task: Callable, node: Optional[Any] = None, **kwargs: Any) -> None: self._tasks.append(BatchTask(task=task, node=node, args=args, kwargs=kwargs)) async def execute(self) -> AsyncGenerator: diff --git a/infrahub_sdk/client.py b/infrahub_sdk/client.py index 5bdc5e28..2b2ab0b3 100644 --- a/infrahub_sdk/client.py +++ b/infrahub_sdk/client.py @@ -5,7 +5,19 @@ import logging from functools import wraps from time import sleep -from typing import TYPE_CHECKING, Any, Callable, Coroutine, MutableMapping, Optional, TypedDict, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Coroutine, + Literal, + MutableMapping, + Optional, + TypedDict, + TypeVar, + Union, + overload, +) import httpx import ujson @@ -35,6 +47,7 @@ InfrahubNodeSync, ) from infrahub_sdk.object_store import ObjectStore, ObjectStoreSync +from infrahub_sdk.protocols_base import CoreNode, CoreNodeSync from infrahub_sdk.queries import get_commit_update_mutation from infrahub_sdk.query_groups import InfrahubGroupContext, InfrahubGroupContextSync from infrahub_sdk.schema import InfrahubSchema, InfrahubSchemaSync, NodeSchema @@ -46,8 +59,12 @@ if TYPE_CHECKING: from types import TracebackType + # pylint: disable=redefined-builtin disable=too-many-lines +SchemaType = TypeVar("SchemaType", bound=CoreNode) +SchemaTypeSync = TypeVar("SchemaTypeSync", bound=CoreNodeSync) + class NodeDiff(ExtensionTypedDict): branch: str @@ -287,14 +304,33 @@ def _initialize(self) -> None: self._request_method: AsyncRequester = self.config.requester or self._default_request_method self.group_context = InfrahubGroupContext(self) + @overload async def create( self, kind: str, + data: Optional[dict] = ..., + branch: Optional[str] = ..., + **kwargs: Any, + ) -> InfrahubNode: ... + + @overload + async def create( + self, + kind: type[SchemaType], + data: Optional[dict] = ..., + branch: Optional[str] = ..., + **kwargs: Any, + ) -> SchemaType: ... + + async def create( + self, + kind: Union[str, type[SchemaType]], data: Optional[dict] = None, branch: Optional[str] = None, **kwargs: Any, - ) -> InfrahubNode: + ) -> Union[InfrahubNode, SchemaType]: branch = branch or self.default_branch + schema = await self.schema.get(kind=kind, branch=branch) if not data and not kwargs: @@ -302,16 +338,119 @@ async def create( return InfrahubNode(client=self, schema=schema, branch=branch, data=data or kwargs) - async def delete(self, kind: str, id: str, branch: Optional[str] = None) -> None: + async def delete(self, kind: Union[str, type[SchemaType]], id: str, branch: Optional[str] = None) -> None: branch = branch or self.default_branch schema = await self.schema.get(kind=kind, branch=branch) node = InfrahubNode(client=self, schema=schema, branch=branch, data={"id": id}) await node.delete() + @overload + async def get( + self, + kind: type[SchemaType], + raise_when_missing: Literal[False], + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + id: Optional[str] = ..., + hfid: Optional[list[str]] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + **kwargs: Any, + ) -> Optional[SchemaType]: ... + + @overload + async def get( + self, + kind: type[SchemaType], + raise_when_missing: Literal[True], + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + id: Optional[str] = ..., + hfid: Optional[list[str]] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + **kwargs: Any, + ) -> SchemaType: ... + + @overload + async def get( + self, + kind: type[SchemaType], + raise_when_missing: bool = ..., + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + id: Optional[str] = ..., + hfid: Optional[list[str]] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + **kwargs: Any, + ) -> SchemaType: ... + + @overload + async def get( + self, + kind: str, + raise_when_missing: Literal[False], + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + id: Optional[str] = ..., + hfid: Optional[list[str]] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + **kwargs: Any, + ) -> Optional[InfrahubNode]: ... + + @overload + async def get( + self, + kind: str, + raise_when_missing: Literal[True], + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + id: Optional[str] = ..., + hfid: Optional[list[str]] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + **kwargs: Any, + ) -> InfrahubNode: ... + + @overload async def get( self, kind: str, + raise_when_missing: bool = ..., + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + id: Optional[str] = ..., + hfid: Optional[list[str]] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + **kwargs: Any, + ) -> InfrahubNode: ... + + async def get( + self, + kind: Union[str, type[SchemaType]], + raise_when_missing: bool = True, at: Optional[Timestamp] = None, branch: Optional[str] = None, id: Optional[str] = None, @@ -322,7 +461,7 @@ async def get( fragment: bool = False, prefetch_relationships: bool = False, **kwargs: Any, - ) -> InfrahubNode: + ) -> Union[InfrahubNode, SchemaType, None]: branch = branch or self.default_branch schema = await self.schema.get(kind=kind, branch=branch) @@ -355,8 +494,10 @@ async def get( **filters, ) - if len(results) == 0: - raise NodeNotFoundError(branch_name=branch, node_type=kind, identifier=filters) + if len(results) == 0 and raise_when_missing: + raise NodeNotFoundError(branch_name=branch, node_type=schema.kind, identifier=filters) + if len(results) == 0 and not raise_when_missing: + return None if len(results) > 1: raise IndexError("More than 1 node returned") @@ -391,9 +532,39 @@ async def _process_nodes_and_relationships( return ProcessRelationsNode(nodes=nodes, related_nodes=related_nodes) + @overload + async def all( + self, + kind: type[SchemaType], + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + populate_store: bool = ..., + offset: Optional[int] = ..., + limit: Optional[int] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + ) -> list[SchemaType]: ... + + @overload async def all( self, kind: str, + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + populate_store: bool = ..., + offset: Optional[int] = ..., + limit: Optional[int] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + ) -> list[InfrahubNode]: ... + + async def all( + self, + kind: Union[str, type[SchemaType]], at: Optional[Timestamp] = None, branch: Optional[str] = None, populate_store: bool = False, @@ -403,7 +574,7 @@ async def all( exclude: Optional[list[str]] = None, fragment: bool = False, prefetch_relationships: bool = False, - ) -> list[InfrahubNode]: + ) -> Union[list[InfrahubNode], list[SchemaType]]: """Retrieve all nodes of a given kind Args: @@ -434,9 +605,43 @@ async def all( prefetch_relationships=prefetch_relationships, ) + @overload + async def filters( + self, + kind: type[SchemaType], + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + populate_store: bool = ..., + offset: Optional[int] = ..., + limit: Optional[int] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + partial_match: bool = ..., + **kwargs: Any, + ) -> list[SchemaType]: ... + + @overload async def filters( self, kind: str, + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + populate_store: bool = ..., + offset: Optional[int] = ..., + limit: Optional[int] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + partial_match: bool = ..., + **kwargs: Any, + ) -> list[InfrahubNode]: ... + + async def filters( + self, + kind: Union[str, type[SchemaType]], at: Optional[Timestamp] = None, branch: Optional[str] = None, populate_store: bool = False, @@ -448,7 +653,7 @@ async def filters( prefetch_relationships: bool = False, partial_match: bool = False, **kwargs: Any, - ) -> list[InfrahubNode]: + ) -> Union[list[InfrahubNode], list[SchemaType]]: """Retrieve nodes of a given kind based on provided filters. Args: @@ -477,9 +682,6 @@ async def filters( node = InfrahubNode(client=self, schema=schema, branch=branch) filters = kwargs - if filters: - node.validate_filters(filters=filters) - nodes: list[InfrahubNode] = [] related_nodes: list[InfrahubNode] = [] @@ -845,9 +1047,100 @@ async def get_diff_summary( ) return response["DiffSummary"] + @overload + async def allocate_next_ip_address( + self, + resource_pool: CoreNode, + kind: type[SchemaType], + identifier: Optional[str] = ..., + prefix_length: Optional[int] = ..., + address_type: Optional[str] = ..., + data: Optional[dict[str, Any]] = ..., + branch: Optional[str] = ..., + timeout: Optional[int] = ..., + tracker: Optional[str] = ..., + raise_for_error: Literal[True] = True, + ) -> SchemaType: ... + + @overload + async def allocate_next_ip_address( + self, + resource_pool: CoreNode, + kind: type[SchemaType], + identifier: Optional[str] = ..., + prefix_length: Optional[int] = ..., + address_type: Optional[str] = ..., + data: Optional[dict[str, Any]] = ..., + branch: Optional[str] = ..., + timeout: Optional[int] = ..., + tracker: Optional[str] = ..., + raise_for_error: Literal[False] = False, + ) -> Optional[SchemaType]: ... + + @overload + async def allocate_next_ip_address( + self, + resource_pool: CoreNode, + kind: type[SchemaType], + identifier: Optional[str] = ..., + prefix_length: Optional[int] = ..., + address_type: Optional[str] = ..., + data: Optional[dict[str, Any]] = ..., + branch: Optional[str] = ..., + timeout: Optional[int] = ..., + tracker: Optional[str] = ..., + raise_for_error: bool = ..., + ) -> SchemaType: ... + + @overload + async def allocate_next_ip_address( + self, + resource_pool: CoreNode, + kind: Literal[None] = ..., + identifier: Optional[str] = ..., + prefix_length: Optional[int] = ..., + address_type: Optional[str] = ..., + data: Optional[dict[str, Any]] = ..., + branch: Optional[str] = ..., + timeout: Optional[int] = ..., + tracker: Optional[str] = ..., + raise_for_error: Literal[True] = True, + ) -> CoreNode: ... + + @overload + async def allocate_next_ip_address( + self, + resource_pool: CoreNode, + kind: Literal[None] = ..., + identifier: Optional[str] = ..., + prefix_length: Optional[int] = ..., + address_type: Optional[str] = ..., + data: Optional[dict[str, Any]] = ..., + branch: Optional[str] = ..., + timeout: Optional[int] = ..., + tracker: Optional[str] = ..., + raise_for_error: Literal[False] = False, + ) -> Optional[CoreNode]: ... + + @overload + async def allocate_next_ip_address( + self, + resource_pool: CoreNode, + kind: Literal[None] = ..., + identifier: Optional[str] = ..., + prefix_length: Optional[int] = ..., + address_type: Optional[str] = ..., + data: Optional[dict[str, Any]] = ..., + branch: Optional[str] = ..., + timeout: Optional[int] = ..., + tracker: Optional[str] = ..., + raise_for_error: bool = ..., + ) -> Optional[CoreNode]: ... + async def allocate_next_ip_address( self, - resource_pool: InfrahubNode, + resource_pool: CoreNode, + kind: Optional[type[SchemaType]] = None, # pylint: disable=unused-argument identifier: Optional[str] = None, prefix_length: Optional[int] = None, address_type: Optional[str] = None, @@ -856,7 +1149,7 @@ async def allocate_next_ip_address( timeout: Optional[int] = None, tracker: Optional[str] = None, raise_for_error: bool = True, - ) -> Optional[InfrahubNode]: + ) -> Optional[Union[CoreNode, SchemaType]]: """Allocate a new IP address by using the provided resource pool. Args: @@ -898,9 +1191,106 @@ async def allocate_next_ip_address( return await self.get(kind=resource_details["kind"], id=resource_details["id"], branch=branch) return None + @overload + async def allocate_next_ip_prefix( + self, + resource_pool: CoreNode, + kind: type[SchemaType], + identifier: Optional[str] = ..., + prefix_length: Optional[int] = ..., + member_type: Optional[str] = ..., + prefix_type: Optional[str] = ..., + data: Optional[dict[str, Any]] = ..., + branch: Optional[str] = ..., + timeout: Optional[int] = ..., + tracker: Optional[str] = ..., + raise_for_error: Literal[True] = True, + ) -> SchemaType: ... + + @overload + async def allocate_next_ip_prefix( + self, + resource_pool: CoreNode, + kind: type[SchemaType], + identifier: Optional[str] = ..., + prefix_length: Optional[int] = ..., + member_type: Optional[str] = ..., + prefix_type: Optional[str] = ..., + data: Optional[dict[str, Any]] = ..., + branch: Optional[str] = ..., + timeout: Optional[int] = ..., + tracker: Optional[str] = ..., + raise_for_error: Literal[False] = False, + ) -> Optional[SchemaType]: ... + + @overload + async def allocate_next_ip_prefix( + self, + resource_pool: CoreNode, + kind: type[SchemaType], + identifier: Optional[str] = ..., + prefix_length: Optional[int] = ..., + member_type: Optional[str] = ..., + prefix_type: Optional[str] = ..., + data: Optional[dict[str, Any]] = ..., + branch: Optional[str] = ..., + timeout: Optional[int] = ..., + tracker: Optional[str] = ..., + raise_for_error: bool = ..., + ) -> SchemaType: ... + + @overload + async def allocate_next_ip_prefix( + self, + resource_pool: CoreNode, + kind: Literal[None] = ..., + identifier: Optional[str] = ..., + prefix_length: Optional[int] = ..., + member_type: Optional[str] = ..., + prefix_type: Optional[str] = ..., + data: Optional[dict[str, Any]] = ..., + branch: Optional[str] = ..., + timeout: Optional[int] = ..., + tracker: Optional[str] = ..., + raise_for_error: Literal[True] = True, + ) -> CoreNode: ... + + @overload + async def allocate_next_ip_prefix( + self, + resource_pool: CoreNode, + kind: Literal[None] = ..., + identifier: Optional[str] = ..., + prefix_length: Optional[int] = ..., + member_type: Optional[str] = ..., + prefix_type: Optional[str] = ..., + data: Optional[dict[str, Any]] = ..., + branch: Optional[str] = ..., + timeout: Optional[int] = ..., + tracker: Optional[str] = ..., + raise_for_error: Literal[False] = False, + ) -> Optional[CoreNode]: ... + + @overload + async def allocate_next_ip_prefix( + self, + resource_pool: CoreNode, + kind: Literal[None] = ..., + identifier: Optional[str] = ..., + prefix_length: Optional[int] = ..., + member_type: Optional[str] = ..., + prefix_type: Optional[str] = ..., + data: Optional[dict[str, Any]] = ..., + branch: Optional[str] = ..., + timeout: Optional[int] = ..., + tracker: Optional[str] = ..., + raise_for_error: bool = ..., + ) -> Optional[CoreNode]: ... + async def allocate_next_ip_prefix( self, - resource_pool: InfrahubNode, + resource_pool: CoreNode, + kind: Optional[type[SchemaType]] = None, # pylint: disable=unused-argument identifier: Optional[str] = None, prefix_length: Optional[int] = None, member_type: Optional[str] = None, @@ -910,20 +1300,20 @@ async def allocate_next_ip_prefix( timeout: Optional[int] = None, tracker: Optional[str] = None, raise_for_error: bool = True, - ) -> Optional[InfrahubNode]: + ) -> Optional[Union[CoreNode, SchemaType]]: """Allocate a new IP prefix by using the provided resource pool. Args: - resource_pool (InfrahubNode): Node corresponding to the pool to allocate resources from. - identifier (str, optional): Value to perform idempotent allocation, the same resource will be returned for a given identifier. - prefix_length (int, optional): Length of the prefix to allocate. - member_type (str, optional): Member type of the prefix to allocate. - prefix_type (str, optional): Kind of the prefix to allocate. - data (dict, optional): A key/value map to use to set attributes values on the allocated prefix. - branch (str, optional): Name of the branch to allocate from. Defaults to default_branch. - timeout (int, optional): Flag to indicate whether to populate the store with the retrieved nodes. - tracker (str, optional): The offset for pagination. - raise_for_error (bool, optional): The limit for pagination. + resource_pool: Node corresponding to the pool to allocate resources from. + identifier: Value to perform idempotent allocation, the same resource will be returned for a given identifier. + prefix_length: Length of the prefix to allocate. + member_type: Member type of the prefix to allocate. + prefix_type: Kind of the prefix to allocate. + data: A key/value map to use to set attributes values on the allocated prefix. + branch: Name of the branch to allocate from. Defaults to default_branch. + timeout: Flag to indicate whether to populate the store with the retrieved nodes. + tracker: The offset for pagination. + raise_for_error: The limit for pagination. Returns: InfrahubNode: Node corresponding to the allocated resource. """ @@ -1030,13 +1420,31 @@ def _initialize(self) -> None: self._request_method: SyncRequester = self.config.sync_requester or self._default_request_method self.group_context = InfrahubGroupContextSync(self) + @overload def create( self, kind: str, + data: Optional[dict] = ..., + branch: Optional[str] = ..., + **kwargs: Any, + ) -> InfrahubNodeSync: ... + + @overload + def create( + self, + kind: type[SchemaTypeSync], + data: Optional[dict] = ..., + branch: Optional[str] = ..., + **kwargs: Any, + ) -> SchemaTypeSync: ... + + def create( + self, + kind: Union[str, type[SchemaTypeSync]], data: Optional[dict] = None, branch: Optional[str] = None, **kwargs: Any, - ) -> InfrahubNodeSync: + ) -> Union[InfrahubNodeSync, SchemaTypeSync]: branch = branch or self.default_branch schema = self.schema.get(kind=kind, branch=branch) @@ -1045,7 +1453,7 @@ def create( return InfrahubNodeSync(client=self, schema=schema, branch=branch, data=data or kwargs) - def delete(self, kind: str, id: str, branch: Optional[str] = None) -> None: + def delete(self, kind: Union[str, type[SchemaTypeSync]], id: str, branch: Optional[str] = None) -> None: branch = branch or self.default_branch schema = self.schema.get(kind=kind, branch=branch) @@ -1138,9 +1546,39 @@ def execute_graphql( # TODO add a special method to execute mutation that will check if the method returned OK + @overload + def all( + self, + kind: type[SchemaTypeSync], + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + populate_store: bool = ..., + offset: Optional[int] = ..., + limit: Optional[int] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + ) -> list[SchemaTypeSync]: ... + + @overload def all( self, kind: str, + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + populate_store: bool = ..., + offset: Optional[int] = ..., + limit: Optional[int] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + ) -> list[InfrahubNodeSync]: ... + + def all( + self, + kind: Union[str, type[SchemaTypeSync]], at: Optional[Timestamp] = None, branch: Optional[str] = None, populate_store: bool = False, @@ -1150,7 +1588,7 @@ def all( exclude: Optional[list[str]] = None, fragment: bool = False, prefetch_relationships: bool = False, - ) -> list[InfrahubNodeSync]: + ) -> Union[list[InfrahubNodeSync], list[SchemaTypeSync]]: """Retrieve all nodes of a given kind Args: @@ -1210,9 +1648,43 @@ def _process_nodes_and_relationships( return ProcessRelationsNodeSync(nodes=nodes, related_nodes=related_nodes) + @overload + def filters( + self, + kind: type[SchemaTypeSync], + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + populate_store: bool = ..., + offset: Optional[int] = ..., + limit: Optional[int] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + partial_match: bool = ..., + **kwargs: Any, + ) -> list[SchemaTypeSync]: ... + + @overload def filters( self, kind: str, + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + populate_store: bool = ..., + offset: Optional[int] = ..., + limit: Optional[int] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + partial_match: bool = ..., + **kwargs: Any, + ) -> list[InfrahubNodeSync]: ... + + def filters( + self, + kind: Union[str, type[SchemaTypeSync]], at: Optional[Timestamp] = None, branch: Optional[str] = None, populate_store: bool = False, @@ -1224,7 +1696,7 @@ def filters( prefetch_relationships: bool = False, partial_match: bool = False, **kwargs: Any, - ) -> list[InfrahubNodeSync]: + ) -> Union[list[InfrahubNodeSync], list[SchemaTypeSync]]: """Retrieve nodes of a given kind based on provided filters. Args: @@ -1253,9 +1725,6 @@ def filters( node = InfrahubNodeSync(client=self, schema=schema, branch=branch) filters = kwargs - if filters: - node.validate_filters(filters=filters) - nodes: list[InfrahubNodeSync] = [] related_nodes: list[InfrahubNodeSync] = [] @@ -1306,9 +1775,112 @@ def filters( return nodes + @overload + def get( + self, + kind: type[SchemaTypeSync], + raise_when_missing: Literal[False], + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + id: Optional[str] = ..., + hfid: Optional[list[str]] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + **kwargs: Any, + ) -> Optional[SchemaTypeSync]: ... + + @overload + def get( + self, + kind: type[SchemaTypeSync], + raise_when_missing: Literal[True], + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + id: Optional[str] = ..., + hfid: Optional[list[str]] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + **kwargs: Any, + ) -> SchemaTypeSync: ... + + @overload + def get( + self, + kind: type[SchemaTypeSync], + raise_when_missing: bool = ..., + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + id: Optional[str] = ..., + hfid: Optional[list[str]] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + **kwargs: Any, + ) -> SchemaTypeSync: ... + + @overload def get( self, kind: str, + raise_when_missing: Literal[False], + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + id: Optional[str] = ..., + hfid: Optional[list[str]] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + **kwargs: Any, + ) -> Optional[InfrahubNodeSync]: ... + + @overload + def get( + self, + kind: str, + raise_when_missing: Literal[True], + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + id: Optional[str] = ..., + hfid: Optional[list[str]] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + **kwargs: Any, + ) -> InfrahubNodeSync: ... + + @overload + def get( + self, + kind: str, + raise_when_missing: bool = ..., + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + id: Optional[str] = ..., + hfid: Optional[list[str]] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + **kwargs: Any, + ) -> InfrahubNodeSync: ... + + def get( + self, + kind: Union[str, type[SchemaTypeSync]], + raise_when_missing: bool = True, at: Optional[Timestamp] = None, branch: Optional[str] = None, id: Optional[str] = None, @@ -1319,7 +1891,7 @@ def get( fragment: bool = False, prefetch_relationships: bool = False, **kwargs: Any, - ) -> InfrahubNodeSync: + ) -> Union[InfrahubNodeSync, SchemaTypeSync, None]: branch = branch or self.default_branch schema = self.schema.get(kind=kind, branch=branch) @@ -1352,8 +1924,10 @@ def get( **filters, ) - if len(results) == 0: - raise NodeNotFoundError(branch_name=branch, node_type=kind, identifier=filters) + if len(results) == 0 and raise_when_missing: + raise NodeNotFoundError(branch_name=branch, node_type=schema.kind, identifier=filters) + if len(results) == 0 and not raise_when_missing: + return None if len(results) > 1: raise IndexError("More than 1 node returned") @@ -1465,9 +2039,100 @@ def get_diff_summary( ) return response["DiffSummary"] + @overload + def allocate_next_ip_address( + self, + resource_pool: CoreNodeSync, + kind: type[SchemaTypeSync], + identifier: Optional[str] = ..., + prefix_length: Optional[int] = ..., + address_type: Optional[str] = ..., + data: Optional[dict[str, Any]] = ..., + branch: Optional[str] = ..., + timeout: Optional[int] = ..., + tracker: Optional[str] = ..., + raise_for_error: Literal[True] = True, + ) -> SchemaTypeSync: ... + + @overload + def allocate_next_ip_address( + self, + resource_pool: CoreNodeSync, + kind: type[SchemaTypeSync], + identifier: Optional[str] = ..., + prefix_length: Optional[int] = ..., + address_type: Optional[str] = ..., + data: Optional[dict[str, Any]] = ..., + branch: Optional[str] = ..., + timeout: Optional[int] = ..., + tracker: Optional[str] = ..., + raise_for_error: Literal[False] = False, + ) -> Optional[SchemaTypeSync]: ... + + @overload + def allocate_next_ip_address( + self, + resource_pool: CoreNodeSync, + kind: type[SchemaTypeSync], + identifier: Optional[str] = ..., + prefix_length: Optional[int] = ..., + address_type: Optional[str] = ..., + data: Optional[dict[str, Any]] = ..., + branch: Optional[str] = ..., + timeout: Optional[int] = ..., + tracker: Optional[str] = ..., + raise_for_error: bool = ..., + ) -> SchemaTypeSync: ... + + @overload + def allocate_next_ip_address( + self, + resource_pool: CoreNodeSync, + kind: Literal[None] = ..., + identifier: Optional[str] = ..., + prefix_length: Optional[int] = ..., + address_type: Optional[str] = ..., + data: Optional[dict[str, Any]] = ..., + branch: Optional[str] = ..., + timeout: Optional[int] = ..., + tracker: Optional[str] = ..., + raise_for_error: Literal[True] = True, + ) -> CoreNodeSync: ... + + @overload + def allocate_next_ip_address( + self, + resource_pool: CoreNodeSync, + kind: Literal[None] = ..., + identifier: Optional[str] = ..., + prefix_length: Optional[int] = ..., + address_type: Optional[str] = ..., + data: Optional[dict[str, Any]] = ..., + branch: Optional[str] = ..., + timeout: Optional[int] = ..., + tracker: Optional[str] = ..., + raise_for_error: Literal[False] = False, + ) -> Optional[CoreNodeSync]: ... + + @overload + def allocate_next_ip_address( + self, + resource_pool: CoreNodeSync, + kind: Literal[None] = ..., + identifier: Optional[str] = ..., + prefix_length: Optional[int] = ..., + address_type: Optional[str] = ..., + data: Optional[dict[str, Any]] = ..., + branch: Optional[str] = ..., + timeout: Optional[int] = ..., + tracker: Optional[str] = ..., + raise_for_error: bool = ..., + ) -> Optional[CoreNodeSync]: ... + def allocate_next_ip_address( self, - resource_pool: InfrahubNodeSync, + resource_pool: CoreNodeSync, + kind: Optional[type[SchemaTypeSync]] = None, # pylint: disable=unused-argument identifier: Optional[str] = None, prefix_length: Optional[int] = None, address_type: Optional[str] = None, @@ -1476,7 +2141,7 @@ def allocate_next_ip_address( timeout: Optional[int] = None, tracker: Optional[str] = None, raise_for_error: bool = True, - ) -> Optional[InfrahubNodeSync]: + ) -> Optional[Union[CoreNodeSync, SchemaTypeSync]]: """Allocate a new IP address by using the provided resource pool. Args: @@ -1514,9 +2179,106 @@ def allocate_next_ip_address( return self.get(kind=resource_details["kind"], id=resource_details["id"], branch=branch) return None + @overload + def allocate_next_ip_prefix( + self, + resource_pool: CoreNodeSync, + kind: type[SchemaTypeSync], + identifier: Optional[str] = ..., + prefix_length: Optional[int] = ..., + member_type: Optional[str] = ..., + prefix_type: Optional[str] = ..., + data: Optional[dict[str, Any]] = ..., + branch: Optional[str] = ..., + timeout: Optional[int] = ..., + tracker: Optional[str] = ..., + raise_for_error: Literal[True] = True, + ) -> SchemaTypeSync: ... + + @overload + def allocate_next_ip_prefix( + self, + resource_pool: CoreNodeSync, + kind: type[SchemaTypeSync], + identifier: Optional[str] = ..., + prefix_length: Optional[int] = ..., + member_type: Optional[str] = ..., + prefix_type: Optional[str] = ..., + data: Optional[dict[str, Any]] = ..., + branch: Optional[str] = ..., + timeout: Optional[int] = ..., + tracker: Optional[str] = ..., + raise_for_error: Literal[False] = False, + ) -> Optional[SchemaTypeSync]: ... + + @overload + def allocate_next_ip_prefix( + self, + resource_pool: CoreNodeSync, + kind: type[SchemaTypeSync], + identifier: Optional[str] = ..., + prefix_length: Optional[int] = ..., + member_type: Optional[str] = ..., + prefix_type: Optional[str] = ..., + data: Optional[dict[str, Any]] = ..., + branch: Optional[str] = ..., + timeout: Optional[int] = ..., + tracker: Optional[str] = ..., + raise_for_error: bool = ..., + ) -> SchemaTypeSync: ... + + @overload + def allocate_next_ip_prefix( + self, + resource_pool: CoreNodeSync, + kind: Literal[None] = ..., + identifier: Optional[str] = ..., + prefix_length: Optional[int] = ..., + member_type: Optional[str] = ..., + prefix_type: Optional[str] = ..., + data: Optional[dict[str, Any]] = ..., + branch: Optional[str] = ..., + timeout: Optional[int] = ..., + tracker: Optional[str] = ..., + raise_for_error: Literal[True] = True, + ) -> CoreNodeSync: ... + + @overload + def allocate_next_ip_prefix( + self, + resource_pool: CoreNodeSync, + kind: Literal[None] = ..., + identifier: Optional[str] = ..., + prefix_length: Optional[int] = ..., + member_type: Optional[str] = ..., + prefix_type: Optional[str] = ..., + data: Optional[dict[str, Any]] = ..., + branch: Optional[str] = ..., + timeout: Optional[int] = ..., + tracker: Optional[str] = ..., + raise_for_error: Literal[False] = False, + ) -> Optional[CoreNodeSync]: ... + + @overload + def allocate_next_ip_prefix( + self, + resource_pool: CoreNodeSync, + kind: Literal[None] = ..., + identifier: Optional[str] = ..., + prefix_length: Optional[int] = ..., + member_type: Optional[str] = ..., + prefix_type: Optional[str] = ..., + data: Optional[dict[str, Any]] = ..., + branch: Optional[str] = ..., + timeout: Optional[int] = ..., + tracker: Optional[str] = ..., + raise_for_error: bool = ..., + ) -> Optional[CoreNodeSync]: ... + def allocate_next_ip_prefix( self, - resource_pool: InfrahubNodeSync, + resource_pool: CoreNodeSync, + kind: Optional[type[SchemaTypeSync]] = None, # pylint: disable=unused-argument identifier: Optional[str] = None, prefix_length: Optional[int] = None, member_type: Optional[str] = None, @@ -1526,7 +2288,7 @@ def allocate_next_ip_prefix( timeout: Optional[int] = None, tracker: Optional[str] = None, raise_for_error: bool = True, - ) -> Optional[InfrahubNodeSync]: + ) -> Optional[Union[CoreNodeSync, SchemaTypeSync]]: """Allocate a new IP prefix by using the provided resource pool. Args: diff --git a/infrahub_sdk/code_generator.py b/infrahub_sdk/code_generator.py new file mode 100644 index 00000000..0bfb4e34 --- /dev/null +++ b/infrahub_sdk/code_generator.py @@ -0,0 +1,123 @@ +from typing import Any, Mapping, Optional + +import jinja2 + +from infrahub_sdk import protocols as sdk_protocols +from infrahub_sdk.ctl.constants import PROTOCOLS_TEMPLATE +from infrahub_sdk.schema import ( + AttributeSchema, + GenericSchema, + MainSchemaTypes, + NodeSchema, + ProfileSchema, + RelationshipSchema, +) + + +class CodeGenerator: + def __init__(self, schema: dict[str, MainSchemaTypes]): + self.generics: dict[str, GenericSchema] = {} + self.nodes: dict[str, NodeSchema] = {} + self.profiles: dict[str, ProfileSchema] = {} + + for name, schema_type in schema.items(): + if isinstance(schema_type, GenericSchema): + self.generics[name] = schema_type + if isinstance(schema_type, NodeSchema): + self.nodes[name] = schema_type + if isinstance(schema_type, ProfileSchema): + self.profiles[name] = schema_type + + self.base_protocols = [ + e + for e in dir(sdk_protocols) + if not e.startswith("__") + and not e.endswith("__") + and e + not in ("TYPE_CHECKING", "CoreNode", "Optional", "Protocol", "Union", "annotations", "runtime_checkable") + ] + + self.sorted_generics = self._sort_and_filter_models(self.generics, filters=["CoreNode"] + self.base_protocols) + self.sorted_nodes = self._sort_and_filter_models(self.nodes, filters=["CoreNode"] + self.base_protocols) + self.sorted_profiles = self._sort_and_filter_models( + self.profiles, filters=["CoreProfile"] + self.base_protocols + ) + + def render(self, sync: bool = True) -> str: + jinja2_env = jinja2.Environment(loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True) + jinja2_env.filters["inheritance"] = self._jinja2_filter_inheritance + jinja2_env.filters["render_attribute"] = self._jinja2_filter_render_attribute + jinja2_env.filters["render_relationship"] = self._jinja2_filter_render_relationship + + template = jinja2_env.from_string(PROTOCOLS_TEMPLATE) + return template.render( + generics=self.sorted_generics, + nodes=self.sorted_nodes, + profiles=self.sorted_profiles, + base_protocols=self.base_protocols, + sync=sync, + ) + + @staticmethod + def _jinja2_filter_inheritance(value: dict[str, Any]) -> str: + inherit_from: list[str] = value.get("inherit_from", []) + + if not inherit_from: + return "CoreNode" + return ", ".join(inherit_from) + + @staticmethod + def _jinja2_filter_render_attribute(value: AttributeSchema) -> str: + attribute_kind_map = { + "boolean": "Boolean", + "datetime": "DateTime", + "dropdown": "Dropdown", + "hashedpassword": "HashedPassword", + "iphost": "IPHost", + "ipnetwork": "IPNetwork", + "json": "JSONAttribute", + "list": "ListAttribute", + "number": "Integer", + "password": "String", + "text": "String", + "textarea": "String", + "url": "URL", + } + + name = value.name + kind = value.kind + + attribute_kind = attribute_kind_map[kind.lower()] + if value.optional: + attribute_kind = f"{attribute_kind}Optional" + + return f"{name}: {attribute_kind}" + + @staticmethod + def _jinja2_filter_render_relationship(value: RelationshipSchema, sync: bool = False) -> str: + name = value.name + cardinality = value.cardinality + + type_ = "RelatedNode" + if cardinality == "many": + type_ = "RelationshipManager" + + if sync: + type_ += "Sync" + + return f"{name}: {type_}" + + @staticmethod + def _sort_and_filter_models( + models: Mapping[str, MainSchemaTypes], filters: Optional[list[str]] = None + ) -> list[MainSchemaTypes]: + if filters is None: + filters = ["CoreNode"] + + filtered: list[MainSchemaTypes] = [] + for name, model in models.items(): + if name in filters: + continue + filtered.append(model) + + return sorted(filtered, key=lambda k: k.name) diff --git a/infrahub_sdk/ctl/cli_commands.py b/infrahub_sdk/ctl/cli_commands.py index 8779b0f9..498e7b0d 100644 --- a/infrahub_sdk/ctl/cli_commands.py +++ b/infrahub_sdk/ctl/cli_commands.py @@ -4,7 +4,7 @@ import logging import sys from pathlib import Path -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional import jinja2 import typer @@ -14,25 +14,29 @@ from rich.traceback import Traceback from infrahub_sdk import __version__ as sdk_version -from infrahub_sdk import protocols as sdk_protocols from infrahub_sdk.async_typer import AsyncTyper +from infrahub_sdk.code_generator import CodeGenerator from infrahub_sdk.ctl import config from infrahub_sdk.ctl.branch import app as branch_app from infrahub_sdk.ctl.check import run as run_check from infrahub_sdk.ctl.client import initialize_client, initialize_client_sync -from infrahub_sdk.ctl.constants import PROTOCOLS_TEMPLATE from infrahub_sdk.ctl.exceptions import QueryNotFoundError from infrahub_sdk.ctl.generator import run as run_generator from infrahub_sdk.ctl.render import list_jinja2_transforms from infrahub_sdk.ctl.repository import app as repository_app from infrahub_sdk.ctl.repository import get_repository_config -from infrahub_sdk.ctl.schema import app as schema +from infrahub_sdk.ctl.schema import app as schema_app +from infrahub_sdk.ctl.schema import load_schemas_from_disk_and_exit from infrahub_sdk.ctl.transform import list_transforms from infrahub_sdk.ctl.utils import catch_exception, execute_graphql_query, parse_cli_vars from infrahub_sdk.ctl.validate import app as validate_app from infrahub_sdk.exceptions import GraphQLError, InfrahubTransformNotFoundError from infrahub_sdk.jinja2 import identify_faulty_jinja_code -from infrahub_sdk.schema import AttributeSchema, GenericSchema, InfrahubRepositoryConfig, NodeSchema, RelationshipSchema +from infrahub_sdk.schema import ( + InfrahubRepositoryConfig, + MainSchemaTypes, + SchemaRoot, +) from infrahub_sdk.transforms import get_transform_class_instance from infrahub_sdk.utils import get_branch, write_to_file @@ -43,7 +47,7 @@ app = AsyncTyper(pretty_exceptions_show_locals=False) app.add_typer(branch_app, name="branch") -app.add_typer(schema, name="schema") +app.add_typer(schema_app, name="schema") app.add_typer(validate_app, name="validate") app.add_typer(repository_app, name="repository") app.command(name="dump")(dump) @@ -323,114 +327,36 @@ def transform( @app.command(name="protocols") @catch_exception(console=console) def protocols( # noqa: PLR0915 + schemas: list[Path] = typer.Option(None, help="List of schemas or directory to load."), branch: str = typer.Option(None, help="Branch of schema to export Python protocols for."), + sync: bool = typer.Option(False, help="Generate for sync or async."), _: str = CONFIG_PARAM, out: str = typer.Option("schema_protocols.py", help="Path to a file to save the result."), ) -> None: """Export Python protocols corresponding to a schema.""" - def _jinja2_filter_inheritance(value: dict[str, Any]) -> str: - inherit_from: list[str] = value.get("inherit_from", []) - - if not inherit_from: - return "CoreNode" - return ", ".join(inherit_from) - - def _jinja2_filter_render_attribute(value: AttributeSchema) -> str: - attribute_kind_map = { - "boolean": "bool", - "datetime": "datetime", - "dropdown": "str", - "hashedpassword": "str", - "iphost": "str", - "ipnetwork": "str", - "json": "dict", - "list": "list", - "number": "int", - "password": "str", - "text": "str", - "textarea": "str", - "url": "str", - } - - name = value.name - kind = value.kind - - attribute_kind = attribute_kind_map[kind.lower()] - if value.optional: - attribute_kind = f"Optional[{attribute_kind}]" - - return f"{name}: {attribute_kind}" - - def _jinja2_filter_render_relationship(value: RelationshipSchema, sync: bool = False) -> str: - name = value.name - cardinality = value.cardinality - - type_ = "RelatedNode" - if cardinality == "many": - type_ = "RelationshipManager" - - if sync: - type_ += "Sync" - - return f"{name}: {type_}" - - def _sort_and_filter_models( - models: dict[str, Union[GenericSchema, NodeSchema]], filters: Optional[list[str]] = None - ) -> list[Union[GenericSchema, NodeSchema]]: - if filters is None: - filters = ["CoreNode"] - - filtered: list[Union[GenericSchema, NodeSchema]] = [] - for name, model in models.items(): - if name in filters: - continue - filtered.append(model) - - return sorted(filtered, key=lambda k: k.name) + schema: dict[str, MainSchemaTypes] = {} - client = initialize_client_sync() - current_schema = client.schema.all(branch=branch) - - generics: dict[str, GenericSchema] = {} - nodes: dict[str, NodeSchema] = {} - - for name, schema_type in current_schema.items(): - if isinstance(schema_type, GenericSchema): - generics[name] = schema_type - if isinstance(schema_type, NodeSchema): - nodes[name] = schema_type - - base_protocols = [ - e - for e in dir(sdk_protocols) - if not e.startswith("__") - and not e.endswith("__") - and e not in ("TYPE_CHECKING", "CoreNode", "Optional", "Protocol", "Union", "annotations", "runtime_checkable") - ] - sorted_generics = _sort_and_filter_models(generics, filters=["CoreNode"] + base_protocols) - sorted_nodes = _sort_and_filter_models(nodes, filters=["CoreNode"] + base_protocols) - - jinja2_env = jinja2.Environment(loader=jinja2.BaseLoader, trim_blocks=True, lstrip_blocks=True) - jinja2_env.filters["inheritance"] = _jinja2_filter_inheritance - jinja2_env.filters["render_attribute"] = _jinja2_filter_render_attribute - jinja2_env.filters["render_relationship"] = _jinja2_filter_render_relationship - - template = jinja2_env.from_string(PROTOCOLS_TEMPLATE) - rendered = template.render(generics=sorted_generics, nodes=sorted_nodes, base_protocols=base_protocols, sync=False) - rendered_sync = template.render( - generics=sorted_generics, nodes=sorted_nodes, base_protocols=base_protocols, sync=True - ) - output_file = Path(out) - output_file_sync = Path(output_file.stem + "_sync" + output_file.suffix) + if schemas: + schemas_data = load_schemas_from_disk_and_exit(schemas=schemas) + + for data in schemas_data: + data.load_content() + schema_root = SchemaRoot(**data.content) + schema.update({item.kind: item for item in schema_root.nodes + schema_root.generics}) + + else: + client = initialize_client_sync() + schema.update(client.schema.fetch(branch=branch)) + + code_generator = CodeGenerator(schema=schema) if out: - write_to_file(output_file, rendered) - write_to_file(output_file_sync, rendered_sync) - console.print(f"Python protocols exported in {output_file} and {output_file_sync}") + output_file = Path(out) + write_to_file(output_file, code_generator.render(sync=sync)) + console.print(f"Python protocols exported in {output_file}") else: - console.print(rendered) - console.print(rendered_sync) + console.print(code_generator.render(sync=sync)) @app.command(name="version") diff --git a/infrahub_sdk/ctl/constants.py b/infrahub_sdk/ctl/constants.py index 24887197..5c82c8a4 100644 --- a/infrahub_sdk/ctl/constants.py +++ b/infrahub_sdk/ctl/constants.py @@ -14,6 +14,30 @@ {% else %} from infrahub_sdk.node import RelatedNode, RelationshipManager {% endif %} + from infrahub_sdk.protocols_base import ( + String, + StringOptional, + Integer, + IntegerOptional, + Boolean, + BooleanOptional, + DateTime, + DateTimeOptional, + Dropdown, + DropdownOptional, + HashedPassword, + HashedPasswordOptional, + IPHost, + IPHostOptional, + IPNetwork, + IPNetworkOptional, + JSONAttribute, + JSONAttributeOptional, + ListAttribute, + ListAttributeOptional, + URL, + URLOptional, + ) {% for generic in generics %} @@ -36,6 +60,7 @@ class {{ generic.namespace + generic.name }}(CoreNode): children: RelationshipManager {% endif %} {% endif %} + {% endfor %} @@ -59,5 +84,29 @@ class {{ node.namespace + node.name }}({{ node.inherit_from | join(", ") or "Cor children: RelationshipManager {% endif %} {% endif %} + +{% endfor %} + +{% for node in profiles %} +class {{ node.namespace + node.name }}({{ node.inherit_from | join(", ") or "CoreNode" }}): + {% if not node.attributes|default([]) and not node.relationships|default([]) %} + pass + {% endif %} + {% for attribute in node.attributes|default([]) %} + {{ attribute | render_attribute }} + {% endfor %} + {% for relationship in node.relationships|default([]) %} + {{ relationship | render_relationship(sync) }} + {% endfor %} + {% if node.hierarchical | default(false) %} + {% if sync %} + parent: RelatedNodeSync + children: RelationshipManagerSync + {% else %} + parent: RelatedNode + children: RelationshipManager + {% endif %} + {% endif %} + {% endfor %} """ diff --git a/infrahub_sdk/ctl/schema.py b/infrahub_sdk/ctl/schema.py index f6f2cda3..57ecc379 100644 --- a/infrahub_sdk/ctl/schema.py +++ b/infrahub_sdk/ctl/schema.py @@ -49,7 +49,7 @@ def load_schemas_from_disk(schemas: list[Path]) -> list[SchemaFile]: return schemas_data -def load_schemas_from_disk_and_exit(schemas: list[Path]): +def load_schemas_from_disk_and_exit(schemas: list[Path]) -> list[SchemaFile]: has_error = False try: schemas_data = load_schemas_from_disk(schemas=schemas) diff --git a/infrahub_sdk/ctl/utils.py b/infrahub_sdk/ctl/utils.py index 032a4d2b..73eb3e6d 100644 --- a/infrahub_sdk/ctl/utils.py +++ b/infrahub_sdk/ctl/utils.py @@ -18,7 +18,6 @@ from infrahub_sdk.exceptions import ( AuthenticationError, Error, - FilterNotFoundError, GraphQLError, NodeNotFoundError, SchemaNotFoundError, @@ -57,7 +56,7 @@ def handle_exception(exc: Exception, console: Console, exit_code: int): if isinstance(exc, GraphQLError): print_graphql_errors(console=console, errors=exc.errors) raise typer.Exit(code=exit_code) - if isinstance(exc, (SchemaNotFoundError, NodeNotFoundError, FilterNotFoundError)): + if isinstance(exc, (SchemaNotFoundError, NodeNotFoundError)): console.print(f"[red]Error: {str(exc)}") raise typer.Exit(code=exit_code) diff --git a/infrahub_sdk/exceptions.py b/infrahub_sdk/exceptions.py index b9628beb..1aa92e58 100644 --- a/infrahub_sdk/exceptions.py +++ b/infrahub_sdk/exceptions.py @@ -87,15 +87,6 @@ def __str__(self) -> str: """ -class FilterNotFoundError(Error): - def __init__(self, identifier: str, kind: str, message: Optional[str] = None, filters: Optional[list[str]] = None): - self.identifier = identifier - self.kind = kind - self.filters = filters or [] - self.message = message or f"{identifier!r} is not a valid filter for {self.kind!r} ({', '.join(self.filters)})." - super().__init__(self.message) - - class InfrahubCheckNotFoundError(Error): def __init__(self, name: str, message: Optional[str] = None): self.message = message or f"The requested InfrahubCheck '{name}' was not found." diff --git a/infrahub_sdk/node.py b/infrahub_sdk/node.py index b75ee77e..c3691f46 100644 --- a/infrahub_sdk/node.py +++ b/infrahub_sdk/node.py @@ -9,7 +9,6 @@ from infrahub_sdk.exceptions import ( Error, FeatureNotSupportedError, - FilterNotFoundError, NodeNotFoundError, UninitializedError, ) @@ -986,26 +985,6 @@ def generate_query_data_init( return data - def validate_filters(self, filters: Optional[dict[str, Any]] = None) -> bool: - if not filters: - return True - - for filter_name in filters.keys(): - found = False - for filter_schema in self._schema.filters: - if filter_name == filter_schema.name: - found = True - break - if not found: - valid_filters = [entry.name for entry in self._schema.filters] - raise FilterNotFoundError( - identifier=filter_name, - kind=self._schema.kind, - filters=valid_filters, - ) - - return True - def extract(self, params: dict[str, str]) -> dict[str, Any]: """Extract some datapoints defined in a flat notation.""" result: dict[str, Any] = {} diff --git a/infrahub_sdk/protocols.py b/infrahub_sdk/protocols.py index 9e66873f..048ec982 100644 --- a/infrahub_sdk/protocols.py +++ b/infrahub_sdk/protocols.py @@ -72,6 +72,10 @@ class CoreArtifactTarget(CoreNode): artifacts: RelationshipManager +class CoreBasePermission(CoreNode): + roles: RelationshipManager + + class CoreCheck(CoreNode): name: StringOptional label: StringOptional @@ -200,6 +204,16 @@ class CoreAccount(LineageOwner, LineageSource, CoreGenericAccount): pass +class CoreAccountGroup(CoreGroup): + roles: RelationshipManager + + +class CoreAccountRole(CoreNode): + name: String + groups: RelationshipManager + permissions: RelationshipManager + + class CoreArtifact(CoreTaskTarget): name: String status: Enum @@ -317,6 +331,11 @@ class CoreGeneratorValidator(CoreValidator): definition: RelatedNode +class CoreGlobalPermission(CoreBasePermission): + name: String + action: Dropdown + + class CoreGraphQLQuery(CoreNode): name: String description: StringOptional @@ -357,6 +376,14 @@ class CoreNumberPool(CoreResourcePool, LineageSource): end_range: Integer +class CoreObjectPermission(CoreBasePermission): + branch: String + namespace: String + name: String + action: Enum + decision: Enum + + class CoreObjectThread(CoreThread): object_path: String @@ -490,6 +517,10 @@ class CoreArtifactTargetSync(CoreNodeSync): artifacts: RelationshipManagerSync +class CoreBasePermissionSync(CoreNodeSync): + roles: RelationshipManagerSync + + class CoreCheckSync(CoreNodeSync): name: StringOptional label: StringOptional @@ -618,6 +649,16 @@ class CoreAccountSync(LineageOwnerSync, LineageSourceSync, CoreGenericAccountSyn pass +class CoreAccountGroupSync(CoreGroupSync): + roles: RelationshipManagerSync + + +class CoreAccountRoleSync(CoreNodeSync): + name: String + groups: RelationshipManagerSync + permissions: RelationshipManagerSync + + class CoreArtifactSync(CoreTaskTargetSync): name: String status: Enum @@ -735,6 +776,11 @@ class CoreGeneratorValidatorSync(CoreValidatorSync): definition: RelatedNodeSync +class CoreGlobalPermissionSync(CoreBasePermissionSync): + name: String + action: Dropdown + + class CoreGraphQLQuerySync(CoreNodeSync): name: String description: StringOptional @@ -775,6 +821,14 @@ class CoreNumberPoolSync(CoreResourcePoolSync, LineageSourceSync): end_range: Integer +class CoreObjectPermissionSync(CoreBasePermissionSync): + branch: String + namespace: String + name: String + action: Enum + decision: Enum + + class CoreObjectThreadSync(CoreThreadSync): object_path: String diff --git a/infrahub_sdk/protocols_base.py b/infrahub_sdk/protocols_base.py index 44f07927..e2520d01 100644 --- a/infrahub_sdk/protocols_base.py +++ b/infrahub_sdk/protocols_base.py @@ -1,11 +1,18 @@ from __future__ import annotations -from typing import Any, Optional, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any, Optional, Protocol, Union, runtime_checkable +if TYPE_CHECKING: + import ipaddress + from infrahub_sdk.schema import MainSchemaTypes + + +@runtime_checkable class RelatedNode(Protocol): ... +@runtime_checkable class RelatedNodeSync(Protocol): ... @@ -79,19 +86,19 @@ class DropdownOptional(Attribute): class IPNetwork(Attribute): - value: str + value: Union[ipaddress.IPv4Network, ipaddress.IPv6Network] class IPNetworkOptional(Attribute): - value: Optional[str] + value: Optional[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]] class IPHost(Attribute): - value: str + value: Union[ipaddress.IPv4Address, ipaddress.IPv6Address] class IPHostOptional(Attribute): - value: Optional[str] + value: Optional[Union[ipaddress.IPv4Address, ipaddress.IPv6Address]] class HashedPassword(Attribute): @@ -118,29 +125,58 @@ class ListAttributeOptional(Attribute): value: Optional[list[Any]] +@runtime_checkable class CoreNodeBase(Protocol): + _schema: MainSchemaTypes id: str display_label: Optional[str] - hfid: Optional[list[str]] - hfid_str: Optional[str] + @property + def hfid(self) -> Optional[list[str]]: ... + + @property + def hfid_str(self) -> Optional[str]: ... + + def get_human_friendly_id_as_string(self, include_kind: bool = False) -> Optional[str]: ... -class CoreNode(CoreNodeBase, Protocol): def get_kind(self) -> str: ... - async def save(self) -> None: ... + def is_ip_prefix(self) -> bool: ... + + def is_ip_address(self) -> bool: ... + + def is_resource_pool(self) -> bool: ... + + def get_raw_graphql_data(self) -> Optional[dict]: ... + + def extract(self, params: dict[str, str]) -> dict[str, Any]: ... + + +@runtime_checkable +class CoreNode(CoreNodeBase, Protocol): + async def save(self, allow_upsert: bool = False, update_group_context: Optional[bool] = None) -> None: ... + + async def delete(self) -> None: ... async def update(self, do_full_update: bool) -> None: ... + async def create(self, allow_upsert: bool = False) -> None: ... -class CoreNodeSync(CoreNodeBase, Protocol): - id: str - display_label: Optional[str] - hfid: Optional[list[str]] - hfid_str: Optional[str] + async def add_relationships(self, relation_to_update: str, related_nodes: list[str]) -> None: ... - def get_kind(self) -> str: ... + async def remove_relationships(self, relation_to_update: str, related_nodes: list[str]) -> None: ... + + +@runtime_checkable +class CoreNodeSync(CoreNodeBase, Protocol): + def save(self, allow_upsert: bool = False, update_group_context: Optional[bool] = None) -> None: ... - def save(self) -> None: ... + def delete(self) -> None: ... def update(self, do_full_update: bool) -> None: ... + + def create(self, allow_upsert: bool = False) -> None: ... + + def add_relationships(self, relation_to_update: str, related_nodes: list[str]) -> None: ... + + def remove_relationships(self, relation_to_update: str, related_nodes: list[str]) -> None: ... diff --git a/infrahub_sdk/schema.py b/infrahub_sdk/schema.py index 58dd06be..d7d4fc15 100644 --- a/infrahub_sdk/schema.py +++ b/infrahub_sdk/schema.py @@ -17,7 +17,7 @@ from infrahub_sdk.utils import duplicates if TYPE_CHECKING: - from infrahub_sdk.client import InfrahubClient, InfrahubClientSync + from infrahub_sdk.client import InfrahubClient, InfrahubClientSync, SchemaType, SchemaTypeSync from infrahub_sdk.node import InfrahubNode, InfrahubNodeSync InfrahubNodeTypes = Union[InfrahubNode, InfrahubNodeSync] @@ -511,30 +511,47 @@ def _validate_load_schema_response(response: httpx.Response) -> SchemaLoadRespon raise InvalidResponseError(message=f"Invalid response received from server HTTP {response.status_code}") + @staticmethod + def _get_schema_name(schema: Union[type[Union[SchemaType, SchemaTypeSync]], str]) -> str: + if hasattr(schema, "_is_runtime_protocol") and schema._is_runtime_protocol: # type: ignore[union-attr] + return schema.__name__ # type: ignore[union-attr] + + if isinstance(schema, str): + return schema + + raise ValueError("schema must be a protocol or a string") + class InfrahubSchema(InfrahubSchemaBase): def __init__(self, client: InfrahubClient): self.client = client self.cache: dict = defaultdict(lambda: dict) - async def get(self, kind: str, branch: Optional[str] = None, refresh: bool = False) -> MainSchemaTypes: + async def get( + self, + kind: Union[type[Union[SchemaType, SchemaTypeSync]], str], + branch: Optional[str] = None, + refresh: bool = False, + ) -> MainSchemaTypes: branch = branch or self.client.default_branch + kind_str = self._get_schema_name(schema=kind) + if refresh: self.cache[branch] = await self.fetch(branch=branch) - if branch in self.cache and kind in self.cache[branch]: - return self.cache[branch][kind] + if branch in self.cache and kind_str in self.cache[branch]: + return self.cache[branch][kind_str] # Fetching the latest schema from the server if we didn't fetch it earlier # because we coulnd't find the object on the local cache if not refresh: self.cache[branch] = await self.fetch(branch=branch) - if branch in self.cache and kind in self.cache[branch]: - return self.cache[branch][kind] + if branch in self.cache and kind_str in self.cache[branch]: + return self.cache[branch][kind_str] - raise SchemaNotFoundError(identifier=kind) + raise SchemaNotFoundError(identifier=kind_str) async def all( self, branch: Optional[str] = None, refresh: bool = False, namespaces: Optional[list[str]] = None @@ -760,24 +777,31 @@ def all( return self.cache[branch] - def get(self, kind: str, branch: Optional[str] = None, refresh: bool = False) -> MainSchemaTypes: + def get( + self, + kind: Union[type[Union[SchemaType, SchemaTypeSync]], str], + branch: Optional[str] = None, + refresh: bool = False, + ) -> MainSchemaTypes: branch = branch or self.client.default_branch + kind_str = self._get_schema_name(schema=kind) + if refresh: self.cache[branch] = self.fetch(branch=branch) - if branch in self.cache and kind in self.cache[branch]: - return self.cache[branch][kind] + if branch in self.cache and kind_str in self.cache[branch]: + return self.cache[branch][kind_str] # Fetching the latest schema from the server if we didn't fetch it earlier # because we coulnd't find the object on the local cache if not refresh: self.cache[branch] = self.fetch(branch=branch) - if branch in self.cache and kind in self.cache[branch]: - return self.cache[branch][kind] + if branch in self.cache and kind_str in self.cache[branch]: + return self.cache[branch][kind_str] - raise SchemaNotFoundError(identifier=kind) + raise SchemaNotFoundError(identifier=kind_str) def _get_kind_and_attribute_schema( self, kind: Union[str, InfrahubNodeTypes], attribute: str, branch: Optional[str] = None diff --git a/infrahub_sdk/store.py b/infrahub_sdk/store.py index 1384bbd8..2289ce48 100644 --- a/infrahub_sdk/store.py +++ b/infrahub_sdk/store.py @@ -6,9 +6,20 @@ from infrahub_sdk.exceptions import NodeNotFoundError if TYPE_CHECKING: + from infrahub_sdk.client import SchemaType from infrahub_sdk.node import InfrahubNode, InfrahubNodeSync +def get_schema_name(schema: Optional[Union[str, type[SchemaType]]] = None) -> Optional[str]: + if isinstance(schema, str): + return schema + + if hasattr(schema, "_is_runtime_protocol") and schema._is_runtime_protocol: # type: ignore[union-attr] + return schema.__name__ # type: ignore[union-attr] + + return None + + class NodeStoreBase: """Internal Store for InfrahubNode objects. @@ -20,7 +31,7 @@ def __init__(self) -> None: self._store: dict[str, dict] = defaultdict(dict) self._store_by_hfid: dict[str, Any] = defaultdict(dict) - def _set(self, node: Union[InfrahubNode, InfrahubNodeSync], key: Optional[str] = None) -> None: + def _set(self, node: Union[InfrahubNode, InfrahubNodeSync, SchemaType], key: Optional[str] = None) -> None: hfid = node.get_human_friendly_id_as_string(include_kind=True) if not key and not hfid: @@ -33,18 +44,19 @@ def _set(self, node: Union[InfrahubNode, InfrahubNodeSync], key: Optional[str] = if hfid: self._store_by_hfid[hfid] = node - def _get(self, key: str, kind: Optional[str] = None, raise_when_missing: bool = True): # type: ignore[no-untyped-def] - if kind and kind not in self._store and key not in self._store[kind]: # type: ignore[attr-defined] + def _get(self, key: str, kind: Optional[Union[str, type[SchemaType]]] = None, raise_when_missing: bool = True): # type: ignore[no-untyped-def] + kind_name = get_schema_name(schema=kind) + if kind_name and kind_name not in self._store and key not in self._store[kind_name]: # type: ignore[attr-defined] if not raise_when_missing: return None raise NodeNotFoundError( - node_type=kind, + node_type=kind_name, identifier={"key": [key]}, message="Unable to find the node in the Store", ) - if kind and kind in self._store and key in self._store[kind]: # type: ignore[attr-defined] - return self._store[kind][key] # type: ignore[attr-defined] + if kind_name and kind_name in self._store and key in self._store[kind_name]: # type: ignore[attr-defined] + return self._store[kind_name][key] # type: ignore[attr-defined] for _, item in self._store.items(): # type: ignore[attr-defined] if key in item: @@ -73,14 +85,30 @@ def _get_by_hfid(self, key: str, raise_when_missing: bool = True): # type: igno class NodeStore(NodeStoreBase): @overload - def get(self, key: str, kind: Optional[str] = None, raise_when_missing: Literal[True] = True) -> InfrahubNode: ... + def get(self, key: str, kind: type[SchemaType], raise_when_missing: Literal[True] = True) -> SchemaType: ... @overload def get( - self, key: str, kind: Optional[str] = None, raise_when_missing: Literal[False] = False + self, key: str, kind: type[SchemaType], raise_when_missing: Literal[False] = False + ) -> Optional[SchemaType]: ... + + @overload + def get(self, key: str, kind: type[SchemaType], raise_when_missing: bool = ...) -> SchemaType: ... + + @overload + def get( + self, key: str, kind: Optional[str] = ..., raise_when_missing: Literal[False] = False ) -> Optional[InfrahubNode]: ... - def get(self, key: str, kind: Optional[str] = None, raise_when_missing: bool = True) -> Optional[InfrahubNode]: + @overload + def get(self, key: str, kind: Optional[str] = ..., raise_when_missing: Literal[True] = True) -> InfrahubNode: ... + + @overload + def get(self, key: str, kind: Optional[str] = ..., raise_when_missing: bool = ...) -> InfrahubNode: ... + + def get( + self, key: str, kind: Optional[Union[str, type[SchemaType]]] = None, raise_when_missing: bool = True + ) -> Optional[Union[InfrahubNode, SchemaType]]: return self._get(key=key, kind=kind, raise_when_missing=raise_when_missing) @overload @@ -92,7 +120,7 @@ def get_by_hfid(self, key: str, raise_when_missing: Literal[False] = False) -> O def get_by_hfid(self, key: str, raise_when_missing: bool = True) -> Optional[InfrahubNode]: return self._get_by_hfid(key=key, raise_when_missing=raise_when_missing) - def set(self, node: InfrahubNode, key: Optional[str] = None) -> None: + def set(self, node: Any, key: Optional[str] = None) -> None: return self._set(node=node, key=key) diff --git a/infrahub_sdk/transfer/importer/json.py b/infrahub_sdk/transfer/importer/json.py index 20b1ac2a..f6ceeeda 100644 --- a/infrahub_sdk/transfer/importer/json.py +++ b/infrahub_sdk/transfer/importer/json.py @@ -129,6 +129,8 @@ async def update_optional_relationships(self) -> None: # Check if we are in a many-many relationship, ignore importing it if it is if relationship_schema.cardinality == "many": + if relationship_schema.peer not in self.schemas_by_kind: + continue for peer_relationship in self.schemas_by_kind[relationship_schema.peer].relationships: if peer_relationship.cardinality == "many" and peer_relationship.peer == node_kind: ignore = True diff --git a/infrahub_sdk/uuidt.py b/infrahub_sdk/uuidt.py index cd3d553d..d0f6418d 100644 --- a/infrahub_sdk/uuidt.py +++ b/infrahub_sdk/uuidt.py @@ -40,7 +40,7 @@ def __init__( timestamp: Optional[int] = None, hostname: Optional[str] = None, random_chars: Optional[str] = None, - ): + ) -> None: self.namespace = namespace or DEFAULT_NAMESPACE self.timestamp = timestamp or time.time_ns() self.hostname = hostname or HOSTNAME diff --git a/pyproject.toml b/pyproject.toml index 7401bc73..2dc2a983 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "infrahub-sdk" -version = "0.13.1" +version = "0.14.0" description = "Python Client to interact with Infrahub" authors = ["OpsMill "] readme = "README.md" diff --git a/tests/integration/test_export_import.py b/tests/integration/test_export_import.py index d6603816..df2247b6 100644 --- a/tests/integration/test_export_import.py +++ b/tests/integration/test_export_import.py @@ -207,13 +207,14 @@ async def test_step01_export_no_schema(self, client: InfrahubClient, temporary_d assert relationships_file.exists() # Verify that only the admin account has been exported - admin_account_node_dump = ujson.loads(nodes_file.read_text()) + with nodes_file.open() as f: + admin_account_node_dump = ujson.loads(f.readline()) assert admin_account_node_dump assert admin_account_node_dump["kind"] == "CoreAccount" assert ujson.loads(admin_account_node_dump["graphql_json"])["name"]["value"] == "admin" relationships_dump = ujson.loads(relationships_file.read_text()) - assert not relationships_dump + assert relationships_dump async def test_step02_import_no_schema(self, client: InfrahubClient, temporary_directory: Path): importer = LineDelimitedJSONImporter(client=client, topological_sorter=InfrahubSchemaTopologicalSorter()) @@ -241,13 +242,14 @@ async def test_step03_export_empty_dataset(self, client: InfrahubClient, tempora assert relationships_file.exists() # Verify that only the admin account has been exported - admin_account_node_dump = ujson.loads(nodes_file.read_text()) + with nodes_file.open() as f: + admin_account_node_dump = ujson.loads(f.readline()) assert admin_account_node_dump assert admin_account_node_dump["kind"] == "CoreAccount" assert ujson.loads(admin_account_node_dump["graphql_json"])["name"]["value"] == "admin" relationships_dump = ujson.loads(relationships_file.read_text()) - assert not relationships_dump + assert relationships_dump async def test_step04_import_empty_dataset(self, client: InfrahubClient, temporary_directory: Path, schema): await client.schema.load(schemas=[schema]) @@ -280,10 +282,10 @@ async def test_step05_export_initial_dataset( with nodes_file.open() as reader: while line := reader.readline(): nodes_dump.append(ujson.loads(line)) - assert len(nodes_dump) == len(initial_dataset) + 1 + assert len(nodes_dump) == len(initial_dataset) + 5 # add number to account for default data relationships_dump = ujson.loads(relationships_file.read_text()) - assert not relationships_dump + assert relationships_dump async def test_step06_import_initial_dataset(self, client: InfrahubClient, temporary_directory: Path, schema): await client.schema.load(schemas=[schema]) @@ -366,6 +368,7 @@ def schema_car_base(self) -> Dict[str, Any]: "optional": True, "peer": "TestingPool", "cardinality": "many", + "identifier": "car__pool", }, { "name": "manufacturer", @@ -490,7 +493,7 @@ async def test_step01_export_initial_dataset( with nodes_file.open() as reader: while line := reader.readline(): nodes_dump.append(ujson.loads(line)) - assert len(nodes_dump) == len(initial_dataset) + 1 + assert len(nodes_dump) == len(initial_dataset) + 5 # add number to account for default data # Make sure there are as many relationships as there are in the database relationship_count = 0 @@ -498,7 +501,7 @@ async def test_step01_export_initial_dataset( await node.cars.fetch() relationship_count += len(node.cars.peers) relationships_dump = ujson.loads(relationships_file.read_text()) - assert len(relationships_dump) == relationship_count + assert len(relationships_dump) == relationship_count + 1 # add number to account for default data async def test_step02_import_initial_dataset(self, client: InfrahubClient, temporary_directory: Path, schema): await client.schema.load(schemas=[schema]) @@ -517,7 +520,7 @@ async def test_step02_import_initial_dataset(self, client: InfrahubClient, tempo relationship_count += len(node.cars.peers) relationships_file = temporary_directory / "relationships.json" relationships_dump = ujson.loads(relationships_file.read_text()) - assert len(relationships_dump) == relationship_count + assert len(relationships_dump) == relationship_count + 1 # add number to account for default data async def test_step03_import_initial_dataset_with_existing_data( self, client: InfrahubClient, temporary_directory: Path, initial_dataset @@ -536,7 +539,7 @@ async def test_step03_import_initial_dataset_with_existing_data( relationship_count += len(node.cars.peers) relationships_file = temporary_directory / "relationships.json" relationships_dump = ujson.loads(relationships_file.read_text()) - assert len(relationships_dump) == relationship_count + assert len(relationships_dump) == relationship_count + 1 # add number to account for default data # Cleanup for next tests self.reset_export_directory(temporary_directory) diff --git a/tests/unit/sdk/conftest.py b/tests/unit/sdk/conftest.py index 680bc4ae..d39be706 100644 --- a/tests/unit/sdk/conftest.py +++ b/tests/unit/sdk/conftest.py @@ -59,10 +59,20 @@ def replace_async_return_annotation(): def replace_annotation(annotation: str) -> str: replacements = { + "type[SchemaType]": "type[SchemaTypeSync]", + "SchemaType": "SchemaTypeSync", + "CoreNode": "CoreNodeSync", + "Optional[CoreNode]": "Optional[CoreNodeSync]", + "Union[str, type[SchemaType]]": "Union[str, type[SchemaTypeSync]]", + "Union[InfrahubNode, SchemaType]": "Union[InfrahubNodeSync, SchemaTypeSync]", + "Union[InfrahubNode, SchemaType, None]": "Union[InfrahubNodeSync, SchemaTypeSync, None]", + "Union[list[InfrahubNode], list[SchemaType]]": "Union[list[InfrahubNodeSync], list[SchemaTypeSync]]", "InfrahubClient": "InfrahubClientSync", "InfrahubNode": "InfrahubNodeSync", "list[InfrahubNode]": "list[InfrahubNodeSync]", "Optional[InfrahubNode]": "Optional[InfrahubNodeSync]", + "Optional[type[SchemaType]]": "Optional[type[SchemaTypeSync]]", + "Optional[Union[CoreNode, SchemaType]]": "Optional[Union[CoreNodeSync, SchemaTypeSync]]", } return replacements.get(annotation) or annotation @@ -89,10 +99,20 @@ def replace_sync_return_annotation() -> str: def replace_annotation(annotation: str) -> str: replacements = { + "type[SchemaTypeSync]": "type[SchemaType]", + "SchemaTypeSync": "SchemaType", + "CoreNodeSync": "CoreNode", + "Optional[CoreNodeSync]": "Optional[CoreNode]", + "Union[str, type[SchemaTypeSync]]": "Union[str, type[SchemaType]]", + "Union[InfrahubNodeSync, SchemaTypeSync]": "Union[InfrahubNode, SchemaType]", + "Union[InfrahubNodeSync, SchemaTypeSync, None]": "Union[InfrahubNode, SchemaType, None]", + "Union[list[InfrahubNodeSync], list[SchemaTypeSync]]": "Union[list[InfrahubNode], list[SchemaType]]", "InfrahubClientSync": "InfrahubClient", "InfrahubNodeSync": "InfrahubNode", "list[InfrahubNodeSync]": "list[InfrahubNode]", "Optional[InfrahubNodeSync]": "Optional[InfrahubNode]", + "Optional[type[SchemaTypeSync]]": "Optional[type[SchemaType]]", + "Optional[Union[CoreNodeSync, SchemaTypeSync]]": "Optional[Union[CoreNode, SchemaType]]", } return replacements.get(annotation) or annotation diff --git a/tests/unit/sdk/test_client.py b/tests/unit/sdk/test_client.py index eb0f0930..b34ec877 100644 --- a/tests/unit/sdk/test_client.py +++ b/tests/unit/sdk/test_client.py @@ -4,7 +4,7 @@ from pytest_httpx import HTTPXMock from infrahub_sdk import InfrahubClient, InfrahubClientSync -from infrahub_sdk.exceptions import FilterNotFoundError, NodeNotFoundError +from infrahub_sdk.exceptions import NodeNotFoundError from infrahub_sdk.node import InfrahubNode, InfrahubNodeSync async_client_methods = [method for method in dir(InfrahubClient) if not method.startswith("_")] @@ -314,6 +314,20 @@ async def test_method_get_not_found(httpx_mock: HTTPXMock, clients, mock_query_r clients.sync.get(kind="CoreRepository", name__value="infrahub-demo-core") +@pytest.mark.parametrize("client_type", client_types) +async def test_method_get_not_found_none( + httpx_mock: HTTPXMock, clients, mock_query_repository_page1_empty, client_type +): # pylint: disable=unused-argument + if client_type == "standard": + response = await clients.standard.get( + kind="CoreRepository", name__value="infrahub-demo-core", raise_when_missing=False + ) + else: + response = clients.sync.get(kind="CoreRepository", name__value="infrahub-demo-core", raise_when_missing=False) + + assert response is None + + @pytest.mark.parametrize("client_type", client_types) async def test_method_get_found_many( httpx_mock: HTTPXMock, @@ -329,19 +343,6 @@ async def test_method_get_found_many( clients.sync.get(kind="CoreRepository", id="bfae43e8-5ebb-456c-a946-bf64e930710a") -@pytest.mark.parametrize("client_type", client_types) -async def test_method_get_invalid_filter(httpx_mock: HTTPXMock, clients, mock_schema_query_01, client_type): # pylint: disable=unused-argument - with pytest.raises(FilterNotFoundError) as excinfo: - if client_type == "standard": - await clients.standard.get(kind="CoreRepository", name__name="infrahub-demo-core") - else: - clients.sync.get(kind="CoreRepository", name__name="infrahub-demo-core") - assert isinstance(excinfo.value.message, str) - assert "'name__name' is not a valid filter for 'CoreRepository'" in excinfo.value.message - assert "default_branch__value" in excinfo.value.message - assert "default_branch__value" in excinfo.value.filters - - @pytest.mark.parametrize("client_type", client_types) async def test_method_filters_many(httpx_mock: HTTPXMock, clients, mock_query_repository_page1_1, client_type): # pylint: disable=unused-argument if client_type == "standard":