diff --git a/changelog/8.fixed.md b/changelog/8.fixed.md new file mode 100644 index 00000000..a25e1b44 --- /dev/null +++ b/changelog/8.fixed.md @@ -0,0 +1 @@ +Make `infrahubctl transform` command set up the InfrahubTransform class with an InfrahubClient instance \ No newline at end of file diff --git a/infrahub_sdk/ctl/branch.py b/infrahub_sdk/ctl/branch.py index 96b164ec..bc7013df 100644 --- a/infrahub_sdk/ctl/branch.py +++ b/infrahub_sdk/ctl/branch.py @@ -34,7 +34,7 @@ async def list_branch(_: str = CONFIG_PARAM) -> None: logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL) - client = await initialize_client() + client = initialize_client() branches = await client.branch.all() table = Table(title="List of all branches") @@ -91,7 +91,7 @@ async def create( logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL) - client = await initialize_client() + client = initialize_client() branch = await client.branch.create(branch_name=branch_name, description=description, sync_with_git=sync_with_git) console.print(f"Branch {branch_name!r} created successfully ({branch.id}).") @@ -103,7 +103,7 @@ async def delete(branch_name: str, _: str = CONFIG_PARAM) -> None: logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL) - client = await initialize_client() + client = initialize_client() await client.branch.delete(branch_name=branch_name) console.print(f"Branch '{branch_name}' deleted successfully.") @@ -115,7 +115,7 @@ async def rebase(branch_name: str, _: str = CONFIG_PARAM) -> None: logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL) - client = await initialize_client() + client = initialize_client() await client.branch.rebase(branch_name=branch_name) console.print(f"Branch '{branch_name}' rebased successfully.") @@ -127,7 +127,7 @@ async def merge(branch_name: str, _: str = CONFIG_PARAM) -> None: logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL) - client = await initialize_client() + client = initialize_client() await client.branch.merge(branch_name=branch_name) console.print(f"Branch '{branch_name}' merged successfully.") @@ -137,6 +137,6 @@ async def merge(branch_name: str, _: str = CONFIG_PARAM) -> None: async def validate(branch_name: str, _: str = CONFIG_PARAM) -> None: """Validate if a branch has some conflict and is passing all the tests (NOT IMPLEMENTED YET).""" - client = await initialize_client() + client = initialize_client() await client.branch.validate(branch_name=branch_name) console.print(f"Branch '{branch_name}' is valid.") diff --git a/infrahub_sdk/ctl/cli_commands.py b/infrahub_sdk/ctl/cli_commands.py index 498e7b0d..712118af 100644 --- a/infrahub_sdk/ctl/cli_commands.py +++ b/infrahub_sdk/ctl/cli_commands.py @@ -153,7 +153,7 @@ async def run( if not hasattr(module, method): raise typer.Abort(f"Unable to Load the method {method} in the Python script at {script}") - client = await initialize_client( + client = initialize_client( branch=branch, timeout=timeout, max_concurrent_execution=concurrent, identifier=module_name ) func = getattr(module, method) @@ -191,19 +191,35 @@ def render_jinja2_template(template_path: Path, variables: dict[str, str], data: def _run_transform( - query: str, + query_name: str, variables: dict[str, Any], - transformer: Callable, + transform_func: Callable, branch: str, debug: bool, repository_config: InfrahubRepositoryConfig, ): + """ + Query GraphQL for the required data then run a transform on that data. + + Args: + query_name: Name of the query to load (e.g. tags_query) + variables: Dictionary of variables used for graphql query + transform_func: The function responsible for transforming data received from graphql + branch: Name of the *infrahub* branch that should be queried for data + debug: Prints debug info to the command line + repository_config: Repository config object. This is used to load the graphql query from the repository. + """ branch = get_branch(branch) try: response = execute_graphql_query( - query=query, variables_dict=variables, branch=branch, debug=debug, repository_config=repository_config + query=query_name, variables_dict=variables, branch=branch, debug=debug, repository_config=repository_config ) + + # TODO: response is a dict and can't be printed to the console in this way. + # if debug: + # message = ("-" * 40, f"Response for GraphQL Query {query_name}", response, "-" * 40) + # console.print("\n".join(message)) except QueryNotFoundError as exc: console.print(f"[red]Unable to find query : {exc}") raise typer.Exit(1) from exc @@ -218,10 +234,10 @@ def _run_transform( console.print("[yellow] you can specify a different branch with --branch") raise typer.Abort() - if asyncio.iscoroutinefunction(transformer.func): - output = asyncio.run(transformer(response)) + if asyncio.iscoroutinefunction(transform_func): + output = asyncio.run(transform_func(response)) else: - output = transformer(response) + output = transform_func(response) return output @@ -247,6 +263,7 @@ def render( list_jinja2_transforms(config=repository_config) return + # Load transform config try: transform_config = repository_config.get_jinja2_transform(name=transform_name) except KeyError as exc: @@ -254,16 +271,20 @@ def render( list_jinja2_transforms(config=repository_config) raise typer.Exit(1) from exc - transformer = functools.partial(render_jinja2_template, transform_config.template_path, variables_dict) + # Construct transform function used to transform data returned from the API + transform_func = functools.partial(render_jinja2_template, transform_config.template_path, variables_dict) + + # Query GQL and run the transform result = _run_transform( - query=transform_config.query, + query_name=transform_config.query, variables=variables_dict, - transformer=transformer, + transform_func=transform_func, branch=branch, debug=debug, repository_config=repository_config, ) + # Output data if out: write_to_file(Path(out), result) else: @@ -292,31 +313,41 @@ def transform( list_transforms(config=repository_config) return - matched = [transform for transform in repository_config.python_transforms if transform.name == transform_name] # pylint: disable=not-an-iterable - - if not matched: + # Load transform config + try: + matched = [transform for transform in repository_config.python_transforms if transform.name == transform_name] # pylint: disable=not-an-iterable + if not matched: + raise ValueError(f"{transform_name} does not exist") + except ValueError as exc: console.print(f"[red]Unable to find requested transform: {transform_name}") list_transforms(config=repository_config) - return + raise typer.Exit(1) from exc transform_config = matched[0] + # Get client + client = initialize_client() + + # Get python transform class instance try: - transform_instance = get_transform_class_instance(transform_config=transform_config) + transform = get_transform_class_instance( + transform_config=transform_config, + branch=branch, + client=client, + ) except InfrahubTransformNotFoundError as exc: console.print(f"Unable to load {transform_name} from python_transforms") raise typer.Exit(1) from exc - transformer = functools.partial(transform_instance.transform) - result = _run_transform( - query=transform_instance.query, - variables=variables_dict, - transformer=transformer, - branch=branch, - debug=debug, - repository_config=repository_config, + # Get data + query_str = repository_config.get_query(name=transform.query).load_query() + data = asyncio.run( + transform.client.execute_graphql(query=query_str, variables=variables_dict, branch_name=transform.branch_name) ) + # Run Transform + result = asyncio.run(transform.run(data=data)) + json_string = ujson.dumps(result, indent=2, sort_keys=True) if out: write_to_file(Path(out), json_string) diff --git a/infrahub_sdk/ctl/client.py b/infrahub_sdk/ctl/client.py index 49e32f65..5699ec47 100644 --- a/infrahub_sdk/ctl/client.py +++ b/infrahub_sdk/ctl/client.py @@ -5,7 +5,7 @@ from infrahub_sdk.ctl import config -async def initialize_client( +def initialize_client( branch: Optional[str] = None, identifier: Optional[str] = None, timeout: Optional[int] = None, diff --git a/infrahub_sdk/ctl/generator.py b/infrahub_sdk/ctl/generator.py index a4d0de23..a66ad206 100644 --- a/infrahub_sdk/ctl/generator.py +++ b/infrahub_sdk/ctl/generator.py @@ -43,7 +43,7 @@ async def run( if param_key: identifier = param_key[0] - client = await initialize_client() + client = initialize_client() if variables_dict: data = execute_graphql_query( query=generator_config.query, diff --git a/infrahub_sdk/ctl/repository.py b/infrahub_sdk/ctl/repository.py index f3b26a61..61ffeea9 100644 --- a/infrahub_sdk/ctl/repository.py +++ b/infrahub_sdk/ctl/repository.py @@ -88,7 +88,7 @@ async def add( }, } - client = await initialize_client() + client = initialize_client() if username: credential = await client.create(kind="CorePasswordCredential", name=name, username=username, password=password) diff --git a/infrahub_sdk/ctl/schema.py b/infrahub_sdk/ctl/schema.py index 57ecc379..9d7aa84f 100644 --- a/infrahub_sdk/ctl/schema.py +++ b/infrahub_sdk/ctl/schema.py @@ -155,7 +155,7 @@ async def load( schemas_data = load_schemas_from_disk_and_exit(schemas=schemas) schema_definition = "schema" if len(schemas_data) == 1 else "schemas" - client = await initialize_client() + client = initialize_client() validate_schema_content_and_exit(client=client, schemas=schemas_data) start_time = time.time() @@ -204,7 +204,7 @@ async def check( init_logging(debug=debug) schemas_data = load_schemas_from_disk_and_exit(schemas=schemas) - client = await initialize_client() + client = initialize_client() validate_schema_content_and_exit(client=client, schemas=schemas_data) success, response = await client.schema.check(schemas=[item.content for item in schemas_data], branch=branch) diff --git a/infrahub_sdk/ctl/validate.py b/infrahub_sdk/ctl/validate.py index f402b487..b4b593da 100644 --- a/infrahub_sdk/ctl/validate.py +++ b/infrahub_sdk/ctl/validate.py @@ -40,7 +40,7 @@ async def validate_schema(schema: Path, _: str = CONFIG_PARAM) -> None: console.print("[red]Invalid JSON file") raise typer.Exit(1) from exc - client = await initialize_client() + client = initialize_client() try: client.schema.validate(schema_data) diff --git a/infrahub_sdk/transforms.py b/infrahub_sdk/transforms.py index f558ad22..bb277068 100644 --- a/infrahub_sdk/transforms.py +++ b/infrahub_sdk/transforms.py @@ -3,6 +3,7 @@ import asyncio import importlib import os +import warnings from abc import abstractmethod from typing import TYPE_CHECKING, Any, Optional @@ -25,15 +26,20 @@ class InfrahubTransform: query: str timeout: int = 10 - def __init__(self, branch: str = "", root_directory: str = "", server_url: str = ""): + def __init__( + self, + branch: str = "", + root_directory: str = "", + server_url: str = "", + client: Optional[InfrahubClient] = None, + ): self.git: Repo self.branch = branch - self.server_url = server_url or os.environ.get("INFRAHUB_URL", "http://127.0.0.1:8000") self.root_directory = root_directory or os.getcwd() - self.client: InfrahubClient + self._client = client if not self.name: self.name = self.__class__.__name__ @@ -41,17 +47,26 @@ def __init__(self, branch: str = "", root_directory: str = "", server_url: str = if not self.query: raise ValueError("A query must be provided") + @property + def client(self) -> InfrahubClient: + if not self._client: + self._client = InfrahubClient(address=self.server_url) + + return self._client + @classmethod async def init(cls, client: Optional[InfrahubClient] = None, *args: Any, **kwargs: Any) -> InfrahubTransform: """Async init method, If an existing InfrahubClient client hasn't been provided, one will be created automatically.""" + warnings.warn( + f"{cls.__class__.__name__}.init has been deprecated and will be removed in Infrahub SDK 0.15.0 or the next major version", + DeprecationWarning, + stacklevel=1, + ) + if client: + kwargs["client"] = client item = cls(*args, **kwargs) - if client: - item.client = client - else: - item.client = InfrahubClient(address=item.server_url) - return item @property @@ -61,7 +76,7 @@ def branch_name(self) -> str: if self.branch: return self.branch - if not self.git: + if not hasattr(self, "git") or not self.git: self.git = Repo(self.root_directory) self.branch = str(self.git.active_branch) @@ -79,10 +94,18 @@ async def collect_data(self) -> dict: async def run(self, data: Optional[dict] = None) -> Any: """Execute the transformation after collecting the data from the GraphQL query. - The result of the check is determined based on the presence or not of ERROR log messages.""" + + The result of the check is determined based on the presence or not of ERROR log messages. + + Args: + data: The data on which to run the transform. Data will be queried from the API if not provided + + Returns: Transformed data + """ if not data: data = await self.collect_data() + unpacked = data.get("data") or data if asyncio.iscoroutinefunction(self.transform): @@ -92,8 +115,20 @@ async def run(self, data: Optional[dict] = None) -> Any: def get_transform_class_instance( - transform_config: InfrahubPythonTransformConfig, search_path: Optional[Path] = None + transform_config: InfrahubPythonTransformConfig, + search_path: Optional[Path] = None, + branch: str = "", + client: Optional[InfrahubClient] = None, ) -> InfrahubTransform: + """Gets an instance of the InfrahubTransform class. + + Args: + transform_config: A config object with information required to find and load the transform. + search_path: The path in which to search for a python file containing the transform. The current directory is + assumed if not speicifed. + branch: Infrahub branch which will be targeted in graphql query used to acquire data for transformation. + client: InfrahubClient used to interact with infrahub API. + """ if transform_config.file_path.is_absolute() or search_path is None: search_location = transform_config.file_path else: @@ -108,7 +143,8 @@ def get_transform_class_instance( transform_class = getattr(module, transform_config.class_name) # Create an instance of the class - transform_instance = transform_class() + transform_instance = transform_class(branch=branch, client=client) + except (FileNotFoundError, AttributeError) as exc: raise InfrahubTransformNotFoundError(name=transform_config.name) from exc diff --git a/tests/fixtures/integration/test_infrahubctl/tags_transform/.infrahub.yml b/tests/fixtures/integration/test_infrahubctl/tags_transform/.infrahub.yml new file mode 100644 index 00000000..e82ff786 --- /dev/null +++ b/tests/fixtures/integration/test_infrahubctl/tags_transform/.infrahub.yml @@ -0,0 +1,15 @@ +--- +python_transforms: + - name: tags_transform + class_name: TagsTransform + file_path: "tags_transform.py" + +queries: + - name: "tags_query" + file_path: "tags_query.gql" + +jinja2_transforms: + - name: my-jinja2-transform # Unique name for your transform + description: "short description" # (optional) + query: "tags_query" # Name or ID of the GraphQLQuery + template_path: "tags_tpl.j2" # Path to the main Jinja2 template diff --git a/tests/fixtures/integration/test_infrahubctl/tags_transform/__init__.py b/tests/fixtures/integration/test_infrahubctl/tags_transform/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fixtures/integration/test_infrahubctl/tags_transform/tags_query.gql b/tests/fixtures/integration/test_infrahubctl/tags_transform/tags_query.gql new file mode 100644 index 00000000..b270ecd2 --- /dev/null +++ b/tests/fixtures/integration/test_infrahubctl/tags_transform/tags_query.gql @@ -0,0 +1,14 @@ +query TagsQuery($tag: String!) { + BuiltinTag(name__value: $tag) { + edges { + node { + name { + value + } + description { + value + } + } + } + } +} \ No newline at end of file diff --git a/tests/fixtures/integration/test_infrahubctl/tags_transform/tags_tpl.j2 b/tests/fixtures/integration/test_infrahubctl/tags_transform/tags_tpl.j2 new file mode 100644 index 00000000..851038d7 --- /dev/null +++ b/tests/fixtures/integration/test_infrahubctl/tags_transform/tags_tpl.j2 @@ -0,0 +1,8 @@ +{% if data.BuiltinTag.edges and data.BuiltinTag.edges is iterable %} +{% for tag in data["BuiltinTag"]["edges"] %} +{% set tag_name = tag.node.name.value %} +{% set tag_description = tag.node.description.value %} +{{ tag_name }} + description: {{ tag_description }} +{% endfor %} +{% endif %} \ No newline at end of file diff --git a/tests/fixtures/integration/test_infrahubctl/tags_transform/tags_transform.py b/tests/fixtures/integration/test_infrahubctl/tags_transform/tags_transform.py new file mode 100644 index 00000000..78d59fb6 --- /dev/null +++ b/tests/fixtures/integration/test_infrahubctl/tags_transform/tags_transform.py @@ -0,0 +1,13 @@ +from infrahub_sdk.transforms import InfrahubTransform + + +class TagsTransform(InfrahubTransform): + query = "tags_query" + url = "my-tags" + + async def transform(self, data): + tag = data["BuiltinTag"]["edges"][0]["node"] + tag_name = tag["name"]["value"] + tag_description = tag["description"]["value"] + + return {"tag_title": tag_name.title(), "bold_description": f"*{tag_description}*".upper()} diff --git a/tests/fixtures/integration/test_infrahubctl/transform_cmd/case_success_api_return.json b/tests/fixtures/integration/test_infrahubctl/transform_cmd/case_success_api_return.json new file mode 100644 index 00000000..014a1bcc --- /dev/null +++ b/tests/fixtures/integration/test_infrahubctl/transform_cmd/case_success_api_return.json @@ -0,0 +1,18 @@ +{ + "data": { + "BuiltinTag": { + "edges": [ + { + "node": { + "name": { + "value": "red" + }, + "description": { + "value": null + } + } + } + ] + } + } +} \ No newline at end of file diff --git a/tests/fixtures/integration/test_infrahubctl/transform_cmd/case_success_output.txt b/tests/fixtures/integration/test_infrahubctl/transform_cmd/case_success_output.txt new file mode 100644 index 00000000..b4a08de4 --- /dev/null +++ b/tests/fixtures/integration/test_infrahubctl/transform_cmd/case_success_output.txt @@ -0,0 +1,4 @@ +{ + "bold_description": "*NONE*", + "tag_title": "Red" +} diff --git a/tests/integration/test_infrahubctl.py b/tests/integration/test_infrahubctl.py new file mode 100644 index 00000000..a63514f1 --- /dev/null +++ b/tests/integration/test_infrahubctl.py @@ -0,0 +1,124 @@ +"""Integration tests for infrahubctl commands.""" + +import json +import os +import shutil +import tempfile +from pathlib import Path + +import pytest +from git import Repo +from pytest_httpx._httpx_mock import HTTPXMock +from typer.testing import Any, CliRunner + +from infrahub_sdk.ctl.cli_commands import app + +from .utils import change_directory, strip_color + +runner = CliRunner() + + +FIXTURE_BASE_DIR = Path(Path(os.path.abspath(__file__)).parent / ".." / "fixtures" / "integration" / "test_infrahubctl") + + +def read_fixture(file_name: str, fixture_subdir: str = ".") -> Any: + """Read the contents of a fixture.""" + with Path(FIXTURE_BASE_DIR / fixture_subdir / file_name).open("r", encoding="utf-8") as fhd: + fixture_contents = fhd.read() + + return fixture_contents + + +@pytest.fixture +def tags_transform_dir(): + temp_dir = tempfile.mkdtemp() + + try: + fixture_path = Path(FIXTURE_BASE_DIR / "tags_transform") + shutil.copytree(fixture_path, temp_dir, dirs_exist_ok=True) + # Initialize fixture as git repo. This is necessary to run some infrahubctl commands. + with change_directory(temp_dir): + Repo.init(".") + + yield temp_dir + + finally: + shutil.rmtree(temp_dir) + + +# --------------------------------------------------------- +# infrahubctl transform command tests +# --------------------------------------------------------- + + +class TestInfrahubctlTransform: + """Groups the 'infrahubctl transform' test cases.""" + + @staticmethod + def test_transform_not_exist_in_infrahub_yml(tags_transform_dir: str) -> None: + """Case transform is not specified in the infrahub.yml file.""" + transform_name = "not_existing_transform" + with change_directory(tags_transform_dir): + output = runner.invoke(app, ["transform", transform_name, "tag=red"]) + assert f"Unable to find requested transform: {transform_name}" in output.stdout + assert output.exit_code == 1 + + @staticmethod + def test_transform_python_file_not_defined(tags_transform_dir: str) -> None: + """Case transform python file not defined.""" + # Remove transform file + transform_file = Path(Path(tags_transform_dir) / "tags_transform.py") + Path.unlink(transform_file) + + # Run command and make assertions + transform_name = "tags_transform" + with change_directory(tags_transform_dir): + output = runner.invoke(app, ["transform", transform_name, "tag=red"]) + assert f"Unable to load {transform_name} from python_transforms" in output.stdout + assert output.exit_code == 1 + + @staticmethod + def test_transform_python_class_not_defined(tags_transform_dir: str) -> None: + """Case transform python class not defined.""" + # Rename transform inside of python file so the class name searched for no longer exists + transform_file = Path(Path(tags_transform_dir) / "tags_transform.py") + with Path.open(transform_file, "r", encoding="utf-8") as fhd: + file_contents = fhd.read() + + with Path.open(transform_file, "w", encoding="utf-8") as fhd: + new_file_contents = file_contents.replace("TagsTransform", "FunTransform") + fhd.write(new_file_contents) + + # Run command and make assertions + transform_name = "tags_transform" + with change_directory(tags_transform_dir): + output = runner.invoke(app, ["transform", transform_name, "tag=red"]) + assert f"Unable to load {transform_name} from python_transforms" in output.stdout + assert output.exit_code == 1 + + @staticmethod + def test_gql_query_not_defined(tags_transform_dir: str) -> None: + """Case GraphQL Query is not defined""" + # Remove GraphQL Query file + gql_file = Path(Path(tags_transform_dir) / "tags_query.gql") + Path.unlink(gql_file) + + # Run command and make assertions + with change_directory(tags_transform_dir): + output = runner.invoke(app, ["transform", "tags_transform", "tag=red"]) + assert "FileNotFoundError" in output.stdout + assert output.exit_code == 1 + + @staticmethod + def test_infrahubctl_transform_cmd_success(httpx_mock: HTTPXMock, tags_transform_dir: str) -> None: + """Case infrahubctl transform command executes successfully""" + httpx_mock.add_response( + method="POST", + url="http://mock/graphql/main", + json=json.loads(read_fixture("case_success_api_return.json", "transform_cmd")), + ) + + with change_directory(tags_transform_dir): + output = runner.invoke(app, ["transform", "tags_transform", "tag=red"]) + assert strip_color(output.stdout) == read_fixture("case_success_output.txt", "transform_cmd") + assert output.exit_code == 0 diff --git a/tests/integration/utils.py b/tests/integration/utils.py new file mode 100644 index 00000000..9a8ee8a7 --- /dev/null +++ b/tests/integration/utils.py @@ -0,0 +1,27 @@ +"""Utility functions reused throughout integration tests.""" + +import os +import re +from contextlib import contextmanager +from typing import Generator + + +@contextmanager +def change_directory(new_directory: str) -> Generator[None, None, None]: + """Helper function used to change directories in a with block.""" + # Save the current working directory + original_directory = os.getcwd() + + # Change to the new directory + try: + os.chdir(new_directory) + yield # Yield control back to the with block + + finally: + # Change back to the original directory + os.chdir(original_directory) + + +def strip_color(text: str) -> str: + ansi_escape = re.compile(r"\x1B[@-_][0-?]*[ -/]*[@-~]") + return ansi_escape.sub("", text)