From 8b8329cd0a54ac6707b30d632e1afc512606f546 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Mon, 9 Jun 2025 13:26:33 +0000 Subject: [PATCH] Jules was unable to complete the task in time. Please review the work done so far and provide feedback for Jules to continue. --- infrahub_sdk/_importer.py | 6 + infrahub_sdk/analyzer.py | 50 +- infrahub_sdk/async_typer.py | 37 ++ infrahub_sdk/batch.py | 77 ++- infrahub_sdk/branch.py | 179 +++++- infrahub_sdk/checks.py | 117 +++- infrahub_sdk/client.py | 1028 +++++++++++++++++++++++++----- infrahub_sdk/config.py | 43 ++ infrahub_sdk/constants.py | 10 +- infrahub_sdk/context.py | 1 + infrahub_sdk/ctl/check.py | 77 +++ infrahub_sdk/ctl/cli_commands.py | 25 +- infrahub_sdk/ctl/client.py | 45 ++ infrahub_sdk/ctl/config.py | 34 +- infrahub_sdk/ctl/exceptions.py | 8 + infrahub_sdk/ctl/exporter.py | 7 + infrahub_sdk/ctl/generator.py | 28 + infrahub_sdk/ctl/importer.py | 10 + infrahub_sdk/ctl/render.py | 16 + infrahub_sdk/ctl/repository.py | 25 + infrahub_sdk/ctl/schema.py | 64 +- infrahub_sdk/ctl/transform.py | 6 + infrahub_sdk/ctl/utils.py | 98 +++ 23 files changed, 1790 insertions(+), 201 deletions(-) diff --git a/infrahub_sdk/_importer.py b/infrahub_sdk/_importer.py index 27af0970..1e23301f 100644 --- a/infrahub_sdk/_importer.py +++ b/infrahub_sdk/_importer.py @@ -20,6 +20,12 @@ def import_module(module_path: Path, import_root: str | None = None, relative_pa module_path (Path): Absolute path of the module to import. import_root (Optional[str]): Absolute string path to the current repository. relative_path (Optional[str]): Relative string path between module_path and import_root. + + Returns: + ModuleType: The imported module. + + Raises: + ModuleImportError: If the module cannot be imported due to ModuleNotFoundError or SyntaxError. """ import_root = import_root or str(module_path.parent) diff --git a/infrahub_sdk/analyzer.py b/infrahub_sdk/analyzer.py index 5bca02e1..2030111a 100644 --- a/infrahub_sdk/analyzer.py +++ b/infrahub_sdk/analyzer.py @@ -18,6 +18,7 @@ class GraphQLQueryVariable(BaseModel): + """Represents a variable in a GraphQL query.""" name: str type: str required: bool = False @@ -25,12 +26,20 @@ class GraphQLQueryVariable(BaseModel): class GraphQLOperation(BaseModel): + """Represents a single operation within a GraphQL query.""" name: str | None = None operation_type: OperationType class GraphQLQueryAnalyzer: + """Analyzes GraphQL queries to extract information about operations, variables, and structure.""" def __init__(self, query: str, schema: GraphQLSchema | None = None): + """Initializes the GraphQLQueryAnalyzer. + + Args: + query: The GraphQL query string. + schema: The GraphQL schema. + """ self.query: str = query self.schema: GraphQLSchema | None = schema self.document: DocumentNode = parse(self.query) @@ -38,6 +47,11 @@ def __init__(self, query: str, schema: GraphQLSchema | None = None): @property def is_valid(self) -> tuple[bool, list[GraphQLError] | None]: + """Validates the query against the schema if provided. + + Returns: + A tuple containing a boolean indicating validity and a list of errors if any. + """ if self.schema is None: return False, [GraphQLError("Schema is not provided")] @@ -49,10 +63,16 @@ def is_valid(self) -> tuple[bool, list[GraphQLError] | None]: @property def nbr_queries(self) -> int: + """Returns the number of definitions in the query document.""" return len(self.document.definitions) @property def operations(self) -> list[GraphQLOperation]: + """Extracts all operations (queries, mutations, subscriptions) from the query. + + Returns: + A list of GraphQLOperation objects. + """ operations = [] for definition in self.document.definitions: if not isinstance(definition, OperationDefinitionNode): @@ -66,10 +86,20 @@ def operations(self) -> list[GraphQLOperation]: @property def contains_mutation(self) -> bool: + """Checks if the query contains any mutation operations. + + Returns: + True if a mutation is present, False otherwise. + """ return any(op.operation_type == OperationType.MUTATION for op in self.operations) @property def variables(self) -> list[GraphQLQueryVariable]: + """Extracts all variables defined in the query. + + Returns: + A list of GraphQLQueryVariable objects. + """ response = [] for definition in self.document.definitions: variable_definitions = getattr(definition, "variable_definitions", None) @@ -99,16 +129,32 @@ def variables(self) -> list[GraphQLQueryVariable]: return response async def calculate_depth(self) -> int: - """Number of nested levels in the query""" + """Calculates the maximum depth of nesting in the query's selection sets. + + Returns: + The maximum depth of the query. + """ fields = await self.get_fields() return calculate_dict_depth(data=fields) async def calculate_height(self) -> int: - """Total number of fields requested in the query""" + """Calculates the total number of fields requested across all operations in the query. + + Returns: + The total height (number of fields) of the query. + """ fields = await self.get_fields() return calculate_dict_height(data=fields) async def get_fields(self) -> dict[str, Any]: + """Extracts all fields requested in the query. + + This method parses the document definitions and extracts fields from + OperationDefinitionNode instances. + + Returns: + A dictionary representing the fields structure. + """ if not self._fields: fields = {} for definition in self.document.definitions: diff --git a/infrahub_sdk/async_typer.py b/infrahub_sdk/async_typer.py index 39014a32..8854f36a 100644 --- a/infrahub_sdk/async_typer.py +++ b/infrahub_sdk/async_typer.py @@ -9,8 +9,25 @@ class AsyncTyper(Typer): + """ + A Typer subclass that allows to run async functions. + + It overrides the `callback` and `command` decorators to wrap async functions + in `asyncio.run`. + """ + @staticmethod def maybe_run_async(decorator: Callable, func: Callable) -> Any: + """ + Wraps an async function in `asyncio.run` if it's a coroutine function. + + Args: + decorator: The decorator to apply (e.g., from `super().command`). + func: The function to potentially wrap. + + Returns: + The decorated function, possibly wrapped to run asyncio. + """ if inspect.iscoroutinefunction(func): @wraps(func) @@ -23,9 +40,29 @@ def runner(*args: Any, **kwargs: Any) -> Any: return func def callback(self, *args: Any, **kwargs: Any) -> Any: + """ + Overrides the Typer.callback decorator to support async functions. + + Args: + *args: Positional arguments for Typer.callback. + **kwargs: Keyword arguments for Typer.callback. + + Returns: + A decorator that can handle both sync and async callback functions. + """ decorator = super().callback(*args, **kwargs) return partial(self.maybe_run_async, decorator) def command(self, *args: Any, **kwargs: Any) -> Any: + """ + Overrides the Typer.command decorator to support async functions. + + Args: + *args: Positional arguments for Typer.command. + **kwargs: Keyword arguments for Typer.command. + + Returns: + A decorator that can handle both sync and async command functions. + """ decorator = super().command(*args, **kwargs) return partial(self.maybe_run_async, decorator) diff --git a/infrahub_sdk/batch.py b/infrahub_sdk/batch.py index 2f1db7f4..40010d2e 100644 --- a/infrahub_sdk/batch.py +++ b/infrahub_sdk/batch.py @@ -12,6 +12,7 @@ @dataclass class BatchTask: + """Represents a single asynchronous task in a batch.""" task: Callable[[Any], Awaitable[Any]] args: tuple[Any, ...] kwargs: dict[str, Any] @@ -20,13 +21,24 @@ class BatchTask: @dataclass class BatchTaskSync: + """Represents a single synchronous task in a batch.""" task: Callable[..., Any] args: tuple[Any, ...] kwargs: dict[str, Any] node: InfrahubNodeSync | None = None def execute(self, return_exceptions: bool = False) -> tuple[InfrahubNodeSync | None, Any]: - """Executes the stored task.""" + """Executes the stored synchronous task. + + Args: + return_exceptions: If True, exceptions are returned instead of raised. + + Returns: + A tuple containing the task's node (if any) and the result or exception. + + Raises: + Exception: If `return_exceptions` is False and the task raises an exception. + """ result = None try: result = self.task(*self.args, **self.kwargs) @@ -41,6 +53,16 @@ def execute(self, return_exceptions: bool = False) -> tuple[InfrahubNodeSync | N async def execute_batch_task_in_pool( task: BatchTask, semaphore: asyncio.Semaphore, return_exceptions: bool = False ) -> tuple[InfrahubNode | None, Any]: + """Executes a BatchTask within a semaphore-controlled pool. + + Args: + task: The BatchTask to execute. + semaphore: An asyncio.Semaphore to limit concurrent executions. + return_exceptions: If True, exceptions are returned instead of raised. + + Returns: + A tuple containing the task's node (if any) and the result or exception. + """ async with semaphore: try: result = await task.task(*task.args, **task.kwargs) @@ -53,24 +75,51 @@ async def execute_batch_task_in_pool( class InfrahubBatch: + """Manages and executes a batch of asynchronous tasks concurrently.""" def __init__( self, semaphore: asyncio.Semaphore | None = None, max_concurrent_execution: int = 5, return_exceptions: bool = False, ): + """Initializes the InfrahubBatch. + + Args: + semaphore: An asyncio.Semaphore to limit concurrent executions. + If None, a new one is created with `max_concurrent_execution`. + max_concurrent_execution: The maximum number of tasks to run concurrently. + Only used if `semaphore` is None. + return_exceptions: If True, exceptions from tasks are returned instead of raised. + """ self._tasks: list[BatchTask] = [] self.semaphore = semaphore or asyncio.Semaphore(value=max_concurrent_execution) self.return_exceptions = return_exceptions @property def num_tasks(self) -> int: + """Returns the number of tasks currently in the batch.""" return len(self._tasks) def add(self, *args: Any, task: Callable, node: Any | None = None, **kwargs: Any) -> None: + """Adds a new task to the batch. + + Args: + task: The callable to be executed. + node: An optional node associated with this task. + *args: Positional arguments to pass to the task. + **kwargs: Keyword arguments to pass to the task. + """ self._tasks.append(BatchTask(task=task, node=node, args=args, kwargs=kwargs)) - async def execute(self) -> AsyncGenerator: + async def execute(self) -> AsyncGenerator[tuple[InfrahubNode | None, Any], None, None]: + """Executes all tasks in the batch concurrently. + + Yields: + A tuple containing the task's node (if any) and the result or exception. + + Raises: + Exception: If `return_exceptions` is False and a task raises an exception. + """ tasks = [] for batch_task in self._tasks: @@ -90,19 +139,43 @@ async def execute(self) -> AsyncGenerator: class InfrahubBatchSync: + """Manages and executes a batch of synchronous tasks concurrently using a thread pool.""" def __init__(self, max_concurrent_execution: int = 5, return_exceptions: bool = False): + """Initializes the InfrahubBatchSync. + + Args: + max_concurrent_execution: The maximum number of tasks to run concurrently in the thread pool. + return_exceptions: If True, exceptions from tasks are returned instead of raised. + """ self._tasks: list[BatchTaskSync] = [] self.max_concurrent_execution = max_concurrent_execution self.return_exceptions = return_exceptions @property def num_tasks(self) -> int: + """Returns the number of tasks currently in the batch.""" return len(self._tasks) def add(self, *args: Any, task: Callable[..., Any], node: Any | None = None, **kwargs: Any) -> None: + """Adds a new synchronous task to the batch. + + Args: + task: The callable to be executed. + node: An optional node associated with this task. + *args: Positional arguments to pass to the task. + **kwargs: Keyword arguments to pass to the task. + """ self._tasks.append(BatchTaskSync(task=task, node=node, args=args, kwargs=kwargs)) def execute(self) -> Generator[tuple[InfrahubNodeSync | None, Any], None, None]: + """Executes all tasks in the batch concurrently using a ThreadPoolExecutor. + + Yields: + A tuple containing the task's node (if any) and the result or exception. + + Raises: + Exception: If `return_exceptions` is False and a task raises an exception. + """ with ThreadPoolExecutor(max_workers=self.max_concurrent_execution) as executor: futures = [executor.submit(task.execute, return_exceptions=self.return_exceptions) for task in self._tasks] for future in futures: diff --git a/infrahub_sdk/branch.py b/infrahub_sdk/branch.py index 9d7de1fb..f30b7f6d 100644 --- a/infrahub_sdk/branch.py +++ b/infrahub_sdk/branch.py @@ -15,6 +15,7 @@ class BranchData(BaseModel): + """Represents data associated with a branch.""" id: str name: str description: str | None = None @@ -48,6 +49,7 @@ class BranchData(BaseModel): class InfraHubBranchManagerBase: + """Base class for branch management operations.""" @classmethod def generate_diff_data_url( cls, @@ -57,7 +59,18 @@ def generate_diff_data_url( time_from: str | None = None, time_to: str | None = None, ) -> str: - """Generate the URL for the diff_data function.""" + """Generates the URL for the diff_data function. + + Args: + client: The Infrahub client (either sync or async). + branch_name: The name of the branch. + branch_only: Whether to include only branch data in the diff. Defaults to True. + time_from: The start time for the diff (ISO 8601 format). + time_to: The end time for the diff (ISO 8601 format). + + Returns: + The generated URL string. + """ url = f"{client.address}/api/diff/data" url_params = {} url_params["branch"] = branch_name @@ -71,7 +84,13 @@ def generate_diff_data_url( class InfrahubBranchManager(InfraHubBranchManagerBase): + """Manages branches in Infrahub (asynchronous operations).""" def __init__(self, client: InfrahubClient): + """Initializes the asynchronous branch manager. + + Args: + client: An instance of InfrahubClient. + """ self.client = client @overload @@ -102,6 +121,19 @@ async def create( wait_until_completion: bool = True, background_execution: bool | None = False, ) -> BranchData | str: + """Creates a new branch. + + Args: + branch_name: The name for the new branch. + sync_with_git: Whether to synchronize the branch with Git. Defaults to True. + description: An optional description for the branch. + wait_until_completion: If True (default), waits for the branch creation to complete + and returns BranchData. If False, returns a task ID string. + background_execution: Deprecated. Use `wait_until_completion=False` instead. + + Returns: + BranchData if `wait_until_completion` is True, otherwise a task ID string. + """ if background_execution is not None: warnings.warn( "`background_execution` is deprecated, please use `wait_until_completion` instead.", @@ -131,6 +163,14 @@ async def create( return BranchData(**response["BranchCreate"]["object"]) async def delete(self, branch_name: str) -> bool: + """Deletes a branch. + + Args: + branch_name: The name of the branch to delete. + + Returns: + True if the deletion was successful, False otherwise. + """ input_data = { "data": { "name": branch_name, @@ -141,6 +181,14 @@ async def delete(self, branch_name: str) -> bool: return response["BranchDelete"]["ok"] async def rebase(self, branch_name: str) -> BranchData: + """Rebases a branch onto its origin branch. + + Args: + branch_name: The name of the branch to rebase. + + Returns: + BranchData for the rebased branch. + """ input_data = { "data": { "name": branch_name, @@ -151,6 +199,16 @@ async def rebase(self, branch_name: str) -> BranchData: return response["BranchRebase"]["ok"] async def validate(self, branch_name: str) -> BranchData: + """Validates a branch. + + Args: + branch_name: The name of the branch to validate. + + Returns: + True if the branch validation was successful, False otherwise. + Actually returns BranchData from the response, but the 'ok' field indicates success. + The return type should ideally be `bool` or a more specific validation result type. + """ input_data = { "data": { "name": branch_name, @@ -172,6 +230,14 @@ async def validate(self, branch_name: str) -> BranchData: return response["BranchValidate"]["ok"] async def merge(self, branch_name: str) -> bool: + """Merges a branch into its origin branch. + + Args: + branch_name: The name of the branch to merge. + + Returns: + True if the merge was successful, False otherwise. + """ input_data = { "data": { "name": branch_name, @@ -185,6 +251,11 @@ async def merge(self, branch_name: str) -> bool: return response["BranchMerge"]["ok"] async def all(self) -> dict[str, BranchData]: + """Retrieves all branches. + + Returns: + A dictionary mapping branch names to BranchData objects. + """ query = Query(name="GetAllBranch", query=QUERY_ALL_BRANCHES_DATA) data = await self.client.execute_graphql(query=query.render(), tracker="query-branch-all") @@ -193,6 +264,17 @@ async def all(self) -> dict[str, BranchData]: return branches async def get(self, branch_name: str) -> BranchData: + """Retrieves a specific branch by name. + + Args: + branch_name: The name of the branch to retrieve. + + Returns: + BranchData for the specified branch. + + Raises: + BranchNotFoundError: If the branch with the given name is not found. + """ query = Query(name="GetBranch", query=QUERY_ONE_BRANCH_DATA, variables={"branch_name": str}) data = await self.client.execute_graphql( query=query.render(), @@ -211,6 +293,19 @@ async def diff_data( time_from: str | None = None, time_to: str | None = None, ) -> dict[Any, Any]: + """Retrieves the data differences for a branch. + + This typically involves changes made on the branch compared to its origin. + + Args: + branch_name: The name of the branch. + branch_only: Whether to include only branch data in the diff. Defaults to True. + time_from: The start time for the diff (ISO 8601 format). + time_to: The end time for the diff (ISO 8601 format). + + Returns: + A dictionary representing the diff data. + """ url = self.generate_diff_data_url( client=self.client, branch_name=branch_name, @@ -223,10 +318,21 @@ async def diff_data( class InfrahubBranchManagerSync(InfraHubBranchManagerBase): + """Manages branches in Infrahub (synchronous operations).""" def __init__(self, client: InfrahubClientSync): + """Initializes the synchronous branch manager. + + Args: + client: An instance of InfrahubClientSync. + """ self.client = client def all(self) -> dict[str, BranchData]: + """Retrieves all branches. + + Returns: + A dictionary mapping branch names to BranchData objects. + """ query = Query(name="GetAllBranch", query=QUERY_ALL_BRANCHES_DATA) data = self.client.execute_graphql(query=query.render(), tracker="query-branch-all") @@ -235,6 +341,17 @@ def all(self) -> dict[str, BranchData]: return branches def get(self, branch_name: str) -> BranchData: + """Retrieves a specific branch by name. + + Args: + branch_name: The name of the branch to retrieve. + + Returns: + BranchData for the specified branch. + + Raises: + BranchNotFoundError: If the branch with the given name is not found. + """ query = Query(name="GetBranch", query=QUERY_ONE_BRANCH_DATA, variables={"branch_name": str}) data = self.client.execute_graphql( query=query.render(), @@ -274,6 +391,19 @@ def create( wait_until_completion: bool = True, background_execution: bool | None = False, ) -> BranchData | str: + """Creates a new branch. + + Args: + branch_name: The name for the new branch. + sync_with_git: Whether to synchronize the branch with Git. Defaults to True. + description: An optional description for the branch. + wait_until_completion: If True (default), waits for the branch creation to complete + and returns BranchData. If False, returns a task ID string. + background_execution: Deprecated. Use `wait_until_completion=False` instead. + + Returns: + BranchData if `wait_until_completion` is True, otherwise a task ID string. + """ if background_execution is not None: warnings.warn( "`background_execution` is deprecated, please use `wait_until_completion` instead.", @@ -302,6 +432,14 @@ def create( return BranchData(**response["BranchCreate"]["object"]) def delete(self, branch_name: str) -> bool: + """Deletes a branch. + + Args: + branch_name: The name of the branch to delete. + + Returns: + True if the deletion was successful, False otherwise. + """ input_data = { "data": { "name": branch_name, @@ -318,6 +456,19 @@ def diff_data( time_from: str | None = None, time_to: str | None = None, ) -> dict[Any, Any]: + """Retrieves the data differences for a branch. + + This typically involves changes made on the branch compared to its origin. + + Args: + branch_name: The name of the branch. + branch_only: Whether to include only branch data in the diff. Defaults to True. + time_from: The start time for the diff (ISO 8601 format). + time_to: The end time for the diff (ISO 8601 format). + + Returns: + A dictionary representing the diff data. + """ url = self.generate_diff_data_url( client=self.client, branch_name=branch_name, @@ -329,6 +480,14 @@ def diff_data( return decode_json(response=response) def merge(self, branch_name: str) -> bool: + """Merges a branch into its origin branch. + + Args: + branch_name: The name of the branch to merge. + + Returns: + True if the merge was successful, False otherwise. + """ input_data = { "data": { "name": branch_name, @@ -340,6 +499,14 @@ def merge(self, branch_name: str) -> bool: return response["BranchMerge"]["ok"] def rebase(self, branch_name: str) -> BranchData: + """Rebases a branch onto its origin branch. + + Args: + branch_name: The name of the branch to rebase. + + Returns: + BranchData for the rebased branch. + """ input_data = { "data": { "name": branch_name, @@ -350,6 +517,16 @@ def rebase(self, branch_name: str) -> BranchData: return response["BranchRebase"]["ok"] def validate(self, branch_name: str) -> BranchData: + """Validates a branch. + + Args: + branch_name: The name of the branch to validate. + + Returns: + True if the branch validation was successful, False otherwise. + Actually returns BranchData from the response, but the 'ok' field indicates success. + The return type should ideally be `bool` or a more specific validation result type. + """ input_data = { "data": { "name": branch_name, diff --git a/infrahub_sdk/checks.py b/infrahub_sdk/checks.py index 79511ccb..4a30e18c 100644 --- a/infrahub_sdk/checks.py +++ b/infrahub_sdk/checks.py @@ -23,7 +23,10 @@ class InfrahubCheckInitializer(BaseModel): - """Information about the originator of the check.""" + """Information about the originator of a check run. + + This data is typically provided by the system initiating the check. + """ proposed_change_id: str = Field( default="", description="If available the ID of the proposed change that requested the check" @@ -31,6 +34,14 @@ class InfrahubCheckInitializer(BaseModel): class InfrahubCheck: + """ + Base class for defining custom checks to be executed against Infrahub data. + + Attributes: + name: The name of the check. Defaults to the class name. + query: The GraphQL query string used to fetch data for the check. + timeout: Timeout in seconds for the check execution. + """ name: str | None = None query: str = "" timeout: int = 10 @@ -44,6 +55,19 @@ def __init__( params: dict | None = None, client: InfrahubClient | None = None, ): + """ + Initializes an InfrahubCheck instance. + + Args: + branch: The name of the branch to run the check against. + If None, it will try to determine the active git branch. + root_directory: The root directory of the repository. Defaults to the current working directory. + output: If "stdout", logs will be printed to standard output. + initializer: Information about the check's originator. + params: Parameters to be passed as variables to the GraphQL query. + client: An InfrahubClient instance. If not provided, one might be + created later or an UninitializedError will be raised when accessed. + """ self.git: GitRepoManager | None = None self.initializer = initializer or InfrahubCheckInitializer() @@ -70,6 +94,12 @@ def __str__(self) -> str: @property def client(self) -> InfrahubClient: + """ + The InfrahubClient instance for interacting with the Infrahub API. + + Raises: + UninitializedError: If the client has not been set. + """ if self._client: return self._client @@ -77,11 +107,33 @@ def client(self) -> InfrahubClient: @client.setter def client(self, value: InfrahubClient) -> None: + """ + Sets the InfrahubClient instance. + + Args: + value: The InfrahubClient instance. + """ self._client = value @classmethod async def init(cls, client: InfrahubClient | None = None, *args: Any, **kwargs: Any) -> InfrahubCheck: - """Async init method, If an existing InfrahubClient client hasn't been provided, one will be created automatically.""" + """ + Asynchronously initializes an instance of the check. + + If an existing InfrahubClient client hasn't been provided, one will be created automatically. + + Args: + client: An optional InfrahubClient instance. + *args: Additional arguments to pass to the check's constructor. + **kwargs: Additional keyword arguments to pass to the check's constructor. + + Returns: + An initialized instance of the InfrahubCheck subclass. + + Deprecated: + This method is deprecated and will be removed in version 2.0.0. + Instantiate the class directly and manage the client lifecycle separately. + """ warnings.warn( "InfrahubCheck.init has been deprecated and will be removed in version 2.0.0 of the Infrahub Python SDK", DeprecationWarning, @@ -96,11 +148,21 @@ async def init(cls, client: InfrahubClient | None = None, *args: Any, **kwargs: @property def errors(self) -> list[dict[str, Any]]: + """A list of all error log entries recorded by the check.""" return [log for log in self.logs if log["level"] == "ERROR"] def _write_log_entry( self, message: str, level: str, object_id: str | None = None, object_type: str | None = None ) -> None: + """ + Writes a structured log entry. + + Args: + message: The log message. + level: The log level (e.g., "INFO", "ERROR"). + object_id: Optional ID of the object related to the log entry. + object_type: Optional type of the object related to the log entry. + """ log_message = {"level": level, "message": message, "branch": self.branch_name} if object_id: log_message["object_id"] = object_id @@ -112,13 +174,30 @@ def _write_log_entry( print(ujson.dumps(log_message)) def log_error(self, message: str, object_id: str | None = None, object_type: str | None = None) -> None: + """ + Logs an error message. + + Args: + message: The error message. + object_id: Optional ID of the object related to the error. + object_type: Optional type of the object related to the error. + """ self._write_log_entry(message=message, level="ERROR", object_id=object_id, object_type=object_type) def log_info(self, message: str, object_id: str | None = None, object_type: str | None = None) -> None: + """ + Logs an informational message. + + Args: + message: The informational message. + object_id: Optional ID of the object related to the message. + object_type: Optional type of the object related to the message. + """ self._write_log_entry(message=message, level="INFO", object_id=object_id, object_type=object_type) @property def log_entries(self) -> str: + """A formatted string containing all log entries.""" output = "" for log in self.logs: output += "-----------------------\n" @@ -145,16 +224,42 @@ def branch_name(self) -> str: @abstractmethod def validate(self, data: dict) -> None: - """Code to validate the status of this check.""" + """ + Abstract method to be implemented by subclasses to perform the actual validation logic. + + This method should use `log_error` to record any validation failures. + The overall check status (passed/failed) is determined by the presence of error logs. + + Args: + data: The data fetched by the GraphQL query, to be validated. + """ async def collect_data(self) -> dict: - """Query the result of the GraphQL Query defined in self.query and return the result""" + """ + Queries the Infrahub API using the GraphQL query defined in `self.query`. + + Returns: + The data returned by the GraphQL query. + """ return await self.client.query_gql_query(name=self.query, branch_name=self.branch_name, variables=self.params) async def run(self, data: dict | None = None) -> bool: - """Execute the check after collecting the data from the GraphQL query. - The result of the check is determined based on the presence or not of ERROR log messages.""" + """ + Executes the check. + + This involves: + 1. Collecting data using `collect_data()` if not provided. + 2. Running the `validate()` method with the collected data. + 3. Determining the check's success based on whether any errors were logged. + + Args: + data: Optional pre-fetched data to use for validation. If None, + `collect_data()` will be called. + + Returns: + True if the check passed (no errors logged), False otherwise. + """ if not data: data = await self.collect_data() diff --git a/infrahub_sdk/client.py b/infrahub_sdk/client.py index bfea914c..559cfd70 100644 --- a/infrahub_sdk/client.py +++ b/infrahub_sdk/client.py @@ -73,11 +73,24 @@ class ProcessRelationsNode(TypedDict): class ProcessRelationsNodeSync(TypedDict): + """A dictionary type for results of processing nodes and their relationships (sync version).""" nodes: list[InfrahubNodeSync] related_nodes: list[InfrahubNodeSync] def handle_relogin(func: Callable[..., Coroutine[Any, Any, httpx.Response]]): # type: ignore[no-untyped-def] + """ + Decorator for InfrahubClient methods to handle automatic re-login on expired signature errors. + + If a 401 error with "Expired Signature" message is received, it attempts to + re-login using `client.login(refresh=True)` and then retries the original call. + + Args: + func: The asynchronous client method to wrap. + + Returns: + The wrapped function. + """ @wraps(func) async def wrapper(client: InfrahubClient, *args: Any, **kwargs: Any) -> httpx.Response: response = await func(client, *args, **kwargs) @@ -92,6 +105,18 @@ async def wrapper(client: InfrahubClient, *args: Any, **kwargs: Any) -> httpx.Re def handle_relogin_sync(func: Callable[..., httpx.Response]): # type: ignore[no-untyped-def] + """ + Decorator for InfrahubClientSync methods to handle automatic re-login on expired signature errors. + + If a 401 error with "Expired Signature" message is received, it attempts to + re-login using `client.login(refresh=True)` and then retries the original call. + + Args: + func: The synchronous client method to wrap. + + Returns: + The wrapped function. + """ @wraps(func) def wrapper(client: InfrahubClientSync, *args: Any, **kwargs: Any) -> httpx.Response: response = func(client, *args, **kwargs) @@ -113,6 +138,14 @@ def __init__( address: str = "", config: Config | dict[str, Any] | None = None, ): + """ + Initializes the BaseClient. + + Args: + address: The Infrahub server address. Overrides address in config if provided. + config: A Config object or a dictionary to initialize the client's configuration. + If None, a default Config object will be created. + """ self.client = None self.headers = {"content-type": "application/json"} self.access_token: str = "" @@ -146,12 +179,29 @@ def __init__( self._request_context: RequestContext | None = None def _initialize(self) -> None: - """Sets the properties for each version of the client""" + """ + Sets the version-specific properties for the client (async or sync). + To be implemented by subclasses. + """ def _record(self, response: httpx.Response) -> None: + """ + Records the HTTP response using the custom recorder if configured. + + Args: + response: The httpx.Response object to record. + """ self.config.custom_recorder.record(response) def _echo(self, url: str, query: str, variables: dict | None = None) -> None: + """ + Prints the GraphQL query details to stdout if echo_graphql_queries is enabled in config. + + Args: + url: The GraphQL endpoint URL. + query: The GraphQL query string. + variables: Optional dictionary of variables for the query. + """ if self.config.echo_graphql_queries: print(f"URL: {url}") print(f"QUERY:\n{query}") @@ -160,10 +210,17 @@ def _echo(self, url: str, query: str, variables: dict | None = None) -> None: @property def request_context(self) -> RequestContext | None: + """The current request context, if any.""" return self._request_context @request_context.setter def request_context(self, request_context: RequestContext) -> None: + """ + Sets the request context for the client. + + Args: + request_context: The RequestContext object. + """ self._request_context = request_context def start_tracking( @@ -173,6 +230,22 @@ def start_tracking( delete_unused_nodes: bool = False, group_type: str | None = None, ) -> Self: + """ + Switches the client to TRACKING mode and configures the group context. + + In TRACKING mode, changes made via the client can be associated with a group, + allowing for features like automatic cleanup of unused nodes. + + Args: + identifier: A unique identifier for the tracking group. Defaults to `self.identifier` or "python-sdk". + params: Optional parameters to associate with the tracking group. + delete_unused_nodes: If True, nodes associated with this group that are no longer + referenced might be deleted when the context ends. + group_type: An optional type for the group. + + Returns: + The client instance (self). + """ self.mode = InfrahubClientMode.TRACKING identifier = identifier or self.identifier or "python-sdk" self.set_context_properties( @@ -188,6 +261,17 @@ def set_context_properties( reset: bool = True, group_type: str | None = None, ) -> None: + """ + Sets the properties for the group context used in TRACKING mode. + + Args: + identifier: A unique identifier for the tracking group. + params: Optional parameters to associate with the tracking group. + delete_unused_nodes: If True, nodes associated with this group that are no longer + referenced might be deleted when the context ends. + reset: If True (default), initializes a new group context. + group_type: An optional type for the group. + """ if reset: if isinstance(self, InfrahubClient): self.group_context = InfrahubGroupContext(self) @@ -202,6 +286,16 @@ def _graphql_url( branch_name: str | None = None, at: str | Timestamp | None = None, ) -> str: + """ + Constructs the GraphQL API URL for a given branch and optional timestamp. + + Args: + branch_name: The name of the branch. If None, the base GraphQL URL is returned. + at: An optional timestamp or ISO 8601 string to query at a specific point in time. + + Returns: + The constructed GraphQL URL. + """ url = f"{self.config.address}/graphql" if branch_name: url += f"/{branch_name}" @@ -222,6 +316,19 @@ def _build_ip_address_allocation_query( address_type: str | None = None, data: dict[str, Any] | None = None, ) -> Mutation: + """ + Builds a GraphQL mutation for allocating an IP address from a resource pool. + + Args: + resource_pool_id: The ID of the CoreIPAddressPool. + identifier: Optional identifier for idempotent allocation. + prefix_length: Optional prefix length for the allocated address. + address_type: Optional type/kind of the IP address to allocate. + data: Optional dictionary of additional data to set on the allocated IP address. + + Returns: + A Mutation object for the IP address allocation. + """ input_data: dict[str, Any] = {"id": resource_pool_id} if identifier: @@ -249,6 +356,23 @@ def _build_ip_prefix_allocation_query( prefix_type: str | None = None, data: dict[str, Any] | None = None, ) -> Mutation: + """ + Builds a GraphQL mutation for allocating an IP prefix from a resource pool. + + Args: + resource_pool_id: The ID of the CoreIPPrefixPool. + identifier: Optional identifier for idempotent allocation. + prefix_length: Optional length of the prefix to allocate. + member_type: Optional member type for the prefix ("prefix" or "address"). + prefix_type: Optional type/kind of the IP prefix to allocate. + data: Optional dictionary of additional data to set on the allocated IP prefix. + + Returns: + A Mutation object for the IP prefix allocation. + + Raises: + ValueError: If `member_type` is provided and is not "prefix" or "address". + """ input_data: dict[str, Any] = {"id": resource_pool_id} if identifier: @@ -273,11 +397,18 @@ def _build_ip_prefix_allocation_query( class InfrahubClient(BaseClient): - """GraphQL Client to interact with Infrahub.""" + """ + Asynchronous GraphQL Client to interact with an Infrahub instance. + + This client provides methods for CRUD operations on Infrahub nodes, + branch management, schema introspection, and other Infrahub-specific functionalities. + It uses `httpx` for asynchronous HTTP requests. + """ group_context: InfrahubGroupContext def _initialize(self) -> None: + """Initializes asynchronous client-specific components.""" self.schema = InfrahubSchema(self) self.branch = InfrahubBranchManager(self) self.object_store = ObjectStore(self) @@ -288,18 +419,33 @@ def _initialize(self) -> None: self.group_context = InfrahubGroupContext(self) async def get_version(self) -> str: - """Return the Infrahub version.""" + """ + Retrieves the version of the connected Infrahub instance. + + Returns: + A string representing the Infrahub server version. + """ response = await self.execute_graphql(query="query { InfrahubInfo { version }}") version = response.get("InfrahubInfo", {}).get("version", "") return version async def get_user(self) -> dict: - """Return user information""" + """ + Retrieves information about the currently authenticated user. + + Returns: + A dictionary containing user profile information. + """ user_info = await self.execute_graphql(query=QUERY_USER) return user_info async def get_user_permissions(self) -> dict: - """Return user permissions""" + """ + Retrieves the permissions of the currently authenticated user. + + Returns: + A dictionary representing the user's permissions. + """ user_info = await self.get_user() return get_user_permissions(user_info["AccountProfile"]["member_of_groups"]["edges"]) @@ -329,6 +475,25 @@ async def create( timeout: int | None = None, **kwargs: Any, ) -> InfrahubNode | SchemaType: + """ + Creates a new Infrahub node. + + Args: + kind: The kind of the node to create (e.g., "CoreSite") or its type (e.g., CoreSite). + data: A dictionary of data to initialize the node with. + Can be used instead of or in addition to kwargs. + branch: The branch on which to create the node. Defaults to the client's default branch. + timeout: Optional timeout in seconds for the schema retrieval. + **kwargs: Attributes and their values to set on the new node. + + Returns: + An `InfrahubNode` instance (or a typed subclass if `kind` was a type) + representing the newly created node. It is not yet saved to Infrahub. + Call `.save()` on the returned node to persist it. + + Raises: + ValueError: If neither `data` nor `kwargs` are provided. + """ branch = branch or self.default_branch schema = await self.schema.get(kind=kind, branch=branch, timeout=timeout) @@ -339,6 +504,16 @@ async def create( return InfrahubNode(client=self, schema=schema, branch=branch, data=data or kwargs) async def delete(self, kind: str | type[SchemaType], id: str, branch: str | None = None) -> None: + """ + Deletes an Infrahub node by its ID. + + Note: This performs an immediate deletion request to the server. + + Args: + kind: The kind of the node to delete or its type. + id: The ID of the node to delete. + branch: The branch from which to delete the node. Defaults to the client's default branch. + """ branch = branch or self.default_branch schema = await self.schema.get(kind=kind, branch=branch) @@ -476,6 +651,35 @@ async def get( property: bool = False, **kwargs: Any, ) -> InfrahubNode | SchemaType | None: + """ + Retrieves a single Infrahub node by its ID, HFID, or other unique attributes. + + Args: + kind: The kind of the node to retrieve (e.g., "CoreSite") or its type (e.g., CoreSite). + raise_when_missing: If True (default), raises `NodeNotFoundError` if the node isn't found. + If False, returns None when not found. + at: Optional timestamp to retrieve the node state at a specific time. + branch: The branch to retrieve the node from. Defaults to the client's default branch. + timeout: Optional timeout in seconds for the GraphQL request. + id: The UUID of the node. + hfid: A list of Human-Friendly IDs to search for. + include: List of specific attributes or relationships to include in the response. + exclude: List of attributes or relationships to exclude from the response. + populate_store: If True (default), the retrieved node is added/updated in the client's NodeStore. + fragment: If True, uses GraphQL fragments (useful for generic schema types). + prefetch_relationships: If True, attempts to prefetch data for related nodes. + property: If True, indicates that a property field is being queried directly. + **kwargs: Additional filter criteria (attribute=value pairs) to find the node. + + Returns: + An `InfrahubNode` (or its typed subclass) if found, or None if `raise_when_missing` is False. + + Raises: + NodeNotFoundError: If `raise_when_missing` is True and no node matches the criteria. + IndexError: If more than one node matches the criteria. + ValueError: If no filter criteria (id, hfid, or kwargs) are provided, or if filtering + by HFID is attempted on a node kind that doesn't support it. + """ branch = branch or self.default_branch schema = await self.schema.get(kind=kind, branch=branch) @@ -527,19 +731,23 @@ async def _process_nodes_and_relationships( prefetch_relationships: bool, timeout: int | None = None, ) -> ProcessRelationsNode: - """Processes InfrahubNode and their Relationships from the GraphQL query response. + """ + Processes InfrahubNode objects and their relationships from a GraphQL query response. + + This is an internal helper method. Args: - response (dict[str, Any]): The response from the GraphQL query. - schema_kind (str): The kind of schema being queried. - branch (str): The branch name. - prefetch_relationships (bool): Flag to indicate whether to prefetch relationship data. - timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds. + response: The raw dictionary response from a GraphQL query. + schema_kind: The `kind` of the primary nodes being processed from the response. + branch: The branch name these nodes belong to. + prefetch_relationships: If True, additionally processes and fetches related nodes. + timeout: Optional timeout for fetching related node schemas. Returns: - ProcessRelationsNodeSync: A TypedDict containing two lists: - - 'nodes': A list of InfrahubNode objects representing the nodes processed. - - 'related_nodes': A list of InfrahubNode objects representing the related nodes + A ProcessRelationsNode TypedDict containing: + - 'nodes': A list of processed `InfrahubNode` objects. + - 'related_nodes': A list of `InfrahubNode` objects that are related to the primary nodes + (populated if `prefetch_relationships` is True). """ nodes: list[InfrahubNode] = [] @@ -565,7 +773,20 @@ async def count( partial_match: bool = False, **kwargs: Any, ) -> int: - """Return the number of nodes of a given kind.""" + """ + Counts the number of nodes of a given kind that match the specified filters. + + Args: + kind: The kind of the node (e.g., "CoreSite") or its type (e.g., CoreSite). + at: Optional timestamp to count nodes at a specific time. + branch: The branch to count nodes in. Defaults to the client's default branch. + timeout: Optional timeout in seconds for the GraphQL request. + partial_match: If True, allows partial matching for string filters. + **kwargs: Filter criteria (attribute=value pairs) for counting nodes. + + Returns: + The number of matching nodes. + """ filters: dict[str, Any] = dict(kwargs) if partial_match: @@ -644,25 +865,29 @@ async def all( parallel: bool = False, order: Order | None = None, ) -> list[InfrahubNode] | list[SchemaType]: - """Retrieve all nodes of a given kind + """ + Retrieves all nodes of a given kind. + + This is a convenience method that calls `filters()` without any specific filter arguments. Args: - kind (str): kind of the nodes to query - at (Timestamp, optional): Time of the query. Defaults to Now. - branch (str, optional): Name of the branch to query from. Defaults to default_branch. - populate_store (bool, optional): Flag to indicate whether to populate the store with the retrieved nodes. - timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds. - offset (int, optional): The offset for pagination. - limit (int, optional): The limit for pagination. - include (list[str], optional): List of attributes or relationships to include in the query. - exclude (list[str], optional): List of attributes or relationships to exclude from the query. - fragment (bool, optional): Flag to use GraphQL fragments for generic schemas. - prefetch_relationships (bool, optional): Flag to indicate whether to prefetch related node data. - parallel (bool, optional): Whether to use parallel processing for the query. - order (Order, optional): Ordering related options. Setting `disable=True` enhances performances. + kind: The kind of the nodes to query (e.g., "CoreSite") or its type (e.g., CoreSite). + at: Optional timestamp to query nodes at a specific time. + branch: The branch to query from. Defaults to the client's default branch. + timeout: Optional timeout in seconds for GraphQL requests. + populate_store: If True (default), retrieved nodes are added/updated in the client's NodeStore. + offset: Optional offset for pagination. + limit: Optional limit for pagination. + include: List of specific attributes or relationships to include in the response. + exclude: List of attributes or relationships to exclude from the response. + fragment: If True, uses GraphQL fragments (useful for generic schema types). + prefetch_relationships: If True, attempts to prefetch data for related nodes. + property: If True, indicates that property fields are being queried directly. + parallel: If True, fetches pages in parallel (can be faster but consumes more resources). + order: Optional `Order` object to specify sorting. Disabling order enhances performance. Returns: - list[InfrahubNode]: List of Nodes + A list of `InfrahubNode` objects (or their typed subclasses). """ return await self.filters( kind=kind, @@ -742,27 +967,31 @@ async def filters( order: Order | None = None, **kwargs: Any, ) -> list[InfrahubNode] | list[SchemaType]: - """Retrieve nodes of a given kind based on provided filters. + """ + Retrieves nodes of a given kind based on provided filters and pagination options. Args: - kind (str): kind of the nodes to query - at (Timestamp, optional): Time of the query. Defaults to Now. - branch (str, optional): Name of the branch to query from. Defaults to default_branch. - timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds. - populate_store (bool, optional): Flag to indicate whether to populate the store with the retrieved nodes. - offset (int, optional): The offset for pagination. - limit (int, optional): The limit for pagination. - include (list[str], optional): List of attributes or relationships to include in the query. - exclude (list[str], optional): List of attributes or relationships to exclude from the query. - fragment (bool, optional): Flag to use GraphQL fragments for generic schemas. - prefetch_relationships (bool, optional): Flag to indicate whether to prefetch related node data. - partial_match (bool, optional): Allow partial match of filter criteria for the query. - parallel (bool, optional): Whether to use parallel processing for the query. - order (Order, optional): Ordering related options. Setting `disable=True` enhances performances. - **kwargs (Any): Additional filter criteria for the query. + kind: The kind of the nodes to query (e.g., "CoreSite") or its type (e.g., CoreSite). + at: Optional timestamp to query nodes at a specific time. + branch: The branch to query from. Defaults to the client's default branch. + timeout: Optional timeout in seconds for GraphQL requests. + populate_store: If True (default), retrieved nodes are added/updated in the client's NodeStore. + offset: Optional offset for pagination. + limit: Optional limit for pagination. If set, `parallel` processing might be less effective + if limit is smaller than pagination_size. + include: List of specific attributes or relationships to include in the response. + exclude: List of attributes or relationships to exclude from the response. + fragment: If True, uses GraphQL fragments (useful for generic schema types). + prefetch_relationships: If True, attempts to prefetch data for related nodes. + partial_match: If True, allows partial matching for string filters. + property: If True, indicates that property fields are being queried directly. + parallel: If True, fetches pages in parallel (can be faster but consumes more resources). + Not recommended if `limit` is set to a small value. + order: Optional `Order` object to specify sorting. Disabling order enhances performance. + **kwargs: Additional filter criteria (attribute=value pairs) for the query. Returns: - list[InfrahubNodeSync]: List of Nodes that match the given filters. + A list of `InfrahubNode` objects (or their typed subclasses) that match the filters. """ branch = branch or self.default_branch schema = await self.schema.get(kind=kind, branch=branch) @@ -857,7 +1086,19 @@ async def process_non_batch() -> tuple[list[InfrahubNode], list[InfrahubNode]]: return nodes def clone(self, branch: str | None = None) -> InfrahubClient: - """Return a cloned version of the client using the same configuration""" + """ + Creates a new `InfrahubClient` instance with a cloned configuration. + + This is useful for creating a client for a different branch while retaining + the original client's settings (address, credentials, etc.). + + Args: + branch: Optional new default branch name for the cloned client. + If None, the current client's default branch is used. + + Returns: + A new `InfrahubClient` instance. + """ return InfrahubClient(config=self.config.clone(branch=branch)) async def execute_graphql( @@ -870,21 +1111,31 @@ async def execute_graphql( raise_for_error: bool = True, tracker: str | None = None, ) -> dict: - """Execute a GraphQL query (or mutation). - If retry_on_failure is True, the query will retry until the server becomes reacheable. + """ + Executes a raw GraphQL query or mutation. + + If `retry_on_failure` is True in the client config, the query will be retried + if the server is unreachable. Args: - query (_type_): GraphQL Query to execute, can be a query or a mutation - variables (dict, optional): Variables to pass along with the GraphQL query. Defaults to None. - branch_name (str, optional): Name of the branch on which the query will be executed. Defaults to None. - at (str, optional): Time when the query should be executed. Defaults to None. - timeout (int, optional): Timeout in second for the query. Defaults to None. - raise_for_error (bool, optional): Flag to indicate that we need to raise an exception if the response has some errors. Defaults to True. - Raises: - GraphQLError: _description_ + query: The GraphQL query or mutation string. + variables: Optional dictionary of variables for the query. + branch_name: The branch to execute against. Defaults to the client's default branch. + at: Optional timestamp to execute the query at a specific time. + timeout: Optional timeout in seconds for this specific request. + raise_for_error: If True (default), raises `GraphQLError` if the response contains errors. + tracker: Optional tracker string to include in request headers for debugging/logging. Returns: - _type_: _description_ + A dictionary containing the "data" part of the GraphQL response. + + Raises: + ServerNotReachableError: If the server cannot be reached after retries (if enabled). + httpx.HTTPStatusError: For HTTP errors (e.g., 401, 403, 404) if not handled otherwise. + AuthenticationError: For 401/403 errors specifically. + URLNotFoundError: For 404 errors. + GraphQLError: If `raise_for_error` is True and the GraphQL response contains errors. + Error: If an unexpected situation occurs where the response object isn't initialized. """ branch_name = branch_name or self.default_branch @@ -983,6 +1234,21 @@ async def _get(self, url: str, headers: dict | None = None, timeout: int | None async def _request( self, url: str, method: HTTPMethod, headers: dict[str, Any], timeout: int, payload: dict | None = None ) -> httpx.Response: + """ + Internal method to make an HTTP request using the configured requester. + + Also handles recording the response. + + Args: + url: The URL for the request. + method: The HTTP method (GET, POST, etc.). + headers: Dictionary of request headers. + timeout: Request timeout in seconds. + payload: Optional request payload (typically for POST/PUT). + + Returns: + An `httpx.Response` object. + """ response = await self._request_method(url=url, method=method, headers=headers, timeout=timeout, payload=payload) self._record(response) return response @@ -990,6 +1256,25 @@ async def _request( async def _default_request_method( self, url: str, method: HTTPMethod, headers: dict[str, Any], timeout: int, payload: dict | None = None ) -> httpx.Response: + """ + The default asynchronous HTTP request method using httpx.AsyncClient. + + Handles proxy configuration and TLS verification settings. + + Args: + url: The URL for the request. + method: The HTTP method. + headers: Request headers. + timeout: Request timeout in seconds. + payload: Optional request payload. + + Returns: + An `httpx.Response` object. + + Raises: + ServerNotReachableError: If a network error occurs. + ServerNotResponsiveError: If a read timeout occurs. + """ params: dict[str, Any] = {} if payload: params["json"] = payload @@ -1023,6 +1308,15 @@ async def _default_request_method( return response async def refresh_login(self) -> None: + """ + Refreshes the authentication access token using the stored refresh token. + + Updates `self.access_token` and the "Authorization" header. + This method is called automatically by decorated request methods if a token expires. + + Raises: + httpx.HTTPStatusError: If the refresh request itself fails (e.g., invalid refresh token). + """ if not self.refresh_token: return @@ -1040,6 +1334,24 @@ async def refresh_login(self) -> None: self.headers["Authorization"] = f"Bearer {self.access_token}" async def login(self, refresh: bool = False) -> None: + """ + Logs into Infrahub using username/password or refreshes an existing session. + + If password authentication is not configured, this method does nothing. + If an access token already exists and `refresh` is False, it does nothing. + If `refresh` is True and a refresh token exists, it attempts `refresh_login()`. + Otherwise, it performs a full login with username and password. + + Updates `self.access_token`, `self.refresh_token`, and the "Authorization" header. + + Args: + refresh: If True, forces an attempt to refresh the token if one exists. + + Raises: + AuthenticationError: If login fails due to authentication issues (e.g., bad credentials + during initial login, or non-401 error during refresh). + httpx.HTTPStatusError: For other HTTP errors during the login process. + """ if not self.config.password_authentication: return @@ -1055,7 +1367,7 @@ async def login(self, refresh: bool = False) -> None: # Other status codes indicate other errors if exc.response.status_code != 401: response = exc.response.json() - errors = response.get("errors") + errors = response.get("errors", []) messages = [error.get("message") for error in errors] raise AuthenticationError(" | ".join(messages)) from exc @@ -1087,6 +1399,27 @@ async def query_gql_query( tracker: str | None = None, raise_for_error: bool = True, ) -> dict: + """ + Executes a pre-defined GraphQL query stored on the Infrahub server by its name. + + Args: + name: The name of the stored GraphQL query. + variables: Optional dictionary of variables for the query. + update_group: If True, associates this query with the current tracking group (if active). + subscribers: Optional list of subscriber identifiers. + params: Optional dictionary of additional URL parameters. + branch_name: The branch to execute against. Defaults to client's default. + at: Optional timestamp to execute at a specific time. + timeout: Optional timeout for this request. + tracker: Optional tracker string for request headers. + raise_for_error: If True (default), raises an exception on HTTP or GraphQL errors. + + Returns: + A dictionary containing the query's response data. + + Raises: + httpx.HTTPStatusError: For HTTP errors if `raise_for_error` is True. + """ url = f"{self.address}/api/query/{name}" url_params = copy.deepcopy(params or {}) headers = copy.copy(self.headers or {}) @@ -1143,6 +1476,18 @@ async def get_diff_summary( tracker: str | None = None, raise_for_error: bool = True, ) -> list[NodeDiff]: + """ + Retrieves a summary of differences (diffs) for a given branch. + + Args: + branch: The name of the branch to get the diff summary for. + timeout: Optional timeout for the GraphQL request. + tracker: Optional tracker string for request headers. + raise_for_error: If True (default), raises an exception on HTTP or GraphQL errors. + + Returns: + A list of `NodeDiff` objects representing the changes on the branch. + """ query = get_diff_summary_query() response = await self.execute_graphql( query=query, @@ -1267,20 +1612,30 @@ async def allocate_next_ip_address( tracker: str | None = None, raise_for_error: bool = True, ) -> CoreNode | SchemaType | None: - """Allocate a new IP address by using the provided resource pool. + """ + Allocates the next available IP address from a specified CoreIPAddressPool. Args: - resource_pool (InfrahubNode): Node corresponding to the pool to allocate resources from. - identifier (str, optional): Value to perform idempotent allocation, the same resource will be returned for a given identifier. - prefix_length (int, optional): Length of the prefix to set on the address to allocate. - address_type (str, optional): Kind of the address to allocate. - data (dict, optional): A key/value map to use to set attributes values on the allocated address. - branch (str, optional): Name of the branch to allocate from. Defaults to default_branch. - timeout (int, optional): Flag to indicate whether to populate the store with the retrieved nodes. - tracker (str, optional): The offset for pagination. - raise_for_error (bool, optional): The limit for pagination. + resource_pool: The `CoreIPAddressPool` node from which to allocate. + kind: Optional specific type of `CoreIPAddress` to expect (e.g., a custom subclass). + identifier: Optional identifier for idempotent allocation. If provided, subsequent calls + with the same identifier will return the same allocated address. + prefix_length: Optional desired prefix length for the allocated IP address. + address_type: Optional specific kind of IP address to allocate if the pool supports multiple. + data: Optional dictionary of attributes to set on the newly allocated IP address node. + branch: The branch on which to perform the allocation. Defaults to the client's default branch. + timeout: Optional timeout for the GraphQL request. + tracker: Optional tracker string for request headers. + raise_for_error: If True (default), raises an exception on HTTP or GraphQL errors. + If False and allocation fails, returns None. + Returns: - InfrahubNode: Node corresponding to the allocated resource. + The allocated `CoreIPAddress` node (or its typed subclass if `kind` was specified), + or None if allocation failed and `raise_for_error` is False. + + Raises: + ValueError: If `resource_pool` is not a "CoreIPAddressPool". + GraphQLError: If allocation fails and `raise_for_error` is True. """ if resource_pool.get_kind() != "CoreIPAddressPool": raise ValueError("resource_pool is not an IP address pool") @@ -1418,21 +1773,30 @@ async def allocate_next_ip_prefix( tracker: str | None = None, raise_for_error: bool = True, ) -> CoreNode | SchemaType | None: - """Allocate a new IP prefix by using the provided resource pool. + """ + Allocates the next available IP prefix from a specified CoreIPPrefixPool. Args: - resource_pool: Node corresponding to the pool to allocate resources from. - identifier: Value to perform idempotent allocation, the same resource will be returned for a given identifier. - prefix_length: Length of the prefix to allocate. - member_type: Member type of the prefix to allocate. - prefix_type: Kind of the prefix to allocate. - data: A key/value map to use to set attributes values on the allocated prefix. - branch: Name of the branch to allocate from. Defaults to default_branch. - timeout: Flag to indicate whether to populate the store with the retrieved nodes. - tracker: The offset for pagination. - raise_for_error: The limit for pagination. + resource_pool: The `CoreIPPrefixPool` node from which to allocate. + kind: Optional specific type of `CoreIPPrefix` to expect (e.g., a custom subclass). + identifier: Optional identifier for idempotent allocation. + prefix_length: Optional desired length of the prefix to allocate. + member_type: Optional member type for the prefix (e.g., "prefix", "address"). + prefix_type: Optional specific kind of IP prefix to allocate if the pool supports multiple. + data: Optional dictionary of attributes to set on the newly allocated IP prefix node. + branch: The branch on which to perform the allocation. Defaults to the client's default branch. + timeout: Optional timeout for the GraphQL request. + tracker: Optional tracker string for request headers. + raise_for_error: If True (default), raises an exception on HTTP or GraphQL errors. + If False and allocation fails, returns None. + Returns: - InfrahubNode: Node corresponding to the allocated resource. + The allocated `CoreIPPrefix` node (or its typed subclass if `kind` was specified), + or None if allocation failed and `raise_for_error` is False. + + Raises: + ValueError: If `resource_pool` is not a "CoreIPPrefixPool". + GraphQLError: If allocation fails and `raise_for_error` is True. """ if resource_pool.get_kind() != "CoreIPPrefixPool": raise ValueError("resource_pool is not an IP prefix pool") @@ -1458,11 +1822,31 @@ async def allocate_next_ip_prefix( return None async def create_batch(self, return_exceptions: bool = False) -> InfrahubBatch: + """ + Creates an `InfrahubBatch` instance for managing concurrent asynchronous tasks. + + Args: + return_exceptions: If True, exceptions from tasks in the batch will be returned + as results instead of being raised. + + Returns: + An `InfrahubBatch` instance. + """ return InfrahubBatch(semaphore=self.concurrent_execution_limit, return_exceptions=return_exceptions) async def get_list_repositories( - self, branches: dict[str, BranchData] | None = None, kind: str = "CoreGenericRepository" + self, branches: dict[str, BranchData] | None = None, kind:str = "CoreGenericRepository" ) -> dict[str, RepositoryData]: + """ + Retrieves a list of repositories and their branch information. + + Args: + branches: Optional dictionary of branch data. If None, all branches are fetched. + kind: The kind of repository node to list (defaults to "CoreGenericRepository"). + + Returns: + A dictionary where keys are repository names and values are `RepositoryData` objects. + """ branches = branches or await self.branch.all() batch = await self.create_batch() @@ -1501,6 +1885,18 @@ async def get_list_repositories( async def repository_update_commit( self, branch_name: str, repository_id: str, commit: str, is_read_only: bool = False ) -> bool: + """ + Updates the commit SHA for a specific repository on a given branch. + + Args: + branch_name: The name of the branch where the repository's commit will be updated. + repository_id: The ID of the repository node to update. + commit: The new commit SHA. + is_read_only: If True, uses a read-only mutation (e.g., for dry runs or checks). + + Returns: + True if the operation was successful (the GraphQL mutation returned ok). + """ variables = {"repository_id": str(repository_id), "commit": str(commit)} await self.execute_graphql( query=get_commit_update_mutation(is_read_only=is_read_only), @@ -1512,6 +1908,7 @@ async def repository_update_commit( return True async def __aenter__(self) -> Self: + """Enters an asynchronous context, returning the client instance.""" return self async def __aexit__( @@ -1520,6 +1917,13 @@ async def __aexit__( exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: + """ + Exits an asynchronous context. + + If the client was in TRACKING mode and no exception occurred, + it finalizes the group context (e.g., by calling `update_group()`). + Resets client mode to DEFAULT. + """ if exc_type is None and self.mode == InfrahubClientMode.TRACKING: await self.group_context.update_group() @@ -1527,6 +1931,13 @@ async def __aexit__( class InfrahubClientSync(BaseClient): + """ + Synchronous GraphQL Client to interact with an Infrahub instance. + + This client provides methods for CRUD operations on Infrahub nodes, + branch management, schema introspection, and other Infrahub-specific functionalities. + It uses `httpx` for synchronous HTTP requests. + """ schema: InfrahubSchemaSync branch: InfrahubBranchManagerSync object_store: ObjectStoreSync @@ -1535,6 +1946,7 @@ class InfrahubClientSync(BaseClient): group_context: InfrahubGroupContextSync def _initialize(self) -> None: + """Initializes synchronous client-specific components.""" self.schema = InfrahubSchemaSync(self) self.branch = InfrahubBranchManagerSync(self) self.object_store = ObjectStoreSync(self) @@ -1544,18 +1956,33 @@ def _initialize(self) -> None: self.group_context = InfrahubGroupContextSync(self) def get_version(self) -> str: - """Return the Infrahub version.""" + """ + Retrieves the version of the connected Infrahub instance. + + Returns: + A string representing the Infrahub server version. + """ response = self.execute_graphql(query="query { InfrahubInfo { version }}") version = response.get("InfrahubInfo", {}).get("version", "") return version def get_user(self) -> dict: - """Return user information""" + """ + Retrieves information about the currently authenticated user. + + Returns: + A dictionary containing user profile information. + """ user_info = self.execute_graphql(query=QUERY_USER) return user_info def get_user_permissions(self) -> dict: - """Return user permissions""" + """ + Retrieves the permissions of the currently authenticated user. + + Returns: + A dictionary representing the user's permissions. + """ user_info = self.get_user() return get_user_permissions(user_info["AccountProfile"]["member_of_groups"]["edges"]) @@ -1585,6 +2012,23 @@ def create( timeout: int | None = None, **kwargs: Any, ) -> InfrahubNodeSync | SchemaTypeSync: + """ + Creates a new Infrahub node (synchronous version). + + Args: + kind: The kind of the node to create (e.g., "CoreSite") or its type (e.g., CoreSite). + data: A dictionary of data to initialize the node with. + branch: The branch on which to create the node. Defaults to the client's default branch. + timeout: Optional timeout for schema retrieval. + **kwargs: Attributes and their values to set on the new node. + + Returns: + An `InfrahubNodeSync` instance (or a typed subclass) representing the new node. + It is not yet saved. Call `.save()` on the returned node to persist it. + + Raises: + ValueError: If neither `data` nor `kwargs` are provided. + """ branch = branch or self.default_branch schema = self.schema.get(kind=kind, branch=branch, timeout=timeout) @@ -1594,6 +2038,16 @@ def create( return InfrahubNodeSync(client=self, schema=schema, branch=branch, data=data or kwargs) def delete(self, kind: str | type[SchemaTypeSync], id: str, branch: str | None = None) -> None: + """ + Deletes an Infrahub node by its ID (synchronous version). + + Note: This performs an immediate deletion request to the server. + + Args: + kind: The kind of the node to delete or its type. + id: The ID of the node to delete. + branch: The branch from which to delete the node. Defaults to the client's default branch. + """ branch = branch or self.default_branch schema = self.schema.get(kind=kind, branch=branch) @@ -1601,7 +2055,15 @@ def delete(self, kind: str | type[SchemaTypeSync], id: str, branch: str | None = node.delete() def clone(self, branch: str | None = None) -> InfrahubClientSync: - """Return a cloned version of the client using the same configuration""" + """ + Creates a new `InfrahubClientSync` instance with a cloned configuration. + + Args: + branch: Optional new default branch name for the cloned client. + + Returns: + A new `InfrahubClientSync` instance. + """ return InfrahubClientSync(config=self.config.clone(branch=branch)) def execute_graphql( @@ -1614,21 +2076,30 @@ def execute_graphql( raise_for_error: bool = True, tracker: str | None = None, ) -> dict: - """Execute a GraphQL query (or mutation). - If retry_on_failure is True, the query will retry until the server becomes reacheable. + """ + Executes a raw GraphQL query or mutation (synchronous version). + + If `retry_on_failure` is True in config, retries if the server is unreachable. Args: - query (str): GraphQL Query to execute, can be a query or a mutation - variables (dict, optional): Variables to pass along with the GraphQL query. Defaults to None. - branch_name (str, optional): Name of the branch on which the query will be executed. Defaults to None. - at (str, optional): Time when the query should be executed. Defaults to None. - timeout (int, optional): Timeout in second for the query. Defaults to None. - raise_for_error (bool, optional): Flag to indicate that we need to raise an exception if the response has some errors. Defaults to True. - Raises: - GraphQLError: When an error occurs during the execution of the GraphQL query or mutation. + query: The GraphQL query or mutation string. + variables: Optional dictionary of variables. + branch_name: Branch to execute against. Defaults to client's default. + at: Optional timestamp for point-in-time query. + timeout: Optional request timeout in seconds. + raise_for_error: If True (default), raises `GraphQLError` on response errors. + tracker: Optional tracker string for request headers. Returns: - dict: The result of the GraphQL query or mutation. + A dictionary containing the "data" part of the GraphQL response. + + Raises: + ServerNotReachableError: If server unreachable after retries. + httpx.HTTPStatusError: For HTTP errors if not handled otherwise. + AuthenticationError: For 401/403 errors. + URLNotFoundError: For 404 errors. + GraphQLError: If `raise_for_error` is True and response has errors. + Error: If response object isn't initialized unexpectedly. """ branch_name = branch_name or self.default_branch @@ -1688,14 +2159,27 @@ def execute_graphql( def count( self, - kind: str | type[SchemaType], + kind: str | type[SchemaTypeSync], # Corrected type hint at: Timestamp | None = None, branch: str | None = None, timeout: int | None = None, partial_match: bool = False, **kwargs: Any, ) -> int: - """Return the number of nodes of a given kind.""" + """ + Counts nodes of a given kind matching filters (synchronous version). + + Args: + kind: The kind of the node or its type (e.g., CoreSiteSync). + at: Optional timestamp for point-in-time count. + branch: Branch to count in. Defaults to client's default. + timeout: Optional request timeout. + partial_match: If True, allows partial string matching. + **kwargs: Filter criteria (attribute=value). + + Returns: + The number of matching nodes. + """ filters: dict[str, Any] = dict(kwargs) if partial_match: @@ -1774,25 +2258,29 @@ def all( parallel: bool = False, order: Order | None = None, ) -> list[InfrahubNodeSync] | list[SchemaTypeSync]: - """Retrieve all nodes of a given kind + """ + Retrieves all nodes of a given kind (synchronous version). + + Calls `filters()` without specific filter arguments. Args: - kind (str): kind of the nodes to query - at (Timestamp, optional): Time of the query. Defaults to Now. - branch (str, optional): Name of the branch to query from. Defaults to default_branch. - timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds. - populate_store (bool, optional): Flag to indicate whether to populate the store with the retrieved nodes. - offset (int, optional): The offset for pagination. - limit (int, optional): The limit for pagination. - include (list[str], optional): List of attributes or relationships to include in the query. - exclude (list[str], optional): List of attributes or relationships to exclude from the query. - fragment (bool, optional): Flag to use GraphQL fragments for generic schemas. - prefetch_relationships (bool, optional): Flag to indicate whether to prefetch related node data. - parallel (bool, optional): Whether to use parallel processing for the query. - order (Order, optional): Ordering related options. Setting `disable=True` enhances performances. + kind: Node kind (e.g., "CoreSite") or type (e.g., CoreSiteSync). + at: Optional timestamp for point-in-time query. + branch: Branch to query. Defaults to client's default. + timeout: Optional request timeout. + populate_store: If True (default), updates client's NodeStore. + offset: Optional pagination offset. + limit: Optional pagination limit. + include: Specific attributes/relationships to include. + exclude: Attributes/relationships to exclude. + fragment: If True, uses GraphQL fragments. + prefetch_relationships: If True, prefetches related node data. + property: If True, indicates direct property field query. + parallel: If True, fetches pages in parallel (thread pool). + order: Optional `Order` object for sorting. Returns: - list[InfrahubNodeSync]: List of Nodes + A list of `InfrahubNodeSync` objects (or typed subclasses). """ return self.filters( kind=kind, @@ -1819,19 +2307,20 @@ def _process_nodes_and_relationships( prefetch_relationships: bool, timeout: int | None = None, ) -> ProcessRelationsNodeSync: - """Processes InfrahubNodeSync and their Relationships from the GraphQL query response. + """ + Processes InfrahubNodeSync objects and relationships from a GraphQL response (synchronous version). + + Internal helper method. Args: - response (dict[str, Any]): The response from the GraphQL query. - schema_kind (str): The kind of schema being queried. - branch (str): The branch name. - prefetch_relationships (bool): Flag to indicate whether to prefetch relationship data. - timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds. + response: Raw dictionary response from GraphQL. + schema_kind: `kind` of primary nodes being processed. + branch: Branch name for these nodes. + prefetch_relationships: If True, processes and fetches related nodes. + timeout: Optional timeout for fetching related node schemas. Returns: - ProcessRelationsNodeSync: A TypedDict containing two lists: - - 'nodes': A list of InfrahubNodeSync objects representing the nodes processed. - - 'related_nodes': A list of InfrahubNodeSync objects representing the related nodes + ProcessRelationsNodeSync TypedDict with 'nodes' and 'related_nodes' lists. """ nodes: list[InfrahubNodeSync] = [] @@ -1907,27 +2396,29 @@ def filters( order: Order | None = None, **kwargs: Any, ) -> list[InfrahubNodeSync] | list[SchemaTypeSync]: - """Retrieve nodes of a given kind based on provided filters. + """ + Retrieves nodes based on filters and pagination (synchronous version). Args: - kind (str): kind of the nodes to query - at (Timestamp, optional): Time of the query. Defaults to Now. - branch (str, optional): Name of the branch to query from. Defaults to default_branch. - timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds. - populate_store (bool, optional): Flag to indicate whether to populate the store with the retrieved nodes. - offset (int, optional): The offset for pagination. - limit (int, optional): The limit for pagination. - include (list[str], optional): List of attributes or relationships to include in the query. - exclude (list[str], optional): List of attributes or relationships to exclude from the query. - fragment (bool, optional): Flag to use GraphQL fragments for generic schemas. - prefetch_relationships (bool, optional): Flag to indicate whether to prefetch related node data. - partial_match (bool, optional): Allow partial match of filter criteria for the query. - parallel (bool, optional): Whether to use parallel processing for the query. - order (Order, optional): Ordering related options. Setting `disable=True` enhances performances. - **kwargs (Any): Additional filter criteria for the query. + kind: Node kind (e.g., "CoreSite") or type (e.g., CoreSiteSync). + at: Optional timestamp for point-in-time query. + branch: Branch to query. Defaults to client's default. + timeout: Optional request timeout. + populate_store: If True (default), updates client's NodeStore. + offset: Optional pagination offset. + limit: Optional pagination limit. + include: Specific attributes/relationships to include. + exclude: Attributes/relationships to exclude. + fragment: If True, uses GraphQL fragments. + prefetch_relationships: If True, prefetches related node data. + partial_match: If True, allows partial string matching. + property: If True, indicates direct property field query. + parallel: If True, fetches pages in parallel (thread pool). + order: Optional `Order` object for sorting. + **kwargs: Filter criteria (attribute=value). Returns: - list[InfrahubNodeSync]: List of Nodes that match the given filters. + List of `InfrahubNodeSync` objects (or typed subclasses) matching filters. """ branch = branch or self.default_branch schema = self.schema.get(kind=kind, branch=branch) @@ -2153,6 +2644,34 @@ def get( property: bool = False, **kwargs: Any, ) -> InfrahubNodeSync | SchemaTypeSync | None: + """ + Retrieves a single node by ID, HFID, or attributes (synchronous version). + + Args: + kind: Node kind (e.g., "CoreSite") or type (e.g., CoreSiteSync). + raise_when_missing: If True (default), raises `NodeNotFoundError`. + If False, returns None if not found. + at: Optional timestamp for point-in-time query. + branch: Branch to query. Defaults to client's default. + timeout: Optional request timeout. + id: UUID of the node. + hfid: List of Human-Friendly IDs. + include: Specific attributes/relationships to include. + exclude: Attributes/relationships to exclude. + populate_store: If True (default), updates client's NodeStore. + fragment: If True, uses GraphQL fragments. + prefetch_relationships: If True, prefetches related node data. + property: If True, indicates direct property field query. + **kwargs: Additional filter criteria (attribute=value). + + Returns: + `InfrahubNodeSync` (or typed subclass) if found, or None. + + Raises: + NodeNotFoundError: If `raise_when_missing` and node not found. + IndexError: If multiple nodes match. + ValueError: If no filters provided or HFID used incorrectly. + """ branch = branch or self.default_branch schema = self.schema.get(kind=kind, branch=branch) @@ -2197,10 +2716,18 @@ def get( return results[0] def create_batch(self, return_exceptions: bool = False) -> InfrahubBatchSync: - """Create a batch to execute multiple queries concurrently. + """ + Creates an `InfrahubBatchSync` for managing concurrent synchronous tasks using a thread pool. + + Note: Due to the nature of thread pools, execution order of tasks within the batch + is not guaranteed. Avoid using for operations with strong interdependencies. + + Args: + return_exceptions: If True, exceptions from tasks are returned as results + instead of being raised. - Executing the batch will be performed using a thread pool, meaning it cannot guarantee the execution order. It is not recommended to use such - batch to manipulate objects that depend on each others. + Returns: + An `InfrahubBatchSync` instance. """ return InfrahubBatchSync( max_concurrent_execution=self.max_concurrent_execution, return_exceptions=return_exceptions @@ -2209,6 +2736,15 @@ def create_batch(self, return_exceptions: bool = False) -> InfrahubBatchSync: def get_list_repositories( self, branches: dict[str, BranchData] | None = None, kind: str = "CoreGenericRepository" ) -> dict[str, RepositoryData]: + """ + Retrieves a list of repositories and their branch information. + + Note: This method is deprecated in the async client and not implemented + in the sync client. + + Raises: + NotImplementedError + """ raise NotImplementedError( "This method is deprecated in the async client and won't be implemented in the sync client." ) @@ -2226,6 +2762,27 @@ def query_gql_query( tracker: str | None = None, raise_for_error: bool = True, ) -> dict: + """ + Executes a pre-defined GraphQL query stored on Infrahub by name (synchronous version). + + Args: + name: Name of the stored GraphQL query. + variables: Optional dictionary of variables. + update_group: If True, associates query with current tracking group. + subscribers: Optional list of subscriber identifiers. + params: Optional dictionary of additional URL parameters. + branch_name: Branch to execute against. Defaults to client's default. + at: Optional timestamp for point-in-time query. + timeout: Optional request timeout. + tracker: Optional tracker string for request headers. + raise_for_error: If True (default), raises on HTTP/GraphQL errors. + + Returns: + Dictionary containing query response data. + + Raises: + httpx.HTTPStatusError: For HTTP errors if `raise_for_error` is True. + """ url = f"{self.address}/api/query/{name}" url_params = copy.deepcopy(params or {}) headers = copy.copy(self.headers or {}) @@ -2281,6 +2838,18 @@ def get_diff_summary( tracker: str | None = None, raise_for_error: bool = True, ) -> list[NodeDiff]: + """ + Retrieves a diff summary for a branch (synchronous version). + + Args: + branch: Name of the branch. + timeout: Optional request timeout. + tracker: Optional tracker string for request headers. + raise_for_error: If True (default), raises on HTTP/GraphQL errors. + + Returns: + List of `NodeDiff` objects representing changes. + """ query = get_diff_summary_query() response = self.execute_graphql( query=query, @@ -2405,20 +2974,28 @@ def allocate_next_ip_address( tracker: str | None = None, raise_for_error: bool = True, ) -> CoreNodeSync | SchemaTypeSync | None: - """Allocate a new IP address by using the provided resource pool. + """ + Allocates next IP address from a pool (synchronous version). Args: - resource_pool (InfrahubNodeSync): Node corresponding to the pool to allocate resources from. - identifier (str, optional): Value to perform idempotent allocation, the same resource will be returned for a given identifier. - prefix_length (int, optional): Length of the prefix to set on the address to allocate. - address_type (str, optional): Kind of the address to allocate. - data (dict, optional): A key/value map to use to set attributes values on the allocated address. - branch (str, optional): Name of the branch to allocate from. Defaults to default_branch. - timeout (int, optional): Flag to indicate whether to populate the store with the retrieved nodes. - tracker (str, optional): The offset for pagination. - raise_for_error (bool, optional): The limit for pagination. + resource_pool: `CoreIPAddressPool` node. + kind: Optional specific type of `CoreIPAddress` to expect. + identifier: Optional identifier for idempotent allocation. + prefix_length: Optional desired prefix length. + address_type: Optional specific kind of IP address if pool supports multiple. + data: Optional attributes for the new IP address node. + branch: Branch for allocation. Defaults to client's default. + timeout: Optional request timeout. + tracker: Optional tracker string. + raise_for_error: If True (default), raises on errors. + If False, returns None on allocation failure. + Returns: - InfrahubNodeSync: Node corresponding to the allocated resource. + Allocated `CoreIPAddress` node (or typed subclass), or None. + + Raises: + ValueError: If `resource_pool` is not "CoreIPAddressPool". + GraphQLError: If allocation fails and `raise_for_error` is True. """ if resource_pool.get_kind() != "CoreIPAddressPool": raise ValueError("resource_pool is not an IP address pool") @@ -2552,21 +3129,29 @@ def allocate_next_ip_prefix( tracker: str | None = None, raise_for_error: bool = True, ) -> CoreNodeSync | SchemaTypeSync | None: - """Allocate a new IP prefix by using the provided resource pool. + """ + Allocates next IP prefix from a pool (synchronous version). Args: - resource_pool (InfrahubNodeSync): Node corresponding to the pool to allocate resources from. - identifier (str, optional): Value to perform idempotent allocation, the same resource will be returned for a given identifier. - size (int, optional): Length of the prefix to allocate. - member_type (str, optional): Member type of the prefix to allocate. - prefix_type (str, optional): Kind of the prefix to allocate. - data (dict, optional): A key/value map to use to set attributes values on the allocated prefix. - branch (str, optional): Name of the branch to allocate from. Defaults to default_branch. - timeout (int, optional): Flag to indicate whether to populate the store with the retrieved nodes. - tracker (str, optional): The offset for pagination. - raise_for_error (bool, optional): The limit for pagination. + resource_pool: `CoreIPPrefixPool` node. + kind: Optional specific type of `CoreIPPrefix` to expect. + identifier: Optional identifier for idempotent allocation. + prefix_length: Optional desired prefix length. + member_type: Optional member type (e.g., "prefix", "address"). + prefix_type: Optional specific kind of IP prefix if pool supports multiple. + data: Optional attributes for the new IP prefix node. + branch: Branch for allocation. Defaults to client's default. + timeout: Optional request timeout. + tracker: Optional tracker string. + raise_for_error: If True (default), raises on errors. + If False, returns None on allocation failure. + Returns: - InfrahubNodeSync: Node corresponding to the allocated resource. + Allocated `CoreIPPrefix` node (or typed subclass), or None. + + Raises: + ValueError: If `resource_pool` is not "CoreIPPrefixPool". + GraphQLError: If allocation fails and `raise_for_error` is True. """ if resource_pool.get_kind() != "CoreIPPrefixPool": raise ValueError("resource_pool is not an IP prefix pool") @@ -2594,17 +3179,35 @@ def allocate_next_ip_prefix( def repository_update_commit( self, branch_name: str, repository_id: str, commit: str, is_read_only: bool = False ) -> bool: + """ + Updates the commit SHA for a repository on a branch. + + Note: This method is deprecated in the async client and not implemented + in the sync client. + + Raises: + NotImplementedError + """ raise NotImplementedError( "This method is deprecated in the async client and won't be implemented in the sync client." ) @handle_relogin_sync def _get(self, url: str, headers: dict | None = None, timeout: int | None = None) -> httpx.Response: - """Execute a HTTP GET with HTTPX. + """ + Executes an HTTP GET request with login handling (synchronous version). + + Args: + url: The URL for the GET request. + headers: Optional request headers. + timeout: Optional request timeout. + + Returns: + An `httpx.Response` object. Raises: - ServerNotReachableError if we are not able to connect to the server - ServerNotResponsiveError if the server didnd't respond before the timeout expired + ServerNotReachableError: If the server is not reachable. + ServerNotResponsiveError: If the server times out. """ self.login() @@ -2616,11 +3219,21 @@ def _get(self, url: str, headers: dict | None = None, timeout: int | None = None @handle_relogin_sync def _post(self, url: str, payload: dict, headers: dict | None = None, timeout: int | None = None) -> httpx.Response: - """Execute a HTTP POST with HTTPX. + """ + Executes an HTTP POST request with login handling (synchronous version). + + Args: + url: The URL for the POST request. + payload: The request payload. + headers: Optional request headers. + timeout: Optional request timeout. + + Returns: + An `httpx.Response` object. Raises: - ServerNotReachableError if we are not able to connect to the server - ServerNotResponsiveError if the server didnd't respond before the timeout expired + ServerNotReachableError: If the server is not reachable. + ServerNotResponsiveError: If the server times out. """ self.login() @@ -2635,6 +3248,21 @@ def _post(self, url: str, payload: dict, headers: dict | None = None, timeout: i def _request( self, url: str, method: HTTPMethod, headers: dict[str, Any], timeout: int, payload: dict | None = None ) -> httpx.Response: + """ + Internal method for making HTTP requests (synchronous version). + + Uses the configured synchronous requester and records the response. + + Args: + url: Request URL. + method: HTTP method. + headers: Request headers. + timeout: Request timeout. + payload: Optional request payload. + + Returns: + An `httpx.Response` object. + """ response = self._request_method(url=url, method=method, headers=headers, timeout=timeout, payload=payload) self._record(response) return response @@ -2642,6 +3270,25 @@ def _request( def _default_request_method( self, url: str, method: HTTPMethod, headers: dict[str, Any], timeout: int, payload: dict | None = None ) -> httpx.Response: + """ + Default synchronous HTTP request method using `httpx.Client`. + + Handles proxy and TLS settings. + + Args: + url: Request URL. + method: HTTP method. + headers: Request headers. + timeout: Request timeout. + payload: Optional request payload. + + Returns: + An `httpx.Response` object. + + Raises: + ServerNotReachableError: If a network error occurs. + ServerNotResponsiveError: If a read timeout occurs. + """ params: dict[str, Any] = {} if payload: params["json"] = payload @@ -2675,6 +3322,15 @@ def _default_request_method( return response def refresh_login(self) -> None: + """ + Refreshes authentication token (synchronous version). + + Updates `self.access_token` and "Authorization" header. + Called automatically by decorated request methods on token expiry. + + Raises: + httpx.HTTPStatusError: If refresh request fails. + """ if not self.refresh_token: return @@ -2692,6 +3348,20 @@ def refresh_login(self) -> None: self.headers["Authorization"] = f"Bearer {self.access_token}" def login(self, refresh: bool = False) -> None: + """ + Logs into Infrahub or refreshes session (synchronous version). + + Performs full login if no token or `refresh` is False. + Attempts token refresh if `refresh` is True and refresh token exists. + Updates `self.access_token`, `self.refresh_token`, and "Authorization" header. + + Args: + refresh: If True, attempts to refresh token if available. + + Raises: + AuthenticationError: On authentication failure. + httpx.HTTPStatusError: For other HTTP errors. + """ if not self.config.password_authentication: return @@ -2707,7 +3377,7 @@ def login(self, refresh: bool = False) -> None: # Other status codes indicate other errors if exc.response.status_code != 401: response = exc.response.json() - errors = response.get("errors") + errors = response.get("errors", []) messages = [error.get("message") for error in errors] raise AuthenticationError(" | ".join(messages)) from exc @@ -2727,6 +3397,7 @@ def login(self, refresh: bool = False) -> None: self.headers["Authorization"] = f"Bearer {self.access_token}" def __enter__(self) -> Self: + """Enters a synchronous context, returning the client instance.""" return self def __exit__( @@ -2735,6 +3406,13 @@ def __exit__( exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: + """ + Exits a synchronous context. + + If client was in TRACKING mode and no exception occurred, + finalizes group context (e.g., calls `update_group()`). + Resets client mode to DEFAULT. + """ if exc_type is None and self.mode == InfrahubClientMode.TRACKING: self.group_context.update_group() diff --git a/infrahub_sdk/config.py b/infrahub_sdk/config.py index b0a2402a..20c95a7f 100644 --- a/infrahub_sdk/config.py +++ b/infrahub_sdk/config.py @@ -15,6 +15,7 @@ class ProxyMountsConfig(BaseSettings): + """Configuration for HTTP/HTTPS proxy mounts.""" model_config = SettingsConfigDict(populate_by_name=True) http: str | None = Field( default=None, @@ -31,10 +32,17 @@ class ProxyMountsConfig(BaseSettings): @property def is_set(self) -> bool: + """True if either HTTP or HTTPS proxy is configured.""" return self.http is not None or self.https is not None class ConfigBase(BaseSettings): + """ + Base configuration settings for the Infrahub client. + + These settings can be sourced from environment variables (with INFRAHUB_ prefix) + or direct initialization. + """ model_config = SettingsConfigDict(env_prefix="INFRAHUB_", validate_assignment=True) address: str = Field(default="http://localhost:8000", description="The URL to use when connecting to Infrahub.") api_token: str | None = Field(default=None, description="API token for authentication against Infrahub.") @@ -82,6 +90,7 @@ class ConfigBase(BaseSettings): @model_validator(mode="before") @classmethod def validate_credentials_input(cls, values: dict[str, Any]) -> dict[str, Any]: + """Ensures that if 'username' is provided, 'password' is also provided, and vice-versa.""" has_username = "username" in values has_password = "password" in values if has_username != has_password: @@ -91,6 +100,7 @@ def validate_credentials_input(cls, values: dict[str, Any]) -> dict[str, Any]: @model_validator(mode="before") @classmethod def set_transport(cls, values: dict[str, Any]) -> dict[str, Any]: + """Sets requester and sync_requester if transport is JSONPlayback.""" if values.get("transport") == RequesterTransport.JSON: playback = JSONPlayback() if "requester" not in values: @@ -103,6 +113,7 @@ def set_transport(cls, values: dict[str, Any]) -> dict[str, Any]: @model_validator(mode="before") @classmethod def validate_mix_authentication_schemes(cls, values: dict[str, Any]) -> dict[str, Any]: + """Ensures that password authentication and API token authentication are not mixed.""" if values.get("password") and values.get("api_token"): raise ValueError("Unable to combine password with token based authentication") return values @@ -110,6 +121,7 @@ def validate_mix_authentication_schemes(cls, values: dict[str, Any]) -> dict[str @field_validator("address") @classmethod def validate_address(cls, value: str) -> str: + """Validates and normalizes the Infrahub server address.""" if is_valid_url(value): return value.rstrip("/") @@ -117,12 +129,19 @@ def validate_address(cls, value: str) -> str: @model_validator(mode="after") def validate_proxy_config(self) -> Self: + """Ensures that 'proxy' and 'proxy_mounts' are not used simultaneously.""" if self.proxy and self.proxy_mounts.is_set: raise ValueError("'proxy' and 'proxy_mounts' are mutually exclusive") return self @property def default_infrahub_branch(self) -> str: + """ + Determines the default Infrahub branch to use. + + If `default_branch_from_git` is True, it attempts to get the current git branch. + Otherwise, it uses `default_branch`. + """ branch: str | None = None if not self.default_branch_from_git: branch = self.default_branch @@ -131,10 +150,17 @@ def default_infrahub_branch(self) -> str: @property def password_authentication(self) -> bool: + """True if username (and thus password) is configured, indicating password authentication is intended.""" return bool(self.username) class Config(ConfigBase): + """ + Main configuration object for the Infrahub client. + + Inherits from `ConfigBase` and adds settings for recorders, custom requesters, + and logging. + """ recorder: RecorderType = Field(default=RecorderType.NONE, description="Select builtin recorder for later replay.") custom_recorder: Recorder = Field( default_factory=NoRecorder.default, description="Provides a way to record responses from the Infrahub API" @@ -145,6 +171,13 @@ class Config(ConfigBase): @property def logger(self) -> InfrahubLoggers: + """ + Provides the configured logger instance. + + This property allows for type hinting and usage of the logger, + even if a custom logger (like structlog) with a different class + structure is provided. + """ # We expect the log to adhere to the definitions defined by the InfrahubLoggers object # When using structlog the logger doesn't expose the expected methods by looking at the # object to pydantic rejects them. This is a workaround to allow structlog to be used @@ -154,6 +187,7 @@ def logger(self) -> InfrahubLoggers: @model_validator(mode="before") @classmethod def set_custom_recorder(cls, values: dict[str, Any]) -> dict[str, Any]: + """Sets the `custom_recorder` based on the `recorder` type if not already set.""" if values.get("recorder") == RecorderType.NONE and "custom_recorder" not in values: values["custom_recorder"] = NoRecorder() elif values.get("recorder") == RecorderType.JSON and "custom_recorder" not in values: @@ -161,6 +195,15 @@ def set_custom_recorder(cls, values: dict[str, Any]) -> dict[str, Any]: return values def clone(self, branch: str | None = None) -> Config: + """ + Creates a deep copy of the current configuration, optionally overriding the default branch. + + Args: + branch: If provided, sets the `default_branch` in the cloned configuration. + + Returns: + A new `Config` instance with copied settings. + """ config: dict[str, Any] = { "default_branch": branch or self.default_branch, "recorder": self.recorder, diff --git a/infrahub_sdk/constants.py b/infrahub_sdk/constants.py index 04dd6b95..73c73d94 100644 --- a/infrahub_sdk/constants.py +++ b/infrahub_sdk/constants.py @@ -2,6 +2,14 @@ class InfrahubClientMode(str, enum.Enum): + """ + Defines the operational modes for the Infrahub client. + + Attributes: + DEFAULT: Standard operational mode. + TRACKING: Mode where client operations can be tracked as part of a group, + often used for idempotent operations or cleanup. + """ DEFAULT = "default" TRACKING = "tracking" - # IDEMPOTENT = "idempotent" + # IDEMPOTENT = "idempotent" # This mode seems to be commented out. diff --git a/infrahub_sdk/context.py b/infrahub_sdk/context.py index 201a9ef9..eb4f5ee4 100644 --- a/infrahub_sdk/context.py +++ b/infrahub_sdk/context.py @@ -4,6 +4,7 @@ class ContextAccount(BaseModel): + """Represents account information within a request context.""" id: str = Field(..., description="The ID of the account") diff --git a/infrahub_sdk/ctl/check.py b/infrahub_sdk/ctl/check.py index 0626d884..f19cb29f 100644 --- a/infrahub_sdk/ctl/check.py +++ b/infrahub_sdk/ctl/check.py @@ -29,6 +29,14 @@ @dataclass class CheckModule: + """ + Represents a check module loaded from a repository configuration. + + Attributes: + name: The name of the check module (usually derived from the file name). + check_class: The loaded class that implements the InfrahubCheck logic. + definition: The configuration definition of the check from the repository. + """ name: str check_class: type[InfrahubCheck] definition: InfrahubCheckDefinitionConfig @@ -91,6 +99,21 @@ async def run_check( branch: str | None = None, params: dict | None = None, ) -> bool: + """ + Runs a single InfrahubCheck instance. + + Args: + check_module: The CheckModule to run. + client: Initialized InfrahubClient. + format_json: If True, logs will be output in JSON format to stdout. + path: The root directory path for the check. + repository_config: The repository configuration. + branch: Optional branch name to run the check against. + params: Optional parameters to pass to the check's GraphQL query. + + Returns: + True if the check passed, False otherwise. + """ module_name = check_module.name output = "stdout" if format_json else None log = logging.getLogger("infrahub") @@ -137,6 +160,25 @@ async def run_targeted_check( variables: dict[str, str], branch: str | None = None, ) -> bool: + """ + Runs a check that is targeted against specific members of a group. + + If `variables` are provided, the check runs once with those variables. + Otherwise, it discovers members of the target group defined in `check_module.definition` + and runs the check for each member. + + Args: + check_module: The CheckModule to run, which includes target definitions. + client: Initialized InfrahubClient. + format_json: If True, logs will be output in JSON format. + path: Root directory path for the check. + repository_config: The repository configuration. + variables: Specific variables to run the check with, bypassing target discovery. + branch: Optional branch name. + + Returns: + True if all runs of the check passed, False otherwise. + """ filters = {} param_value = list(check_module.definition.parameters.values()) if param_value: @@ -189,6 +231,23 @@ async def run_checks( repository_config: InfrahubRepositoryConfig, branch: str | None = None, ) -> None: + """ + Asynchronously runs a list of check modules. + + It initializes a client and then iterates through `check_modules`, + running either `run_targeted_check` or `run_check` based on whether + the check definition has targets. + + Exits with status code 1 if any check fails. + + Args: + check_modules: A list of CheckModule instances to execute. + format_json: If True, logs output in JSON format. + path: Root directory path for the checks. + variables: Variables to pass to checks (can override targeted check discovery). + repository_config: The repository configuration. + branch: Optional branch name to run checks against. + """ log = logging.getLogger("infrahub") check_summary: list[bool] = [] @@ -227,6 +286,18 @@ async def run_checks( def get_modules(check_definitions: list[InfrahubCheckDefinitionConfig]) -> list[CheckModule]: + """ + Loads check classes from their file paths based on check definitions. + + Args: + check_definitions: A list of InfrahubCheckDefinitionConfig objects. + + Returns: + A list of CheckModule objects, each containing the loaded class and its definition. + + Raises: + typer.Exit: If a module cannot be imported or a class cannot be loaded. + """ modules = [] for check_definition in check_definitions: module_name = check_definition.file_path.stem @@ -245,6 +316,12 @@ def get_modules(check_definitions: list[InfrahubCheckDefinitionConfig]) -> list[ def list_checks(repository_config: InfrahubRepositoryConfig) -> None: + """ + Prints a list of available checks defined in the repository configuration. + + Args: + repository_config: The loaded repository configuration. + """ console.print(f"Python checks defined in repository: {len(repository_config.check_definitions)}") for check in repository_config.check_definitions: diff --git a/infrahub_sdk/ctl/cli_commands.py b/infrahub_sdk/ctl/cli_commands.py index 605743fa..07265a10 100644 --- a/infrahub_sdk/ctl/cli_commands.py +++ b/infrahub_sdk/ctl/cli_commands.py @@ -198,15 +198,26 @@ async def _run_transform( repository_config: InfrahubRepositoryConfig, ) -> Any: """ - Query GraphQL for the required data then run a transform on that data. + Queries GraphQL for data and then applies a transformation function to it. + + This internal helper function is used by commands like `render` and `transform`. + It fetches data using a specified GraphQL query and variables, then passes + the response to the provided `transform_func`. Args: - query_name: Name of the query to load (e.g. tags_query) - variables: Dictionary of variables used for graphql query - transform_func: The function responsible for transforming data received from graphql - branch: Name of the *infrahub* branch that should be queried for data - debug: Prints debug info to the command line - repository_config: Repository config object. This is used to load the graphql query from the repository. + query_name: The name of the GraphQL query to execute (must be defined in `repository_config`). + variables: A dictionary of variables to pass to the GraphQL query. + transform_func: A callable (sync or async) that takes the GraphQL response data + and returns the transformed data. + branch: The Infrahub branch name to query against. + debug: If True, enables debug logging/output (currently affects GraphQL query execution). + repository_config: The repository configuration containing query definitions. + + Returns: + The result of applying `transform_func` to the GraphQL query response. + + Raises: + typer.Exit: If the specified query is not found or if GraphQL errors occur. """ branch = get_branch(branch) diff --git a/infrahub_sdk/ctl/client.py b/infrahub_sdk/ctl/client.py index 3932b8b1..11ca92d0 100644 --- a/infrahub_sdk/ctl/client.py +++ b/infrahub_sdk/ctl/client.py @@ -14,6 +14,21 @@ def initialize_client( max_concurrent_execution: int | None = None, retry_on_failure: bool | None = None, ) -> InfrahubClient: + """ + Initializes and returns an asynchronous InfrahubClient. + + Uses global CLI configuration settings and allows overriding specific parameters. + + Args: + branch: Optional default branch for the client. + identifier: Optional identifier for tracking client operations. + timeout: Optional request timeout in seconds. + max_concurrent_execution: Optional limit for concurrent operations in batch mode. + retry_on_failure: Optional flag to enable/disable retries on failure. + + Returns: + An initialized InfrahubClient instance. + """ return InfrahubClient( config=_define_config( branch=branch, @@ -32,6 +47,21 @@ def initialize_client_sync( max_concurrent_execution: int | None = None, retry_on_failure: bool | None = None, ) -> InfrahubClientSync: + """ + Initializes and returns a synchronous InfrahubClientSync. + + Uses global CLI configuration settings and allows overriding specific parameters. + + Args: + branch: Optional default branch for the client. + identifier: Optional identifier for tracking client operations. + timeout: Optional request timeout in seconds. + max_concurrent_execution: Optional limit for concurrent operations in batch mode. + retry_on_failure: Optional flag to enable/disable retries on failure. + + Returns: + An initialized InfrahubClientSync instance. + """ return InfrahubClientSync( config=_define_config( branch=branch, @@ -50,6 +80,21 @@ def _define_config( max_concurrent_execution: int | None = None, retry_on_failure: bool | None = None, ) -> Config: + """ + Internal helper to construct a Config object for client initialization. + + Prioritizes explicitly passed arguments, then falls back to global CLI settings. + + Args: + branch: Default branch. + identifier: Tracker identifier. + timeout: Request timeout. + max_concurrent_execution: Max concurrent tasks for batch operations. + retry_on_failure: Whether to retry on failure. + + Returns: + A Config object. + """ client_config: dict[str, Any] = { "address": config.SETTINGS.active.server_address, "insert_tracker": True, diff --git a/infrahub_sdk/ctl/config.py b/infrahub_sdk/ctl/config.py index 9d3b6488..2361f46d 100644 --- a/infrahub_sdk/ctl/config.py +++ b/infrahub_sdk/ctl/config.py @@ -25,15 +25,32 @@ class Settings(BaseSettings): @field_validator("server_address") @classmethod def cleanup_server_address(cls, v: str) -> str: + """Removes trailing slashes from the server_address.""" return v.rstrip("/") class ConfiguredSettings: + """ + Manages the loading and access of Infrahub CLI settings. + + This class ensures that settings are loaded (e.g., from a TOML file or environment variables) + before they are accessed, providing a single point of truth for configuration. + """ def __init__(self) -> None: + """Initializes ConfiguredSettings with no settings loaded yet.""" self._settings: Settings | None = None @property def active(self) -> Settings: + """ + Provides the currently active Settings instance. + + Raises: + typer.Abort: If settings have not been loaded before access. + + Returns: + The loaded Settings object. + """ if self._settings: return self._settings @@ -41,10 +58,21 @@ def active(self) -> Settings: raise typer.Abort() def load(self, config_file: str | Path = "infrahubctl.toml", config_data: dict | None = None) -> None: - """Load configuration. + """ + Loads configuration settings. + + The method attempts to load settings from `config_data` if provided. + If not, it tries to load from the specified `config_file`. + If neither is successful or available, it falls back to default Pydantic settings + (which can include environment variables). - Configuration is loaded from a config file in toml format that contains the settings, - or from a dictionary of those settings passed in as "config_data" + Once settings are successfully loaded, subsequent calls to `load` will do nothing. + + Args: + config_file: Path to the TOML configuration file. + Defaults to "infrahubctl.toml". + config_data: A dictionary containing configuration settings. + If provided, this takes precedence over `config_file`. """ if self._settings: diff --git a/infrahub_sdk/ctl/exceptions.py b/infrahub_sdk/ctl/exceptions.py index fc764f3b..8081f7ee 100644 --- a/infrahub_sdk/ctl/exceptions.py +++ b/infrahub_sdk/ctl/exceptions.py @@ -3,6 +3,14 @@ class Error(Exception): class QueryNotFoundError(Error): + """Exception raised when a GraphQL query is not found in the repository.""" def __init__(self, name: str, message: str = ""): + """ + Initializes QueryNotFoundError. + + Args: + name: The name of the query that was not found. + message: Optional custom message. If not provided, a default message is generated. + """ self.message = message or f"The requested query '{name}' was not found." super().__init__(self.message) diff --git a/infrahub_sdk/ctl/exporter.py b/infrahub_sdk/ctl/exporter.py index ae5e5d18..6b85a3da 100644 --- a/infrahub_sdk/ctl/exporter.py +++ b/infrahub_sdk/ctl/exporter.py @@ -12,6 +12,13 @@ def directory_name_with_timestamp() -> str: + """ + Generates a directory name string prefixed with "infrahubexport-" + and appended with the current timestamp in YYYYMMDD-HHMMSS format. + + Returns: + A string suitable for a directory name. + """ right_now = datetime.now(timezone.utc).astimezone() timestamp = right_now.strftime("%Y%m%d-%H%M%S") return f"infrahubexport-{timestamp}" diff --git a/infrahub_sdk/ctl/generator.py b/infrahub_sdk/ctl/generator.py index 49019196..92abaad5 100644 --- a/infrahub_sdk/ctl/generator.py +++ b/infrahub_sdk/ctl/generator.py @@ -25,6 +25,28 @@ async def run( branch: str | None = None, variables: Optional[list[str]] = None, ) -> None: + """ + Runs a specified generator script. + + This function initializes logging, loads repository and generator configurations, + and then executes the generator. If `list_available` is True or `generator_name` + is not provided, it lists available generators instead. + + The generator can be run with specific variables or against targets defined + in its configuration (members of a CoreGroup). + + Args: + generator_name: The name of the generator to run. + path: The root directory path (currently unused, marked with noqa: ARG001). + debug: If True, enables debug logging. + list_available: If True, lists available generators and exits. + branch: Optional branch name to run the generator against. + variables: Optional list of "key=value" strings to pass as variables + to the generator's query. If provided, target discovery is skipped. + + Raises: + typer.Exit: If the generator class cannot be loaded or other critical errors occur. + """ init_logging(debug=debug) repository_config = get_repository_config(Path(config.INFRAHUB_REPO_CONFIG_FILE)) @@ -108,6 +130,12 @@ async def run( def list_generators(repository_config: InfrahubRepositoryConfig) -> None: + """ + Prints a list of available generators defined in the repository configuration. + + Args: + repository_config: The loaded repository configuration. + """ console = Console() console.print(f"Generators defined in repository: {len(repository_config.generator_definitions)}") diff --git a/infrahub_sdk/ctl/importer.py b/infrahub_sdk/ctl/importer.py index e9181c8b..8e0fd4cb 100644 --- a/infrahub_sdk/ctl/importer.py +++ b/infrahub_sdk/ctl/importer.py @@ -15,6 +15,16 @@ def local_directory() -> Path: + """ + Returns the current working directory as a resolved Path object. + + This is used as a default for Typer options requiring a directory path. + The comment about documentation generation suggests it might be to ensure + Path().resolve() is called at runtime rather than module load time. + + Returns: + A Path object representing the current absolute directory. + """ # We use a function here to avoid failure when generating the documentation due to directory name return Path().resolve() diff --git a/infrahub_sdk/ctl/render.py b/infrahub_sdk/ctl/render.py index cb1c962e..85bd5c5d 100644 --- a/infrahub_sdk/ctl/render.py +++ b/infrahub_sdk/ctl/render.py @@ -10,6 +10,12 @@ def list_jinja2_transforms(config: InfrahubRepositoryConfig) -> None: + """ + Prints a list of available Jinja2 transforms defined in the repository configuration. + + Args: + config: The loaded repository configuration. + """ console = Console() console.print(f"Jinja2 transforms defined in repository: {len(config.jinja2_transforms)}") @@ -18,6 +24,16 @@ def list_jinja2_transforms(config: InfrahubRepositoryConfig) -> None: def print_template_errors(error: JinjaTemplateError, console: Console) -> None: + """ + Prints formatted error messages to the console for JinjaTemplateError exceptions. + + It handles specific subtypes of JinjaTemplateError (NotFoundError, UndefinedError, SyntaxError) + to provide more detailed information. + + Args: + error: The JinjaTemplateError instance. + console: The Rich Console object for printing. + """ if isinstance(error, JinjaTemplateNotFoundError): console.print("[red]An error occurred while rendering the jinja template") console.print("") diff --git a/infrahub_sdk/ctl/repository.py b/infrahub_sdk/ctl/repository.py index 98e394bf..94d537ac 100644 --- a/infrahub_sdk/ctl/repository.py +++ b/infrahub_sdk/ctl/repository.py @@ -23,6 +23,18 @@ def get_repository_config(repo_config_file: Path) -> InfrahubRepositoryConfig: + """ + Loads, validates, and returns the repository configuration from a YAML file. + + Args: + repo_config_file: Path to the repository configuration file (e.g., .infrahub.yml). + + Returns: + An InfrahubRepositoryConfig object. + + Raises: + typer.Exit: If the file is not found, not valid YAML, or fails Pydantic validation. + """ try: config_file_data = load_repository_config_file(repo_config_file) except FileNotFoundError as exc: @@ -45,6 +57,19 @@ def get_repository_config(repo_config_file: Path) -> InfrahubRepositoryConfig: def load_repository_config_file(repo_config_file: Path) -> dict: + """ + Reads a YAML file from the given path and loads it into a Python dictionary. + + Args: + repo_config_file: Path to the YAML configuration file. + + Returns: + A dictionary representing the content of the YAML file. + + Raises: + FileNotFoundError: If the `repo_config_file` does not exist (raised by `read_file`). + FileNotValidError: If the file content is not valid YAML. + """ yaml_data = read_file(file_path=repo_config_file) try: diff --git a/infrahub_sdk/ctl/schema.py b/infrahub_sdk/ctl/schema.py index 6e9ff994..40d4211f 100644 --- a/infrahub_sdk/ctl/schema.py +++ b/infrahub_sdk/ctl/schema.py @@ -33,6 +33,15 @@ def callback() -> None: def validate_schema_content_and_exit(client: InfrahubClient, schemas: list[SchemaFile]) -> None: + """ + Validates the content of schema files using the client's schema validation. + + If any schema is invalid, prints error details to the console and exits the program. + + Args: + client: An initialized InfrahubClient. + schemas: A list of SchemaFile objects whose content will be validated. + """ has_error: bool = False for schema_file in schemas: try: @@ -49,6 +58,18 @@ def validate_schema_content_and_exit(client: InfrahubClient, schemas: list[Schem def display_schema_load_errors(response: dict[str, Any], schemas_data: list[dict]) -> None: + """ + Displays detailed error messages when schema loading fails. + + Parses the error response from the Infrahub API and attempts to pinpoint + the location of errors within the provided schema data, printing them + in a user-friendly format. + + Args: + response: The error response dictionary from the Infrahub API. + schemas_data: A list of dictionaries, where each dictionary is the parsed + content of a schema file (used to find node names for errors). + """ console.print("[red]Unable to load the schema:") if "detail" not in response: handle_non_detail_errors(response=response) @@ -84,21 +105,56 @@ def display_schema_load_errors(response: dict[str, Any], schemas_data: list[dict def handle_non_detail_errors(response: dict[str, Any]) -> None: + """ + Handles and prints generic error messages from an API response + when a detailed error structure (like `response["detail"]`) is not available. + + Args: + response: The error response dictionary from the API. + """ if "error" in response: console.print(f" {response.get('error')}") elif "errors" in response: - for error in response.get("errors"): - console.print(f" {error.get('message')}") + for error_item in response.get("errors", []): # Ensure errors is treated as a list + if isinstance(error_item, dict): + console.print(f" {error_item.get('message')}") + else: + console.print(f" {error_item}") # Handle cases where error is just a string else: console.print(f" '{response}'") def valid_error_path(loc_path: list[Any]) -> bool: + """ + Checks if an error location path from Pydantic validation is valid for schema errors. + + A valid path typically looks like: `['body', 'schemas', , 'nodes', , ]`. + + Args: + loc_path: The location path list from a Pydantic validation error. + + Returns: + True if the path structure is recognized for schema errors, False otherwise. + """ return len(loc_path) >= 6 and loc_path[0] == "body" and loc_path[1] == "schemas" -def get_node(schemas_data: list[dict], schema_index: int, node_index: int) -> dict | None: - if schema_index < len(schemas_data) and node_index < len(schemas_data[schema_index].content["nodes"]): +def get_node(schemas_data: list[SchemaFile], schema_index: int, node_index: int) -> dict | None: # Corrected type hint for schemas_data + """ + Retrieves a specific node definition from a list of parsed schema file contents. + + Args: + schemas_data: A list of SchemaFile objects, where each object's `content` + attribute holds the parsed schema data (e.g., from YAML). + schema_index: The index of the schema file in `schemas_data`. + node_index: The index of the node within the specified schema file's "nodes" list. + + Returns: + A dictionary representing the node definition if found, otherwise None. + """ + if schema_index < len(schemas_data) and schemas_data[schema_index].content and \ + "nodes" in schemas_data[schema_index].content and \ + node_index < len(schemas_data[schema_index].content["nodes"]): return schemas_data[schema_index].content["nodes"][node_index] return None diff --git a/infrahub_sdk/ctl/transform.py b/infrahub_sdk/ctl/transform.py index 1cda0940..7cbb2677 100644 --- a/infrahub_sdk/ctl/transform.py +++ b/infrahub_sdk/ctl/transform.py @@ -4,6 +4,12 @@ def list_transforms(config: InfrahubRepositoryConfig) -> None: + """ + Prints a list of available Python transforms defined in the repository configuration. + + Args: + config: The loaded repository configuration. + """ console = Console() console.print(f"Python transforms defined in repository: {len(config.python_transforms)}") diff --git a/infrahub_sdk/ctl/utils.py b/infrahub_sdk/ctl/utils.py index 63d7dfb8..3e696e5b 100644 --- a/infrahub_sdk/ctl/utils.py +++ b/infrahub_sdk/ctl/utils.py @@ -40,6 +40,15 @@ def init_logging(debug: bool = False) -> None: + """ + Initializes basic logging for CLI operations. + + Sets log levels for Infrahub SDK and HTTPX/HTTPCore libraries to minimize noise. + Configures a RichHandler for console output. + + Args: + debug: If True, sets the root logger level to DEBUG, otherwise INFO. + """ logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL) logging.getLogger("httpx").setLevel(logging.ERROR) logging.getLogger("httpcore").setLevel(logging.ERROR) @@ -113,6 +122,19 @@ def execute_graphql_query( branch: str | None = None, debug: bool = False, ) -> dict: + """ + Executes a GraphQL query using the synchronous Infrahub client. + + Args: + query: The name of the query (as defined in `repository_config`) or the query string itself. + variables_dict: A dictionary of variables for the GraphQL query. + repository_config: The repository configuration containing query definitions. + branch: Optional branch name to execute the query against. + debug: If True, prints the GraphQL response to the console. + + Returns: + A dictionary containing the GraphQL query response. + """ console = Console() query_object = repository_config.get_query(name=query) query_str = query_object.load_query() @@ -135,8 +157,16 @@ def execute_graphql_query( def print_graphql_errors(console: Console, errors: list) -> None: + """ + Prints GraphQL errors to the console with rich formatting. + + Args: + console: The Rich Console object for printing. + errors: A list of error objects, typically from a GraphQLError exception. + """ if not isinstance(errors, list): console.print(f"[red]{escape(str(errors))}") + return # Ensure function exits if errors is not a list for error in errors: if isinstance(error, dict) and "message" in error and "path" in error: @@ -146,6 +176,17 @@ def print_graphql_errors(console: Console, errors: list) -> None: def parse_cli_vars(variables: Optional[list[str]]) -> dict[str, str]: + """ + Parses a list of "key=value" strings into a dictionary. + + Args: + variables: An optional list of strings, where each string is expected + to be in "key=value" format. + + Returns: + A dictionary of parsed key-value pairs. Returns an empty dictionary + if `variables` is None or empty. + """ if not variables: return {} @@ -153,6 +194,19 @@ def parse_cli_vars(variables: Optional[list[str]]) -> dict[str, str]: def find_graphql_query(name: str, directory: str | Path = ".") -> str: + """ + Searches for a GraphQL query file (.gql) by its stem name within a directory. + + Args: + name: The stem name of the query file (without the .gql extension). + directory: The directory to search in. Defaults to the current directory. + + Returns: + The content of the found query file as a string. + + Raises: + QueryNotFoundError: If no .gql file with the given stem name is found. + """ if isinstance(directory, str): directory = Path(directory) @@ -165,6 +219,15 @@ def find_graphql_query(name: str, directory: str | Path = ".") -> str: def render_action_rich(value: str) -> str: + """ + Formats an action string (created, updated, deleted) with Rich markup for colored output. + + Args: + value: The action string. + + Returns: + A Rich-formatted string with color based on the action. + """ if value == "created": return f"[green]{value.upper()}[/green]" if value == "updated": @@ -184,6 +247,21 @@ def get_fixtures_dir() -> Path: def load_yamlfile_from_disk_and_exit( paths: list[Path], file_type: type[YamlFileVar], console: Console ) -> list[YamlFileVar]: + """ + Loads YAML files of a specific type from disk and exits on validation errors. + + Args: + paths: A list of Path objects pointing to the YAML files. + file_type: The specific YamlFile subclass to use for loading and validation + (e.g., SchemaFile, ObjectFile). + console: The Rich Console object for printing error messages. + + Returns: + A sorted list of loaded and validated YamlFileVar objects. + + Raises: + typer.Exit: If any file is not found, invalid YAML, or fails content validation. + """ has_error = False try: data_files = file_type.load_from_disk(paths=paths) @@ -204,6 +282,15 @@ def load_yamlfile_from_disk_and_exit( def display_object_validate_format_success(file: ObjectFile, console: Console) -> None: + """ + Prints a success message to the console for a validated object file. + + Distinguishes between single-document and multi-document YAML files in the message. + + Args: + file: The validated ObjectFile. + console: The Rich Console object for printing. + """ if file.multiple_documents: console.print(f"[green] File '{file.location}' [{file.document_position}] is Valid!") else: @@ -211,6 +298,17 @@ def display_object_validate_format_success(file: ObjectFile, console: Console) - def display_object_validate_format_error(file: ObjectFile, error: ValidationError, console: Console) -> None: + """ + Prints detailed error messages to the console for an object file that failed validation. + + Distinguishes between single-document and multi-document YAML files and lists + all specific validation error messages. + + Args: + file: The ObjectFile that failed validation. + error: The Pydantic ValidationError object. + console: The Rich Console object for printing. + """ if file.multiple_documents: console.print(f"[red] File '{file.location}' [{file.document_position}] is not valid!") else: