Skip to content

Commit 43879f6

Browse files
authored
Merge pull request #88 from opsmill/pog-merge-back-stable
Merge back stable into develop with resolved conflicts
2 parents b889934 + 37adb41 commit 43879f6

File tree

21 files changed

+341
-50
lines changed

21 files changed

+341
-50
lines changed

changelog/8.fixed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Make `infrahubctl transform` command set up the InfrahubTransform class with an InfrahubClient instance

infrahub_sdk/ctl/branch.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ async def list_branch(_: str = CONFIG_PARAM) -> None:
3434

3535
logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL)
3636

37-
client = await initialize_client()
37+
client = initialize_client()
3838
branches = await client.branch.all()
3939

4040
table = Table(title="List of all branches")
@@ -91,7 +91,7 @@ async def create(
9191

9292
logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL)
9393

94-
client = await initialize_client()
94+
client = initialize_client()
9595
branch = await client.branch.create(branch_name=branch_name, description=description, sync_with_git=sync_with_git)
9696
console.print(f"Branch {branch_name!r} created successfully ({branch.id}).")
9797

@@ -103,7 +103,7 @@ async def delete(branch_name: str, _: str = CONFIG_PARAM) -> None:
103103

104104
logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL)
105105

106-
client = await initialize_client()
106+
client = initialize_client()
107107
await client.branch.delete(branch_name=branch_name)
108108
console.print(f"Branch '{branch_name}' deleted successfully.")
109109

@@ -115,7 +115,7 @@ async def rebase(branch_name: str, _: str = CONFIG_PARAM) -> None:
115115

116116
logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL)
117117

118-
client = await initialize_client()
118+
client = initialize_client()
119119
await client.branch.rebase(branch_name=branch_name)
120120
console.print(f"Branch '{branch_name}' rebased successfully.")
121121

@@ -127,7 +127,7 @@ async def merge(branch_name: str, _: str = CONFIG_PARAM) -> None:
127127

128128
logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL)
129129

130-
client = await initialize_client()
130+
client = initialize_client()
131131
await client.branch.merge(branch_name=branch_name)
132132
console.print(f"Branch '{branch_name}' merged successfully.")
133133

@@ -137,6 +137,6 @@ async def merge(branch_name: str, _: str = CONFIG_PARAM) -> None:
137137
async def validate(branch_name: str, _: str = CONFIG_PARAM) -> None:
138138
"""Validate if a branch has some conflict and is passing all the tests (NOT IMPLEMENTED YET)."""
139139

140-
client = await initialize_client()
140+
client = initialize_client()
141141
await client.branch.validate(branch_name=branch_name)
142142
console.print(f"Branch '{branch_name}' is valid.")

infrahub_sdk/ctl/check.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ async def run_checks(
196196
log = logging.getLogger("infrahub")
197197

198198
check_summary: list[bool] = []
199-
client = await initialize_client()
199+
client = initialize_client()
200200
for check_module in check_modules:
201201
if check_module.definition.targets:
202202
result = await run_targeted_check(

infrahub_sdk/ctl/cli_commands.py

Lines changed: 54 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ async def run(
163163
if not hasattr(module, method):
164164
raise typer.Abort(f"Unable to Load the method {method} in the Python script at {script}")
165165

166-
client = await initialize_client(
166+
client = initialize_client(
167167
branch=branch, timeout=timeout, max_concurrent_execution=concurrent, identifier=module_name
168168
)
169169
func = getattr(module, method)
@@ -201,19 +201,35 @@ def render_jinja2_template(template_path: Path, variables: dict[str, str], data:
201201

202202

203203
def _run_transform(
204-
query: str,
204+
query_name: str,
205205
variables: dict[str, Any],
206-
transformer: Callable,
206+
transform_func: Callable,
207207
branch: str,
208208
debug: bool,
209209
repository_config: InfrahubRepositoryConfig,
210210
):
211+
"""
212+
Query GraphQL for the required data then run a transform on that data.
213+
214+
Args:
215+
query_name: Name of the query to load (e.g. tags_query)
216+
variables: Dictionary of variables used for graphql query
217+
transform_func: The function responsible for transforming data received from graphql
218+
branch: Name of the *infrahub* branch that should be queried for data
219+
debug: Prints debug info to the command line
220+
repository_config: Repository config object. This is used to load the graphql query from the repository.
221+
"""
211222
branch = get_branch(branch)
212223

213224
try:
214225
response = execute_graphql_query(
215-
query=query, variables_dict=variables, branch=branch, debug=debug, repository_config=repository_config
226+
query=query_name, variables_dict=variables, branch=branch, debug=debug, repository_config=repository_config
216227
)
228+
229+
# TODO: response is a dict and can't be printed to the console in this way.
230+
# if debug:
231+
# message = ("-" * 40, f"Response for GraphQL Query {query_name}", response, "-" * 40)
232+
# console.print("\n".join(message))
217233
except QueryNotFoundError as exc:
218234
console.print(f"[red]Unable to find query : {exc}")
219235
raise typer.Exit(1) from exc
@@ -228,10 +244,10 @@ def _run_transform(
228244
console.print("[yellow] you can specify a different branch with --branch")
229245
raise typer.Abort()
230246

231-
if asyncio.iscoroutinefunction(transformer.func):
232-
output = asyncio.run(transformer(response))
247+
if asyncio.iscoroutinefunction(transform_func):
248+
output = asyncio.run(transform_func(response))
233249
else:
234-
output = transformer(response)
250+
output = transform_func(response)
235251
return output
236252

237253

@@ -257,23 +273,28 @@ def render(
257273
list_jinja2_transforms(config=repository_config)
258274
return
259275

276+
# Load transform config
260277
try:
261278
transform_config = repository_config.get_jinja2_transform(name=transform_name)
262279
except KeyError as exc:
263280
console.print(f'[red]Unable to find "{transform_name}" in {config.INFRAHUB_REPO_CONFIG_FILE}')
264281
list_jinja2_transforms(config=repository_config)
265282
raise typer.Exit(1) from exc
266283

267-
transformer = functools.partial(render_jinja2_template, transform_config.template_path, variables_dict)
284+
# Construct transform function used to transform data returned from the API
285+
transform_func = functools.partial(render_jinja2_template, transform_config.template_path, variables_dict)
286+
287+
# Query GQL and run the transform
268288
result = _run_transform(
269-
query=transform_config.query,
289+
query_name=transform_config.query,
270290
variables=variables_dict,
271-
transformer=transformer,
291+
transform_func=transform_func,
272292
branch=branch,
273293
debug=debug,
274294
repository_config=repository_config,
275295
)
276296

297+
# Output data
277298
if out:
278299
write_to_file(Path(out), result)
279300
else:
@@ -302,31 +323,41 @@ def transform(
302323
list_transforms(config=repository_config)
303324
return
304325

305-
matched = [transform for transform in repository_config.python_transforms if transform.name == transform_name] # pylint: disable=not-an-iterable
306-
307-
if not matched:
326+
# Load transform config
327+
try:
328+
matched = [transform for transform in repository_config.python_transforms if transform.name == transform_name] # pylint: disable=not-an-iterable
329+
if not matched:
330+
raise ValueError(f"{transform_name} does not exist")
331+
except ValueError as exc:
308332
console.print(f"[red]Unable to find requested transform: {transform_name}")
309333
list_transforms(config=repository_config)
310-
return
334+
raise typer.Exit(1) from exc
311335

312336
transform_config = matched[0]
313337

338+
# Get client
339+
client = initialize_client()
340+
341+
# Get python transform class instance
314342
try:
315-
transform_instance = get_transform_class_instance(transform_config=transform_config)
343+
transform = get_transform_class_instance(
344+
transform_config=transform_config,
345+
branch=branch,
346+
client=client,
347+
)
316348
except InfrahubTransformNotFoundError as exc:
317349
console.print(f"Unable to load {transform_name} from python_transforms")
318350
raise typer.Exit(1) from exc
319351

320-
transformer = functools.partial(transform_instance.transform)
321-
result = _run_transform(
322-
query=transform_instance.query,
323-
variables=variables_dict,
324-
transformer=transformer,
325-
branch=branch,
326-
debug=debug,
327-
repository_config=repository_config,
352+
# Get data
353+
query_str = repository_config.get_query(name=transform.query).load_query()
354+
data = asyncio.run(
355+
transform.client.execute_graphql(query=query_str, variables=variables_dict, branch_name=transform.branch_name)
328356
)
329357

358+
# Run Transform
359+
result = asyncio.run(transform.run(data=data))
360+
330361
json_string = ujson.dumps(result, indent=2, sort_keys=True)
331362
if out:
332363
write_to_file(Path(out), json_string)

infrahub_sdk/ctl/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from infrahub_sdk.ctl import config
66

77

8-
async def initialize_client(
8+
def initialize_client(
99
branch: Optional[str] = None,
1010
identifier: Optional[str] = None,
1111
timeout: Optional[int] = None,

infrahub_sdk/ctl/generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ async def run(
4343
if param_key:
4444
identifier = param_key[0]
4545

46-
client = await initialize_client()
46+
client = initialize_client()
4747
if variables_dict:
4848
data = execute_graphql_query(
4949
query=generator_config.query,

infrahub_sdk/ctl/menu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ async def load(
3838
logging.getLogger("infrahub_sdk").setLevel(logging.INFO)
3939

4040
files = load_yamlfile_from_disk_and_exit(paths=menus, file_type=MenuFile, console=console)
41-
client = await initialize_client()
41+
client = initialize_client()
4242

4343
for file in files:
4444
file.validate_content()

infrahub_sdk/ctl/object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ async def load(
3838
logging.getLogger("infrahub_sdk").setLevel(logging.INFO)
3939

4040
files = load_yamlfile_from_disk_and_exit(paths=paths, file_type=ObjectFile, console=console)
41-
client = await initialize_client()
41+
client = initialize_client()
4242

4343
for file in files:
4444
file.validate_content()

infrahub_sdk/ctl/repository.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ async def add(
8888
},
8989
}
9090

91-
client = await initialize_client()
91+
client = initialize_client()
9292

9393
if username:
9494
credential = await client.create(kind="CorePasswordCredential", name=name, username=username, password=password)

infrahub_sdk/ctl/schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ async def load(
115115

116116
schemas_data = load_yamlfile_from_disk_and_exit(paths=schemas, file_type=SchemaFile, console=console)
117117
schema_definition = "schema" if len(schemas_data) == 1 else "schemas"
118-
client = await initialize_client()
118+
client = initialize_client()
119119
validate_schema_content_and_exit(client=client, schemas=schemas_data)
120120

121121
start_time = time.time()
@@ -164,7 +164,7 @@ async def check(
164164
init_logging(debug=debug)
165165

166166
schemas_data = load_yamlfile_from_disk_and_exit(paths=schemas, file_type=SchemaFile, console=console)
167-
client = await initialize_client()
167+
client = initialize_client()
168168
validate_schema_content_and_exit(client=client, schemas=schemas_data)
169169

170170
success, response = await client.schema.check(schemas=[item.content for item in schemas_data], branch=branch)

0 commit comments

Comments
 (0)