Skip to content

Fix infrahub/#8 #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Oct 17, 2024
1 change: 1 addition & 0 deletions changelog/8.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make `infrahubctl transform` command set up the InfrahubTransform class with an InfrahubClient instance
12 changes: 6 additions & 6 deletions infrahub_sdk/ctl/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async def list_branch(_: str = CONFIG_PARAM) -> None:

logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL)

client = await initialize_client()
client = initialize_client()
branches = await client.branch.all()

table = Table(title="List of all branches")
Expand Down Expand Up @@ -91,7 +91,7 @@ async def create(

logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL)

client = await initialize_client()
client = initialize_client()
branch = await client.branch.create(branch_name=branch_name, description=description, sync_with_git=sync_with_git)
console.print(f"Branch {branch_name!r} created successfully ({branch.id}).")

Expand All @@ -103,7 +103,7 @@ async def delete(branch_name: str, _: str = CONFIG_PARAM) -> None:

logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL)

client = await initialize_client()
client = initialize_client()
await client.branch.delete(branch_name=branch_name)
console.print(f"Branch '{branch_name}' deleted successfully.")

Expand All @@ -115,7 +115,7 @@ async def rebase(branch_name: str, _: str = CONFIG_PARAM) -> None:

logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL)

client = await initialize_client()
client = initialize_client()
await client.branch.rebase(branch_name=branch_name)
console.print(f"Branch '{branch_name}' rebased successfully.")

Expand All @@ -127,7 +127,7 @@ async def merge(branch_name: str, _: str = CONFIG_PARAM) -> None:

logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL)

client = await initialize_client()
client = initialize_client()
await client.branch.merge(branch_name=branch_name)
console.print(f"Branch '{branch_name}' merged successfully.")

Expand All @@ -137,6 +137,6 @@ async def merge(branch_name: str, _: str = CONFIG_PARAM) -> None:
async def validate(branch_name: str, _: str = CONFIG_PARAM) -> None:
"""Validate if a branch has some conflict and is passing all the tests (NOT IMPLEMENTED YET)."""

client = await initialize_client()
client = initialize_client()
await client.branch.validate(branch_name=branch_name)
console.print(f"Branch '{branch_name}' is valid.")
85 changes: 61 additions & 24 deletions infrahub_sdk/ctl/cli_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@
from infrahub_sdk.ctl.validate import app as validate_app
from infrahub_sdk.exceptions import GraphQLError, InfrahubTransformNotFoundError
from infrahub_sdk.jinja2 import identify_faulty_jinja_code
from infrahub_sdk.schema import AttributeSchema, GenericSchema, InfrahubRepositoryConfig, NodeSchema, RelationshipSchema
from infrahub_sdk.schema import (
AttributeSchema,
GenericSchema,
InfrahubRepositoryConfig,
NodeSchema,
RelationshipSchema,
)
from infrahub_sdk.transforms import get_transform_class_instance
from infrahub_sdk.utils import get_branch, write_to_file

Expand Down Expand Up @@ -149,7 +155,7 @@ async def run(
if not hasattr(module, method):
raise typer.Abort(f"Unable to Load the method {method} in the Python script at {script}")

client = await initialize_client(
client = initialize_client(
branch=branch, timeout=timeout, max_concurrent_execution=concurrent, identifier=module_name
)
func = getattr(module, method)
Expand Down Expand Up @@ -187,19 +193,35 @@ def render_jinja2_template(template_path: Path, variables: dict[str, str], data:


def _run_transform(
query: str,
query_name: str,
variables: dict[str, Any],
transformer: Callable,
transform_func: Callable,
branch: str,
debug: bool,
repository_config: InfrahubRepositoryConfig,
):
"""
Query GraphQL for the required data then run a transform on that data.

Args:
query_name: Name of the query to load (e.g. tags_query)
variables: Dictionary of variables used for graphql query
transformer_func: The function responsible for transforming data received from graphql
branch: Name of the *infrahub* branch that should be queried for data
debug: Prints debug info to the command line
repository_config: Repository config object. This is used to load the graphql query from the repository.
"""
branch = get_branch(branch)

try:
response = execute_graphql_query(
query=query, variables_dict=variables, branch=branch, debug=debug, repository_config=repository_config
query=query_name, variables_dict=variables, branch=branch, debug=debug, repository_config=repository_config
)

# TODO: response is a dict and can't be printed to the console in this way.
# if debug:
# message = ("-" * 40, f"Response for GraphQL Query {query_name}", response, "-" * 40)
# console.print("\n".join(message))
except QueryNotFoundError as exc:
console.print(f"[red]Unable to find query : {exc}")
raise typer.Exit(1) from exc
Expand All @@ -214,10 +236,10 @@ def _run_transform(
console.print("[yellow] you can specify a different branch with --branch")
raise typer.Abort()

if asyncio.iscoroutinefunction(transformer.func):
output = asyncio.run(transformer(response))
if asyncio.iscoroutinefunction(transform_func):
output = asyncio.run(transform_func(response))
else:
output = transformer(response)
output = transform_func(response)
return output


Expand All @@ -243,23 +265,28 @@ def render(
list_jinja2_transforms(config=repository_config)
return

# Load transform config
try:
transform_config = repository_config.get_jinja2_transform(name=transform_name)
except KeyError as exc:
console.print(f'[red]Unable to find "{transform_name}" in {config.INFRAHUB_REPO_CONFIG_FILE}')
list_jinja2_transforms(config=repository_config)
raise typer.Exit(1) from exc

transformer = functools.partial(render_jinja2_template, transform_config.template_path, variables_dict)
# Construct transform function used to transform data returned from the API
transform_func = functools.partial(render_jinja2_template, transform_config.template_path, variables_dict)

# Query GQL and run the transform
result = _run_transform(
query=transform_config.query,
query_name=transform_config.query,
variables=variables_dict,
transformer=transformer,
transform_func=transform_func,
branch=branch,
debug=debug,
repository_config=repository_config,
)

# Output data
if out:
write_to_file(Path(out), result)
else:
Expand Down Expand Up @@ -288,31 +315,41 @@ def transform(
list_transforms(config=repository_config)
return

matched = [transform for transform in repository_config.python_transforms if transform.name == transform_name] # pylint: disable=not-an-iterable

if not matched:
# Load transform config
try:
matched = [transform for transform in repository_config.python_transforms if transform.name == transform_name] # pylint: disable=not-an-iterable
if not matched:
raise ValueError(f"{transform_name} does not exist")
except ValueError as exc:
console.print(f"[red]Unable to find requested transform: {transform_name}")
list_transforms(config=repository_config)
return
raise typer.Exit(1) from exc

transform_config = matched[0]

# Get client
client = initialize_client()

# Get python transform class instance
try:
transform_instance = get_transform_class_instance(transform_config=transform_config)
transform = get_transform_class_instance(
transform_config=transform_config,
branch=branch,
client=client,
)
except InfrahubTransformNotFoundError as exc:
console.print(f"Unable to load {transform_name} from python_transforms")
raise typer.Exit(1) from exc

transformer = functools.partial(transform_instance.transform)
result = _run_transform(
query=transform_instance.query,
variables=variables_dict,
transformer=transformer,
branch=branch,
debug=debug,
repository_config=repository_config,
# Get data
query_str = repository_config.get_query(name=transform.query).load_query()
data = asyncio.run(
transform.client.execute_graphql(query=query_str, variables=variables_dict, branch_name=transform.branch_name)
)

# Run Transform
result = asyncio.run(transform.run(data=data))

json_string = ujson.dumps(result, indent=2, sort_keys=True)
if out:
write_to_file(Path(out), json_string)
Expand Down
2 changes: 1 addition & 1 deletion infrahub_sdk/ctl/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from infrahub_sdk.ctl import config


async def initialize_client(
def initialize_client(
branch: Optional[str] = None,
identifier: Optional[str] = None,
timeout: Optional[int] = None,
Expand Down
2 changes: 1 addition & 1 deletion infrahub_sdk/ctl/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async def run(
if param_key:
identifier = param_key[0]

client = await initialize_client()
client = initialize_client()
if variables_dict:
data = execute_graphql_query(
query=generator_config.query,
Expand Down
2 changes: 1 addition & 1 deletion infrahub_sdk/ctl/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async def add(
},
}

client = await initialize_client()
client = initialize_client()

if username:
credential = await client.create(kind="CorePasswordCredential", name=name, username=username, password=password)
Expand Down
4 changes: 2 additions & 2 deletions infrahub_sdk/ctl/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ async def load(

schemas_data = load_schemas_from_disk_and_exit(schemas=schemas)
schema_definition = "schema" if len(schemas_data) == 1 else "schemas"
client = await initialize_client()
client = initialize_client()
validate_schema_content_and_exit(client=client, schemas=schemas_data)

start_time = time.time()
Expand Down Expand Up @@ -204,7 +204,7 @@ async def check(
init_logging(debug=debug)

schemas_data = load_schemas_from_disk_and_exit(schemas=schemas)
client = await initialize_client()
client = initialize_client()
validate_schema_content_and_exit(client=client, schemas=schemas_data)

success, response = await client.schema.check(schemas=[item.content for item in schemas_data], branch=branch)
Expand Down
2 changes: 1 addition & 1 deletion infrahub_sdk/ctl/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def validate_schema(schema: Path, _: str = CONFIG_PARAM) -> None:
console.print("[red]Invalid JSON file")
raise typer.Exit(1) from exc

client = await initialize_client()
client = initialize_client()

try:
client.schema.validate(schema_data)
Expand Down
60 changes: 48 additions & 12 deletions infrahub_sdk/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import importlib
import os
import warnings
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Optional

Expand All @@ -25,33 +26,47 @@ class InfrahubTransform:
query: str
timeout: int = 10

def __init__(self, branch: str = "", root_directory: str = "", server_url: str = ""):
def __init__(
self,
branch: str = "",
root_directory: str = "",
server_url: str = "",
client: Optional[InfrahubClient] = None,
):
self.git: Repo

self.branch = branch

self.server_url = server_url or os.environ.get("INFRAHUB_URL", "http://127.0.0.1:8000")
self.root_directory = root_directory or os.getcwd()

self.client: InfrahubClient
self._client = client

if not self.name:
self.name = self.__class__.__name__

if not self.query:
raise ValueError("A query must be provided")

@property
def client(self) -> InfrahubClient:
if not self._client:
self._client = InfrahubClient(address=self.server_url)

return self._client

@classmethod
async def init(cls, client: Optional[InfrahubClient] = None, *args: Any, **kwargs: Any) -> InfrahubTransform:
"""Async init method, If an existing InfrahubClient client hasn't been provided, one will be created automatically."""
warnings.warn(
f"{cls.__class__.__name__}.init has been deprecated and will be removed in Infrahub SDK 0.15.0 or the next major version",
DeprecationWarning,
stacklevel=1,
)
if client:
kwargs["client"] = client

item = cls(*args, **kwargs)

if client:
item.client = client
else:
item.client = InfrahubClient(address=item.server_url)

return item

@property
Expand All @@ -61,7 +76,7 @@ def branch_name(self) -> str:
if self.branch:
return self.branch

if not self.git:
if not hasattr(self, "git") or not self.git:
self.git = Repo(self.root_directory)

self.branch = str(self.git.active_branch)
Expand All @@ -79,10 +94,18 @@ async def collect_data(self) -> dict:

async def run(self, data: Optional[dict] = None) -> Any:
"""Execute the transformation after collecting the data from the GraphQL query.
The result of the check is determined based on the presence or not of ERROR log messages."""

The result of the check is determined based on the presence or not of ERROR log messages.

Args:
data: The data on which to run the transform. Data will be queried from the API if not provided

Returns: Transformed data
"""

if not data:
data = await self.collect_data()

unpacked = data.get("data") or data

if asyncio.iscoroutinefunction(self.transform):
Expand All @@ -92,8 +115,20 @@ async def run(self, data: Optional[dict] = None) -> Any:


def get_transform_class_instance(
transform_config: InfrahubPythonTransformConfig, search_path: Optional[Path] = None
transform_config: InfrahubPythonTransformConfig,
search_path: Optional[Path] = None,
branch: str = "",
client: Optional[InfrahubClient] = None,
) -> InfrahubTransform:
"""Gets an instance of the InfrahubTransform class.

Args:
transform_config: A config object with information required to find and load the transform.
search_path: The path in which to search for a python file containing the transform. The current directory is
assumed if not speicifed.
branch: Infrahub branch which will be targeted in graphql query used to acquire data for transformation.
client: InfrahubClient used to interact with infrahub API.
"""
if transform_config.file_path.is_absolute() or search_path is None:
search_location = transform_config.file_path
else:
Expand All @@ -108,7 +143,8 @@ def get_transform_class_instance(
transform_class = getattr(module, transform_config.class_name)

# Create an instance of the class
transform_instance = transform_class()
transform_instance = transform_class(branch=branch, client=client)

except (FileNotFoundError, AttributeError) as exc:
raise InfrahubTransformNotFoundError(name=transform_config.name) from exc

Expand Down