From e797e10dad836b768edf9add9d364826bf464f85 Mon Sep 17 00:00:00 2001 From: Phillip Simonds Date: Thu, 17 Oct 2024 14:59:15 -0700 Subject: [PATCH 1/2] Add support for specific timout per request per issue #25 --- changelog/25.added.md | 1 + infrahub_sdk/client.py | 74 ++++++++++++++++++++++--- infrahub_sdk/node.py | 121 +++++++++++++++++++++++++++++------------ infrahub_sdk/schema.py | 22 +++++--- 4 files changed, 167 insertions(+), 51 deletions(-) create mode 100644 changelog/25.added.md diff --git a/changelog/25.added.md b/changelog/25.added.md new file mode 100644 index 00000000..5ea6dad6 --- /dev/null +++ b/changelog/25.added.md @@ -0,0 +1 @@ +Add support for specific timeout per request on InfrahubClient and InfrahubNode function calls. diff --git a/infrahub_sdk/client.py b/infrahub_sdk/client.py index 29310ac3..8cb54a33 100644 --- a/infrahub_sdk/client.py +++ b/infrahub_sdk/client.py @@ -324,6 +324,7 @@ async def get( raise_when_missing: Literal[False], at: Optional[Timestamp] = ..., branch: Optional[str] = ..., + timeout: Optional[int] = ..., id: Optional[str] = ..., hfid: Optional[list[str]] = ..., include: Optional[list[str]] = ..., @@ -341,6 +342,7 @@ async def get( raise_when_missing: Literal[True], at: Optional[Timestamp] = ..., branch: Optional[str] = ..., + timeout: Optional[int] = ..., id: Optional[str] = ..., hfid: Optional[list[str]] = ..., include: Optional[list[str]] = ..., @@ -358,6 +360,7 @@ async def get( raise_when_missing: bool = ..., at: Optional[Timestamp] = ..., branch: Optional[str] = ..., + timeout: Optional[int] = ..., id: Optional[str] = ..., hfid: Optional[list[str]] = ..., include: Optional[list[str]] = ..., @@ -375,6 +378,7 @@ async def get( raise_when_missing: Literal[False], at: Optional[Timestamp] = ..., branch: Optional[str] = ..., + timeout: Optional[int] = ..., id: Optional[str] = ..., hfid: Optional[list[str]] = ..., include: Optional[list[str]] = ..., @@ -392,6 +396,7 @@ async def get( raise_when_missing: Literal[True], at: Optional[Timestamp] = ..., branch: Optional[str] = ..., + timeout: Optional[int] = ..., id: Optional[str] = ..., hfid: Optional[list[str]] = ..., include: Optional[list[str]] = ..., @@ -409,6 +414,7 @@ async def get( raise_when_missing: bool = ..., at: Optional[Timestamp] = ..., branch: Optional[str] = ..., + timeout: Optional[int] = ..., id: Optional[str] = ..., hfid: Optional[list[str]] = ..., include: Optional[list[str]] = ..., @@ -425,6 +431,7 @@ async def get( raise_when_missing: bool = True, at: Optional[Timestamp] = None, branch: Optional[str] = None, + timeout: Optional[int] = None, id: Optional[str] = None, hfid: Optional[list[str]] = None, include: Optional[list[str]] = None, @@ -458,6 +465,7 @@ async def get( kind=kind, at=at, branch=branch, + timeout=timeout, populate_store=populate_store, include=include, exclude=exclude, @@ -476,7 +484,12 @@ async def get( return results[0] async def _process_nodes_and_relationships( - self, response: dict[str, Any], schema_kind: str, branch: str, prefetch_relationships: bool + self, + response: dict[str, Any], + schema_kind: str, + branch: str, + prefetch_relationships: bool, + timeout: Optional[int] = None, ) -> ProcessRelationsNode: """Processes InfrahubNode and their Relationships from the GraphQL query response. @@ -485,6 +498,7 @@ async def _process_nodes_and_relationships( schema_kind (str): The kind of schema being queried. branch (str): The branch name. prefetch_relationships (bool): Flag to indicate whether to prefetch relationship data. + timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds. Returns: ProcessRelationsNodeSync: A TypedDict containing two lists: @@ -496,11 +510,13 @@ async def _process_nodes_and_relationships( related_nodes: list[InfrahubNode] = [] for item in response.get(schema_kind, {}).get("edges", []): - node = await InfrahubNode.from_graphql(client=self, branch=branch, data=item) + node = await InfrahubNode.from_graphql(client=self, branch=branch, data=item, timeout=timeout) nodes.append(node) if prefetch_relationships: - await node._process_relationships(node_data=item, branch=branch, related_nodes=related_nodes) + await node._process_relationships( + node_data=item, branch=branch, related_nodes=related_nodes, timeout=timeout + ) return ProcessRelationsNode(nodes=nodes, related_nodes=related_nodes) @@ -510,6 +526,7 @@ async def all( kind: type[SchemaType], at: Optional[Timestamp] = ..., branch: Optional[str] = ..., + timeout: Optional[int] = ..., populate_store: bool = ..., offset: Optional[int] = ..., limit: Optional[int] = ..., @@ -525,6 +542,7 @@ async def all( kind: str, at: Optional[Timestamp] = ..., branch: Optional[str] = ..., + timeout: Optional[int] = ..., populate_store: bool = ..., offset: Optional[int] = ..., limit: Optional[int] = ..., @@ -539,6 +557,7 @@ async def all( kind: Union[str, type[SchemaType]], at: Optional[Timestamp] = None, branch: Optional[str] = None, + timeout: Optional[int] = None, populate_store: bool = False, offset: Optional[int] = None, limit: Optional[int] = None, @@ -554,6 +573,7 @@ async def all( at (Timestamp, optional): Time of the query. Defaults to Now. branch (str, optional): Name of the branch to query from. Defaults to default_branch. populate_store (bool, optional): Flag to indicate whether to populate the store with the retrieved nodes. + timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds. offset (int, optional): The offset for pagination. limit (int, optional): The limit for pagination. include (list[str], optional): List of attributes or relationships to include in the query. @@ -568,6 +588,7 @@ async def all( kind=kind, at=at, branch=branch, + timeout=timeout, populate_store=populate_store, offset=offset, limit=limit, @@ -583,6 +604,7 @@ async def filters( kind: type[SchemaType], at: Optional[Timestamp] = ..., branch: Optional[str] = ..., + timeout: Optional[int] = ..., populate_store: bool = ..., offset: Optional[int] = ..., limit: Optional[int] = ..., @@ -600,6 +622,7 @@ async def filters( kind: str, at: Optional[Timestamp] = ..., branch: Optional[str] = ..., + timeout: Optional[int] = ..., populate_store: bool = ..., offset: Optional[int] = ..., limit: Optional[int] = ..., @@ -616,6 +639,7 @@ async def filters( kind: Union[str, type[SchemaType]], at: Optional[Timestamp] = None, branch: Optional[str] = None, + timeout: Optional[int] = None, populate_store: bool = False, offset: Optional[int] = None, limit: Optional[int] = None, @@ -632,6 +656,7 @@ async def filters( kind (str): kind of the nodes to query at (Timestamp, optional): Time of the query. Defaults to Now. branch (str, optional): Name of the branch to query from. Defaults to default_branch. + timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds. populate_store (bool, optional): Flag to indicate whether to populate the store with the retrieved nodes. offset (int, optional): The offset for pagination. limit (int, optional): The limit for pagination. @@ -679,10 +704,15 @@ async def filters( branch_name=branch, at=at, tracker=f"query-{str(schema.kind).lower()}-page{page_number}", + timeout=timeout, ) process_result: ProcessRelationsNode = await self._process_nodes_and_relationships( - response=response, schema_kind=schema.kind, branch=branch, prefetch_relationships=prefetch_relationships + response=response, + schema_kind=schema.kind, + branch=branch, + prefetch_relationships=prefetch_relationships, + timeout=timeout, ) nodes.extend(process_result["nodes"]) related_nodes.extend(process_result["related_nodes"]) @@ -1509,6 +1539,7 @@ def all( kind: type[SchemaTypeSync], at: Optional[Timestamp] = ..., branch: Optional[str] = ..., + timeout: Optional[int] = ..., populate_store: bool = ..., offset: Optional[int] = ..., limit: Optional[int] = ..., @@ -1524,6 +1555,7 @@ def all( kind: str, at: Optional[Timestamp] = ..., branch: Optional[str] = ..., + timeout: Optional[int] = ..., populate_store: bool = ..., offset: Optional[int] = ..., limit: Optional[int] = ..., @@ -1538,6 +1570,7 @@ def all( kind: Union[str, type[SchemaTypeSync]], at: Optional[Timestamp] = None, branch: Optional[str] = None, + timeout: Optional[int] = None, populate_store: bool = False, offset: Optional[int] = None, limit: Optional[int] = None, @@ -1552,6 +1585,7 @@ def all( kind (str): kind of the nodes to query at (Timestamp, optional): Time of the query. Defaults to Now. branch (str, optional): Name of the branch to query from. Defaults to default_branch. + timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds. populate_store (bool, optional): Flag to indicate whether to populate the store with the retrieved nodes. offset (int, optional): The offset for pagination. limit (int, optional): The limit for pagination. @@ -1567,6 +1601,7 @@ def all( kind=kind, at=at, branch=branch, + timeout=timeout, populate_store=populate_store, offset=offset, limit=limit, @@ -1577,7 +1612,12 @@ def all( ) def _process_nodes_and_relationships( - self, response: dict[str, Any], schema_kind: str, branch: str, prefetch_relationships: bool + self, + response: dict[str, Any], + schema_kind: str, + branch: str, + prefetch_relationships: bool, + timeout: Optional[int] = None, ) -> ProcessRelationsNodeSync: """Processes InfrahubNodeSync and their Relationships from the GraphQL query response. @@ -1586,6 +1626,7 @@ def _process_nodes_and_relationships( schema_kind (str): The kind of schema being queried. branch (str): The branch name. prefetch_relationships (bool): Flag to indicate whether to prefetch relationship data. + timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds. Returns: ProcessRelationsNodeSync: A TypedDict containing two lists: @@ -1597,11 +1638,11 @@ def _process_nodes_and_relationships( related_nodes: list[InfrahubNodeSync] = [] for item in response.get(schema_kind, {}).get("edges", []): - node = InfrahubNodeSync.from_graphql(client=self, branch=branch, data=item) + node = InfrahubNodeSync.from_graphql(client=self, branch=branch, data=item, timeout=timeout) nodes.append(node) if prefetch_relationships: - node._process_relationships(node_data=item, branch=branch, related_nodes=related_nodes) + node._process_relationships(node_data=item, branch=branch, related_nodes=related_nodes, timeout=timeout) return ProcessRelationsNodeSync(nodes=nodes, related_nodes=related_nodes) @@ -1611,6 +1652,7 @@ def filters( kind: type[SchemaTypeSync], at: Optional[Timestamp] = ..., branch: Optional[str] = ..., + timeout: Optional[int] = ..., populate_store: bool = ..., offset: Optional[int] = ..., limit: Optional[int] = ..., @@ -1628,6 +1670,7 @@ def filters( kind: str, at: Optional[Timestamp] = ..., branch: Optional[str] = ..., + timeout: Optional[int] = ..., populate_store: bool = ..., offset: Optional[int] = ..., limit: Optional[int] = ..., @@ -1644,6 +1687,7 @@ def filters( kind: Union[str, type[SchemaTypeSync]], at: Optional[Timestamp] = None, branch: Optional[str] = None, + timeout: Optional[int] = None, populate_store: bool = False, offset: Optional[int] = None, limit: Optional[int] = None, @@ -1660,6 +1704,7 @@ def filters( kind (str): kind of the nodes to query at (Timestamp, optional): Time of the query. Defaults to Now. branch (str, optional): Name of the branch to query from. Defaults to default_branch. + timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds. populate_store (bool, optional): Flag to indicate whether to populate the store with the retrieved nodes. offset (int, optional): The offset for pagination. limit (int, optional): The limit for pagination. @@ -1706,11 +1751,16 @@ def filters( query=query.render(), branch_name=branch, at=at, + timeout=timeout, tracker=f"query-{str(schema.kind).lower()}-page{page_number}", ) process_result: ProcessRelationsNodeSync = self._process_nodes_and_relationships( - response=response, schema_kind=schema.kind, branch=branch, prefetch_relationships=prefetch_relationships + response=response, + schema_kind=schema.kind, + branch=branch, + prefetch_relationships=prefetch_relationships, + timeout=timeout, ) nodes.extend(process_result["nodes"]) related_nodes.extend(process_result["related_nodes"]) @@ -1739,6 +1789,7 @@ def get( raise_when_missing: Literal[False], at: Optional[Timestamp] = ..., branch: Optional[str] = ..., + timeout: Optional[int] = ..., id: Optional[str] = ..., hfid: Optional[list[str]] = ..., include: Optional[list[str]] = ..., @@ -1756,6 +1807,7 @@ def get( raise_when_missing: Literal[True], at: Optional[Timestamp] = ..., branch: Optional[str] = ..., + timeout: Optional[int] = ..., id: Optional[str] = ..., hfid: Optional[list[str]] = ..., include: Optional[list[str]] = ..., @@ -1773,6 +1825,7 @@ def get( raise_when_missing: bool = ..., at: Optional[Timestamp] = ..., branch: Optional[str] = ..., + timeout: Optional[int] = ..., id: Optional[str] = ..., hfid: Optional[list[str]] = ..., include: Optional[list[str]] = ..., @@ -1790,6 +1843,7 @@ def get( raise_when_missing: Literal[False], at: Optional[Timestamp] = ..., branch: Optional[str] = ..., + timeout: Optional[int] = ..., id: Optional[str] = ..., hfid: Optional[list[str]] = ..., include: Optional[list[str]] = ..., @@ -1807,6 +1861,7 @@ def get( raise_when_missing: Literal[True], at: Optional[Timestamp] = ..., branch: Optional[str] = ..., + timeout: Optional[int] = ..., id: Optional[str] = ..., hfid: Optional[list[str]] = ..., include: Optional[list[str]] = ..., @@ -1824,6 +1879,7 @@ def get( raise_when_missing: bool = ..., at: Optional[Timestamp] = ..., branch: Optional[str] = ..., + timeout: Optional[int] = ..., id: Optional[str] = ..., hfid: Optional[list[str]] = ..., include: Optional[list[str]] = ..., @@ -1840,6 +1896,7 @@ def get( raise_when_missing: bool = True, at: Optional[Timestamp] = None, branch: Optional[str] = None, + timeout: Optional[int] = None, id: Optional[str] = None, hfid: Optional[list[str]] = None, include: Optional[list[str]] = None, @@ -1873,6 +1930,7 @@ def get( kind=kind, at=at, branch=branch, + timeout=timeout, populate_store=populate_store, include=include, exclude=exclude, diff --git a/infrahub_sdk/node.py b/infrahub_sdk/node.py index 04daaf50..18abfd39 100644 --- a/infrahub_sdk/node.py +++ b/infrahub_sdk/node.py @@ -314,11 +314,13 @@ def __init__( self._client = client super().__init__(branch=branch, schema=schema, data=data, name=name) - async def fetch(self) -> None: + async def fetch(self, timeout: Optional[int] = None) -> None: if not self.id or not self.typename: raise Error("Unable to fetch the peer, id and/or typename are not defined") - self._peer = await self._client.get(kind=self.typename, id=self.id, populate_store=True, branch=self._branch) + self._peer = await self._client.get( + kind=self.typename, id=self.id, populate_store=True, branch=self._branch, timeout=timeout + ) @property def peer(self) -> InfrahubNode: @@ -359,11 +361,13 @@ def __init__( self._client = client super().__init__(branch=branch, schema=schema, data=data, name=name) - def fetch(self) -> None: + def fetch(self, timeout: Optional[int] = None) -> None: if not self.id or not self.typename: raise Error("Unable to fetch the peer, id and/or typename are not defined") - self._peer = self._client.get(kind=self.typename, id=self.id, populate_store=True, branch=self._branch) + self._peer = self._client.get( + kind=self.typename, id=self.id, populate_store=True, branch=self._branch, timeout=timeout + ) @property def peer(self) -> InfrahubNodeSync: @@ -1045,13 +1049,18 @@ def __init__( @classmethod async def from_graphql( - cls, client: InfrahubClient, branch: str, data: dict, schema: Optional[MainSchemaTypes] = None + cls, + client: InfrahubClient, + branch: str, + data: dict, + schema: Optional[MainSchemaTypes] = None, + timeout: Optional[int] = None, ) -> Self: if not schema: node_kind = data.get("__typename", None) or data.get("node", {}).get("__typename", None) if not node_kind: raise ValueError("Unable to determine the type of the node, __typename not present in data") - schema = await client.schema.get(kind=node_kind, branch=branch) + schema = await client.schema.get(kind=node_kind, branch=branch, timeout=timeout) return cls(client=client, schema=schema, branch=branch, data=cls._strip_alias(data)) @@ -1104,7 +1113,7 @@ async def artifact_fetch(self, name: str) -> Union[str, dict[str, Any]]: content = await self._client.object_store.get(identifier=artifact.storage_id.value) # type: ignore[attr-defined] return content - async def delete(self) -> None: + async def delete(self, timeout: Optional[int] = None) -> None: input_data = {"data": {"id": self.id}} mutation_query = {"ok": None} query = Mutation( @@ -1115,14 +1124,17 @@ async def delete(self) -> None: await self._client.execute_graphql( query=query.render(), branch_name=self._branch, + timeout=timeout, tracker=f"mutation-{str(self._schema.kind).lower()}-delete", ) - async def save(self, allow_upsert: bool = False, update_group_context: Optional[bool] = None) -> None: + async def save( + self, allow_upsert: bool = False, update_group_context: Optional[bool] = None, timeout: Optional[int] = None + ) -> None: if self._existing is False or allow_upsert is True: - await self.create(allow_upsert=allow_upsert) + await self.create(allow_upsert=allow_upsert, timeout=timeout) else: - await self.update() + await self.update(timeout=timeout) if update_group_context is None and self._client.mode == InfrahubClientMode.TRACKING: update_group_context = True @@ -1297,7 +1309,9 @@ def _generate_mutation_query(self) -> dict[str, Any]: return query_result - async def _process_mutation_result(self, mutation_name: str, response: dict[str, Any]) -> None: + async def _process_mutation_result( + self, mutation_name: str, response: dict[str, Any], timeout: Optional[int] = None + ) -> None: object_response: dict[str, Any] = response[mutation_name]["object"] self.id = object_response["id"] self._existing = True @@ -1324,10 +1338,10 @@ async def _process_mutation_result(self, mutation_name: str, response: dict[str, related_node = RelatedNode( client=self._client, branch=self._branch, schema=rel.schema, data=allocated_resource ) - await related_node.fetch() + await related_node.fetch(timeout=timeout) setattr(self, rel_name, related_node) - async def create(self, allow_upsert: bool = False) -> None: + async def create(self, allow_upsert: bool = False, timeout: Optional[int] = None) -> None: mutation_query = self._generate_mutation_query() if allow_upsert: @@ -1345,11 +1359,15 @@ async def create(self, allow_upsert: bool = False) -> None: variables=input_data["mutation_variables"], ) response = await self._client.execute_graphql( - query=query.render(), branch_name=self._branch, tracker=tracker, variables=input_data["variables"] + query=query.render(), + branch_name=self._branch, + tracker=tracker, + variables=input_data["variables"], + timeout=timeout, ) - await self._process_mutation_result(mutation_name=mutation_name, response=response) + await self._process_mutation_result(mutation_name=mutation_name, response=response, timeout=timeout) - async def update(self, do_full_update: bool = False) -> None: + async def update(self, do_full_update: bool = False, timeout: Optional[int] = None) -> None: input_data = self._generate_input_data(exclude_unmodified=not do_full_update) mutation_query = self._generate_mutation_query() mutation_name = f"{self._schema.kind}Update" @@ -1363,13 +1381,14 @@ async def update(self, do_full_update: bool = False) -> None: response = await self._client.execute_graphql( query=query.render(), branch_name=self._branch, + timeout=timeout, tracker=f"mutation-{str(self._schema.kind).lower()}-update", variables=input_data["variables"], ) - await self._process_mutation_result(mutation_name=mutation_name, response=response) + await self._process_mutation_result(mutation_name=mutation_name, response=response, timeout=timeout) async def _process_relationships( - self, node_data: dict[str, Any], branch: str, related_nodes: list[InfrahubNode] + self, node_data: dict[str, Any], branch: str, related_nodes: list[InfrahubNode], timeout: Optional[int] = None ) -> None: """Processes the Relationships of a InfrahubNode and add Related Nodes to a list. @@ -1377,19 +1396,24 @@ async def _process_relationships( node_data (dict[str, Any]): The item from the GraphQL response corresponding to the node. branch (str): The branch name. related_nodes (list[InfrahubNode]): The list to which related nodes will be appended. + timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds. """ for rel_name in self._relationships: rel = getattr(self, rel_name) if rel and isinstance(rel, RelatedNode): relation = node_data["node"].get(rel_name) if relation.get("node", None): - related_node = await InfrahubNode.from_graphql(client=self._client, branch=branch, data=relation) + related_node = await InfrahubNode.from_graphql( + client=self._client, branch=branch, data=relation, timeout=timeout + ) related_nodes.append(related_node) elif rel and isinstance(rel, RelationshipManager): peers = node_data["node"].get(rel_name) if peers: for peer in peers["edges"]: - related_node = await InfrahubNode.from_graphql(client=self._client, branch=branch, data=peer) + related_node = await InfrahubNode.from_graphql( + client=self._client, branch=branch, data=peer, timeout=timeout + ) related_nodes.append(related_node) async def get_pool_allocated_resources(self, resource: InfrahubNode) -> list[InfrahubNode]: @@ -1522,13 +1546,18 @@ def __init__( @classmethod def from_graphql( - cls, client: InfrahubClientSync, branch: str, data: dict, schema: Optional[MainSchemaTypes] = None + cls, + client: InfrahubClientSync, + branch: str, + data: dict, + schema: Optional[MainSchemaTypes] = None, + timeout: Optional[int] = None, ) -> Self: if not schema: node_kind = data.get("__typename", None) or data.get("node", {}).get("__typename", None) if not node_kind: raise ValueError("Unable to determine the type of the node, __typename not present in data") - schema = client.schema.get(kind=node_kind, branch=branch) + schema = client.schema.get(kind=node_kind, branch=branch, timeout=timeout) return cls(client=client, schema=schema, branch=branch, data=cls._strip_alias(data)) @@ -1578,7 +1607,7 @@ def artifact_fetch(self, name: str) -> Union[str, dict[str, Any]]: content = self._client.object_store.get(identifier=artifact.storage_id.value) # type: ignore[attr-defined] return content - def delete(self) -> None: + def delete(self, timeout: Optional[int] = None) -> None: input_data = {"data": {"id": self.id}} mutation_query = {"ok": None} query = Mutation( @@ -1590,13 +1619,16 @@ def delete(self) -> None: query=query.render(), branch_name=self._branch, tracker=f"mutation-{str(self._schema.kind).lower()}-delete", + timeout=timeout, ) - def save(self, allow_upsert: bool = False, update_group_context: Optional[bool] = None) -> None: + def save( + self, allow_upsert: bool = False, update_group_context: Optional[bool] = None, timeout: Optional[int] = None + ) -> None: if self._existing is False or allow_upsert is True: - self.create(allow_upsert=allow_upsert) + self.create(allow_upsert=allow_upsert, timeout=timeout) else: - self.update() + self.update(timeout=timeout) if update_group_context is None and self._client.mode == InfrahubClientMode.TRACKING: update_group_context = True @@ -1770,7 +1802,9 @@ def _generate_mutation_query(self) -> dict[str, Any]: return query_result - def _process_mutation_result(self, mutation_name: str, response: dict[str, Any]) -> None: + def _process_mutation_result( + self, mutation_name: str, response: dict[str, Any], timeout: Optional[int] = None + ) -> None: object_response: dict[str, Any] = response[mutation_name]["object"] self.id = object_response["id"] self._existing = True @@ -1797,10 +1831,10 @@ def _process_mutation_result(self, mutation_name: str, response: dict[str, Any]) related_node = RelatedNodeSync( client=self._client, branch=self._branch, schema=rel.schema, data=allocated_resource ) - related_node.fetch() + related_node.fetch(timeout=timeout) setattr(self, rel_name, related_node) - def create(self, allow_upsert: bool = False) -> None: + def create(self, allow_upsert: bool = False, timeout: Optional[int] = None) -> None: mutation_query = self._generate_mutation_query() if allow_upsert: @@ -1819,11 +1853,15 @@ def create(self, allow_upsert: bool = False) -> None: ) response = self._client.execute_graphql( - query=query.render(), branch_name=self._branch, tracker=tracker, variables=input_data["variables"] + query=query.render(), + branch_name=self._branch, + tracker=tracker, + variables=input_data["variables"], + timeout=timeout, ) - self._process_mutation_result(mutation_name=mutation_name, response=response) + self._process_mutation_result(mutation_name=mutation_name, response=response, timeout=timeout) - def update(self, do_full_update: bool = False) -> None: + def update(self, do_full_update: bool = False, timeout: Optional[int] = None) -> None: input_data = self._generate_input_data(exclude_unmodified=not do_full_update) mutation_query = self._generate_mutation_query() mutation_name = f"{self._schema.kind}Update" @@ -1840,11 +1878,16 @@ def update(self, do_full_update: bool = False) -> None: branch_name=self._branch, tracker=f"mutation-{str(self._schema.kind).lower()}-update", variables=input_data["variables"], + timeout=timeout, ) - self._process_mutation_result(mutation_name=mutation_name, response=response) + self._process_mutation_result(mutation_name=mutation_name, response=response, timeout=timeout) def _process_relationships( - self, node_data: dict[str, Any], branch: str, related_nodes: list[InfrahubNodeSync] + self, + node_data: dict[str, Any], + branch: str, + related_nodes: list[InfrahubNodeSync], + timeout: Optional[int] = None, ) -> None: """Processes the Relationships of a InfrahubNodeSync and add Related Nodes to a list. @@ -1852,19 +1895,25 @@ def _process_relationships( node_data (dict[str, Any]): The item from the GraphQL response corresponding to the node. branch (str): The branch name. related_nodes (list[InfrahubNodeSync]): The list to which related nodes will be appended. + timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds. + """ for rel_name in self._relationships: rel = getattr(self, rel_name) if rel and isinstance(rel, RelatedNodeSync): relation = node_data["node"].get(rel_name) if relation.get("node", None): - related_node = InfrahubNodeSync.from_graphql(client=self._client, branch=branch, data=relation) + related_node = InfrahubNodeSync.from_graphql( + client=self._client, branch=branch, data=relation, timeout=timeout + ) related_nodes.append(related_node) elif rel and isinstance(rel, RelationshipManagerSync): peers = node_data["node"].get(rel_name) if peers: for peer in peers["edges"]: - related_node = InfrahubNodeSync.from_graphql(client=self._client, branch=branch, data=peer) + related_node = InfrahubNodeSync.from_graphql( + client=self._client, branch=branch, data=peer, timeout=timeout + ) related_nodes.append(related_node) def get_pool_allocated_resources(self, resource: InfrahubNodeSync) -> list[InfrahubNodeSync]: diff --git a/infrahub_sdk/schema.py b/infrahub_sdk/schema.py index d7d4fc15..41682ab6 100644 --- a/infrahub_sdk/schema.py +++ b/infrahub_sdk/schema.py @@ -532,13 +532,14 @@ async def get( kind: Union[type[Union[SchemaType, SchemaTypeSync]], str], branch: Optional[str] = None, refresh: bool = False, + timeout: Optional[int] = None, ) -> MainSchemaTypes: branch = branch or self.client.default_branch kind_str = self._get_schema_name(schema=kind) if refresh: - self.cache[branch] = await self.fetch(branch=branch) + self.cache[branch] = await self.fetch(branch=branch, timeout=timeout) if branch in self.cache and kind_str in self.cache[branch]: return self.cache[branch][kind_str] @@ -546,7 +547,7 @@ async def get( # Fetching the latest schema from the server if we didn't fetch it earlier # because we coulnd't find the object on the local cache if not refresh: - self.cache[branch] = await self.fetch(branch=branch) + self.cache[branch] = await self.fetch(branch=branch, timeout=timeout) if branch in self.cache and kind_str in self.cache[branch]: return self.cache[branch][kind_str] @@ -715,11 +716,14 @@ async def add_dropdown_option( dropdown_optional_args=dropdown_optional_args, ) - async def fetch(self, branch: str, namespaces: Optional[list[str]] = None) -> MutableMapping[str, MainSchemaTypes]: + async def fetch( + self, branch: str, namespaces: Optional[list[str]] = None, timeout: Optional[int] = None + ) -> MutableMapping[str, MainSchemaTypes]: """Fetch the schema from the server for a given branch. Args: branch (str): Name of the branch to fetch the schema for. + timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds. Returns: dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind @@ -730,7 +734,7 @@ async def fetch(self, branch: str, namespaces: Optional[list[str]] = None) -> Mu query_params = urlencode(url_parts) url = f"{self.client.address}/api/schema?{query_params}" - response = await self.client._get(url=url) + response = await self.client._get(url=url, timeout=timeout) response.raise_for_status() data: MutableMapping[str, Any] = response.json() @@ -782,6 +786,7 @@ def get( kind: Union[type[Union[SchemaType, SchemaTypeSync]], str], branch: Optional[str] = None, refresh: bool = False, + timeout: Optional[int] = None, ) -> MainSchemaTypes: branch = branch or self.client.default_branch @@ -796,7 +801,7 @@ def get( # Fetching the latest schema from the server if we didn't fetch it earlier # because we coulnd't find the object on the local cache if not refresh: - self.cache[branch] = self.fetch(branch=branch) + self.cache[branch] = self.fetch(branch=branch, timeout=timeout) if branch in self.cache and kind_str in self.cache[branch]: return self.cache[branch][kind_str] @@ -915,11 +920,14 @@ def add_dropdown_option( dropdown_optional_args=dropdown_optional_args, ) - def fetch(self, branch: str, namespaces: Optional[list[str]] = None) -> MutableMapping[str, MainSchemaTypes]: + def fetch( + self, branch: str, namespaces: Optional[list[str]] = None, timeout: Optional[int] = None + ) -> MutableMapping[str, MainSchemaTypes]: """Fetch the schema from the server for a given branch. Args: branch (str): Name of the branch to fetch the schema for. + timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds. Returns: dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind @@ -930,7 +938,7 @@ def fetch(self, branch: str, namespaces: Optional[list[str]] = None) -> MutableM query_params = urlencode(url_parts) url = f"{self.client.address}/api/schema?{query_params}" - response = self.client._get(url=url) + response = self.client._get(url=url, timeout=timeout) response.raise_for_status() data: MutableMapping[str, Any] = response.json() From ada9ac4b29464bb9be88f7f87c6ea456d1345843 Mon Sep 17 00:00:00 2001 From: Phillip Simonds Date: Thu, 17 Oct 2024 15:10:48 -0700 Subject: [PATCH 2/2] Add timeout param to InfrahubClient create method. --- infrahub_sdk/client.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/infrahub_sdk/client.py b/infrahub_sdk/client.py index ccffe1e6..9d7dbeb2 100644 --- a/infrahub_sdk/client.py +++ b/infrahub_sdk/client.py @@ -299,11 +299,12 @@ async def create( kind: Union[str, type[SchemaType]], data: Optional[dict] = None, branch: Optional[str] = None, + timeout: Optional[int] = None, **kwargs: Any, ) -> Union[InfrahubNode, SchemaType]: branch = branch or self.default_branch - schema = await self.schema.get(kind=kind, branch=branch) + schema = await self.schema.get(kind=kind, branch=branch, timeout=timeout) if not data and not kwargs: raise ValueError("Either data or a list of keywords but be provided") @@ -1430,10 +1431,11 @@ def create( kind: Union[str, type[SchemaTypeSync]], data: Optional[dict] = None, branch: Optional[str] = None, + timeout: Optional[int] = None, **kwargs: Any, ) -> Union[InfrahubNodeSync, SchemaTypeSync]: branch = branch or self.default_branch - schema = self.schema.get(kind=kind, branch=branch) + schema = self.schema.get(kind=kind, branch=branch, timeout=timeout) if not data and not kwargs: raise ValueError("Either data or a list of keywords but be provided")