Skip to content

Add docstring to all functions & methods #437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: stable
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions infrahub_sdk/_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ def import_module(module_path: Path, import_root: str | None = None, relative_pa
module_path (Path): Absolute path of the module to import.
import_root (Optional[str]): Absolute string path to the current repository.
relative_path (Optional[str]): Relative string path between module_path and import_root.

Returns:
ModuleType: The imported module.

Raises:
ModuleImportError: If the module cannot be imported due to ModuleNotFoundError or SyntaxError.
"""
import_root = import_root or str(module_path.parent)

Expand Down
50 changes: 48 additions & 2 deletions infrahub_sdk/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,40 @@


class GraphQLQueryVariable(BaseModel):
"""Represents a variable in a GraphQL query."""
name: str
type: str
required: bool = False
default_value: Any | None = None


class GraphQLOperation(BaseModel):
"""Represents a single operation within a GraphQL query."""
name: str | None = None
operation_type: OperationType


class GraphQLQueryAnalyzer:
"""Analyzes GraphQL queries to extract information about operations, variables, and structure."""
def __init__(self, query: str, schema: GraphQLSchema | None = None):
"""Initializes the GraphQLQueryAnalyzer.

Args:
query: The GraphQL query string.
schema: The GraphQL schema.
"""
self.query: str = query
self.schema: GraphQLSchema | None = schema
self.document: DocumentNode = parse(self.query)
self._fields: dict | None = None

@property
def is_valid(self) -> tuple[bool, list[GraphQLError] | None]:
"""Validates the query against the schema if provided.

Returns:
A tuple containing a boolean indicating validity and a list of errors if any.
"""
if self.schema is None:
return False, [GraphQLError("Schema is not provided")]

Expand All @@ -49,10 +63,16 @@ def is_valid(self) -> tuple[bool, list[GraphQLError] | None]:

@property
def nbr_queries(self) -> int:
"""Returns the number of definitions in the query document."""
return len(self.document.definitions)

@property
def operations(self) -> list[GraphQLOperation]:
"""Extracts all operations (queries, mutations, subscriptions) from the query.

Returns:
A list of GraphQLOperation objects.
"""
operations = []
for definition in self.document.definitions:
if not isinstance(definition, OperationDefinitionNode):
Expand All @@ -66,10 +86,20 @@ def operations(self) -> list[GraphQLOperation]:

@property
def contains_mutation(self) -> bool:
"""Checks if the query contains any mutation operations.

Returns:
True if a mutation is present, False otherwise.
"""
return any(op.operation_type == OperationType.MUTATION for op in self.operations)

@property
def variables(self) -> list[GraphQLQueryVariable]:
"""Extracts all variables defined in the query.

Returns:
A list of GraphQLQueryVariable objects.
"""
response = []
for definition in self.document.definitions:
variable_definitions = getattr(definition, "variable_definitions", None)
Expand Down Expand Up @@ -99,16 +129,32 @@ def variables(self) -> list[GraphQLQueryVariable]:
return response

async def calculate_depth(self) -> int:
"""Number of nested levels in the query"""
"""Calculates the maximum depth of nesting in the query's selection sets.

Returns:
The maximum depth of the query.
"""
fields = await self.get_fields()
return calculate_dict_depth(data=fields)

async def calculate_height(self) -> int:
"""Total number of fields requested in the query"""
"""Calculates the total number of fields requested across all operations in the query.

Returns:
The total height (number of fields) of the query.
"""
fields = await self.get_fields()
return calculate_dict_height(data=fields)

async def get_fields(self) -> dict[str, Any]:
"""Extracts all fields requested in the query.

This method parses the document definitions and extracts fields from
OperationDefinitionNode instances.

Returns:
A dictionary representing the fields structure.
"""
if not self._fields:
fields = {}
for definition in self.document.definitions:
Expand Down
37 changes: 37 additions & 0 deletions infrahub_sdk/async_typer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,25 @@


class AsyncTyper(Typer):
"""
A Typer subclass that allows to run async functions.

It overrides the `callback` and `command` decorators to wrap async functions
in `asyncio.run`.
"""

@staticmethod
def maybe_run_async(decorator: Callable, func: Callable) -> Any:
"""
Wraps an async function in `asyncio.run` if it's a coroutine function.

Args:
decorator: The decorator to apply (e.g., from `super().command`).
func: The function to potentially wrap.

Returns:
The decorated function, possibly wrapped to run asyncio.
"""
if inspect.iscoroutinefunction(func):

@wraps(func)
Expand All @@ -23,9 +40,29 @@ def runner(*args: Any, **kwargs: Any) -> Any:
return func

def callback(self, *args: Any, **kwargs: Any) -> Any:
"""
Overrides the Typer.callback decorator to support async functions.

Args:
*args: Positional arguments for Typer.callback.
**kwargs: Keyword arguments for Typer.callback.

Returns:
A decorator that can handle both sync and async callback functions.
"""
decorator = super().callback(*args, **kwargs)
return partial(self.maybe_run_async, decorator)

def command(self, *args: Any, **kwargs: Any) -> Any:
"""
Overrides the Typer.command decorator to support async functions.

Args:
*args: Positional arguments for Typer.command.
**kwargs: Keyword arguments for Typer.command.

Returns:
A decorator that can handle both sync and async command functions.
"""
decorator = super().command(*args, **kwargs)
return partial(self.maybe_run_async, decorator)
77 changes: 75 additions & 2 deletions infrahub_sdk/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

@dataclass
class BatchTask:
"""Represents a single asynchronous task in a batch."""
task: Callable[[Any], Awaitable[Any]]
args: tuple[Any, ...]
kwargs: dict[str, Any]
Expand All @@ -20,13 +21,24 @@ class BatchTask:

@dataclass
class BatchTaskSync:
"""Represents a single synchronous task in a batch."""
task: Callable[..., Any]
args: tuple[Any, ...]
kwargs: dict[str, Any]
node: InfrahubNodeSync | None = None

def execute(self, return_exceptions: bool = False) -> tuple[InfrahubNodeSync | None, Any]:
"""Executes the stored task."""
"""Executes the stored synchronous task.

Args:
return_exceptions: If True, exceptions are returned instead of raised.

Returns:
A tuple containing the task's node (if any) and the result or exception.

Raises:
Exception: If `return_exceptions` is False and the task raises an exception.
"""
result = None
try:
result = self.task(*self.args, **self.kwargs)
Expand All @@ -41,6 +53,16 @@ def execute(self, return_exceptions: bool = False) -> tuple[InfrahubNodeSync | N
async def execute_batch_task_in_pool(
task: BatchTask, semaphore: asyncio.Semaphore, return_exceptions: bool = False
) -> tuple[InfrahubNode | None, Any]:
"""Executes a BatchTask within a semaphore-controlled pool.

Args:
task: The BatchTask to execute.
semaphore: An asyncio.Semaphore to limit concurrent executions.
return_exceptions: If True, exceptions are returned instead of raised.

Returns:
A tuple containing the task's node (if any) and the result or exception.
"""
async with semaphore:
try:
result = await task.task(*task.args, **task.kwargs)
Expand All @@ -53,24 +75,51 @@ async def execute_batch_task_in_pool(


class InfrahubBatch:
"""Manages and executes a batch of asynchronous tasks concurrently."""
def __init__(
self,
semaphore: asyncio.Semaphore | None = None,
max_concurrent_execution: int = 5,
return_exceptions: bool = False,
):
"""Initializes the InfrahubBatch.

Args:
semaphore: An asyncio.Semaphore to limit concurrent executions.
If None, a new one is created with `max_concurrent_execution`.
max_concurrent_execution: The maximum number of tasks to run concurrently.
Only used if `semaphore` is None.
return_exceptions: If True, exceptions from tasks are returned instead of raised.
"""
self._tasks: list[BatchTask] = []
self.semaphore = semaphore or asyncio.Semaphore(value=max_concurrent_execution)
self.return_exceptions = return_exceptions

@property
def num_tasks(self) -> int:
"""Returns the number of tasks currently in the batch."""
return len(self._tasks)

def add(self, *args: Any, task: Callable, node: Any | None = None, **kwargs: Any) -> None:
"""Adds a new task to the batch.

Args:
task: The callable to be executed.
node: An optional node associated with this task.
*args: Positional arguments to pass to the task.
**kwargs: Keyword arguments to pass to the task.
"""
self._tasks.append(BatchTask(task=task, node=node, args=args, kwargs=kwargs))

async def execute(self) -> AsyncGenerator:
async def execute(self) -> AsyncGenerator[tuple[InfrahubNode | None, Any], None, None]:
"""Executes all tasks in the batch concurrently.

Yields:
A tuple containing the task's node (if any) and the result or exception.

Raises:
Exception: If `return_exceptions` is False and a task raises an exception.
"""
tasks = []

for batch_task in self._tasks:
Expand All @@ -90,19 +139,43 @@ async def execute(self) -> AsyncGenerator:


class InfrahubBatchSync:
"""Manages and executes a batch of synchronous tasks concurrently using a thread pool."""
def __init__(self, max_concurrent_execution: int = 5, return_exceptions: bool = False):
"""Initializes the InfrahubBatchSync.

Args:
max_concurrent_execution: The maximum number of tasks to run concurrently in the thread pool.
return_exceptions: If True, exceptions from tasks are returned instead of raised.
"""
self._tasks: list[BatchTaskSync] = []
self.max_concurrent_execution = max_concurrent_execution
self.return_exceptions = return_exceptions

@property
def num_tasks(self) -> int:
"""Returns the number of tasks currently in the batch."""
return len(self._tasks)

def add(self, *args: Any, task: Callable[..., Any], node: Any | None = None, **kwargs: Any) -> None:
"""Adds a new synchronous task to the batch.

Args:
task: The callable to be executed.
node: An optional node associated with this task.
*args: Positional arguments to pass to the task.
**kwargs: Keyword arguments to pass to the task.
"""
self._tasks.append(BatchTaskSync(task=task, node=node, args=args, kwargs=kwargs))

def execute(self) -> Generator[tuple[InfrahubNodeSync | None, Any], None, None]:
"""Executes all tasks in the batch concurrently using a ThreadPoolExecutor.

Yields:
A tuple containing the task's node (if any) and the result or exception.

Raises:
Exception: If `return_exceptions` is False and a task raises an exception.
"""
with ThreadPoolExecutor(max_workers=self.max_concurrent_execution) as executor:
futures = [executor.submit(task.execute, return_exceptions=self.return_exceptions) for task in self._tasks]
for future in futures:
Expand Down
Loading
Loading