Skip to content

Fix infrahub/#8 #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Oct 17, 2024
1 change: 1 addition & 0 deletions changelog/4143.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use InfrahubClient to communicate with infrahub API for 'infrahubctl render' and 'infrahubctl transform' commands.
92 changes: 71 additions & 21 deletions infrahub_sdk/ctl/cli_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from infrahub_sdk import __version__ as sdk_version
from infrahub_sdk import protocols as sdk_protocols
from infrahub_sdk.async_typer import AsyncTyper
from infrahub_sdk.client import InfrahubClient
from infrahub_sdk.ctl import config
from infrahub_sdk.ctl.branch import app as branch_app
from infrahub_sdk.ctl.check import run as run_check
Expand All @@ -28,11 +29,18 @@
from infrahub_sdk.ctl.repository import get_repository_config
from infrahub_sdk.ctl.schema import app as schema
from infrahub_sdk.ctl.transform import list_transforms
from infrahub_sdk.ctl.utils import catch_exception, execute_graphql_query, parse_cli_vars
from infrahub_sdk.ctl.utils import catch_exception, parse_cli_vars
from infrahub_sdk.ctl.validate import app as validate_app
from infrahub_sdk.exceptions import GraphQLError, InfrahubTransformNotFoundError
from infrahub_sdk.jinja2 import identify_faulty_jinja_code
from infrahub_sdk.schema import AttributeSchema, GenericSchema, InfrahubRepositoryConfig, NodeSchema, RelationshipSchema
from infrahub_sdk.schema import (
AttributeSchema,
GenericSchema,
InfrahubRepositoryConfig,
InfrahubRepositoryGraphQLConfig,
NodeSchema,
RelationshipSchema,
)
from infrahub_sdk.transforms import get_transform_class_instance
from infrahub_sdk.utils import get_branch, write_to_file

Expand Down Expand Up @@ -187,19 +195,34 @@ def render_jinja2_template(template_path: Path, variables: dict[str, str], data:


def _run_transform(
query: str,
query_name: str,
client: InfrahubClient,
variables: dict[str, Any],
transformer: Callable,
transform_func: Callable,
branch: str,
debug: bool,
repository_config: InfrahubRepositoryConfig,
):
"""
Query GraphQL for the required data then run a transform on that data.

Args:
query_name: Name of the query to load.
client: InfrahubClient object used to execute a graphql query against the infrahub API
variables: Dictionary of variables used for graphql query
transform_func: A function used to transform the return from the graphql query into a different form
branch: Name of the *infrahub* branch that should be queried for data
debug: Prints debug info to the command line
repository_config: Repository config object. This is used to load the graphql query from the repository.
"""
branch = get_branch(branch)
query_str = repository_config.get_query(name=query_name).load_query()

try:
response = execute_graphql_query(
query=query, variables_dict=variables, branch=branch, debug=debug, repository_config=repository_config
)
response = client.execute_graphql(query=query_str, variables=variables, branch_name=branch)
if debug:
message = ("-" * 40, f"Response for GraphQL Query {query_name}", response, "-" * 40)
console.print("\n".join(message))
except QueryNotFoundError as exc:
console.print(f"[red]Unable to find query : {exc}")
raise typer.Exit(1) from exc
Expand All @@ -214,10 +237,10 @@ def _run_transform(
console.print("[yellow] you can specify a different branch with --branch")
raise typer.Abort()

if asyncio.iscoroutinefunction(transformer.func):
output = asyncio.run(transformer(response))
if asyncio.iscoroutinefunction(transform_func):
output = asyncio.run(transform_func(response))
else:
output = transformer(response)
output = transform_func(response)
return output


Expand All @@ -243,23 +266,38 @@ def render(
list_jinja2_transforms(config=repository_config)
return

# Load transform config
try:
transform_config = repository_config.get_jinja2_transform(name=transform_name)
except KeyError as exc:
console.print(f'[red]Unable to find "{transform_name}" in {config.INFRAHUB_REPO_CONFIG_FILE}')
list_jinja2_transforms(config=repository_config)
raise typer.Exit(1) from exc

transformer = functools.partial(render_jinja2_template, transform_config.template_path, variables_dict)
# Load query config object and add to repository config
query_config_obj = InfrahubRepositoryGraphQLConfig(
name=transform_config.query, file_path=Path(transform_config.query + ".gql")
)
repository_config.queries.append(query_config_obj)

# Get client used to make call to API
client = initialize_client_sync()

# Construct transform function used to transform data returned from the API
transform_func = functools.partial(render_jinja2_template, transform_config.template_path, variables_dict)

# Query GQL and run the transform
result = _run_transform(
query=transform_config.query,
query_name=transform_config.query,
client=client,
variables=variables_dict,
transformer=transformer,
transform_func=transform_func,
branch=branch,
debug=debug,
repository_config=repository_config,
)

# Output data
if out:
write_to_file(Path(out), result)
else:
Expand Down Expand Up @@ -288,26 +326,38 @@ def transform(
list_transforms(config=repository_config)
return

matched = [transform for transform in repository_config.python_transforms if transform.name == transform_name] # pylint: disable=not-an-iterable

if not matched:
# Load transform config
try:
matched = [transform for transform in repository_config.python_transforms if transform.name == transform_name] # pylint: disable=not-an-iterable
if not matched:
raise ValueError(f"{transform_name} does not exist")
except ValueError as exc:
console.print(f"[red]Unable to find requested transform: {transform_name}")
list_transforms(config=repository_config)
return
raise typer.Exit(1) from exc

transform_config = matched[0]

# Get Infrahub Client
client = initialize_client_sync()

# Get python transform class instance
try:
transform_instance = get_transform_class_instance(transform_config=transform_config)
transform = get_transform_class_instance(transform_config=transform_config, branch=branch, client=client)
except InfrahubTransformNotFoundError as exc:
console.print(f"Unable to load {transform_name} from python_transforms")
raise typer.Exit(1) from exc

transformer = functools.partial(transform_instance.transform)
# Load query config
query_config_obj = InfrahubRepositoryGraphQLConfig(name=transform.query, file_path=Path(transform.query + ".gql"))
repository_config.queries.append(query_config_obj)

# Run Transformer
result = _run_transform(
query=transform_instance.query,
query_name=transform.query,
client=transform.client,
variables=variables_dict,
transformer=transformer,
transform_func=transform.transform,
branch=branch,
debug=debug,
repository_config=repository_config,
Expand Down
17 changes: 15 additions & 2 deletions infrahub_sdk/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,20 @@ async def run(self, data: Optional[dict] = None) -> Any:


def get_transform_class_instance(
transform_config: InfrahubPythonTransformConfig, search_path: Optional[Path] = None
transform_config: InfrahubPythonTransformConfig,
search_path: Optional[Path] = None,
client: Optional[InfrahubClient] = None,
branch: str = "",
) -> InfrahubTransform:
"""Gets an uninstantiated InfrahubTransform class.

Args:
transform_config: A config object with information required to find and load the transform.
search_path: The path in which to search for a python file containing the transform. The current directory is
assumed if not speicifed.
client: The infrahub client used to interact with infrahub's API.
branch: git branch in which t
"""
if transform_config.file_path.is_absolute() or search_path is None:
search_location = transform_config.file_path
else:
Expand All @@ -108,7 +120,8 @@ def get_transform_class_instance(
transform_class = getattr(module, transform_config.class_name)

# Create an instance of the class
transform_instance = transform_class()
transform_instance = asyncio.run(transform_class.init(client=client, branch=branch))

except (FileNotFoundError, AttributeError) as exc:
raise InfrahubTransformNotFoundError(name=transform_config.name) from exc

Expand Down