Skip to content

Commit 5a37995

Browse files
authored
Merge pull request #338 from Climate-REF/151-specified-solve
2 parents 147fa71 + e29a22a commit 5a37995

File tree

5 files changed

+207
-9
lines changed

5 files changed

+207
-9
lines changed

changelog/338.feature.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Adds `--diagnostic` and `--provider` arguments to the `ref solve` command.
2+
This allows users to subset a specific diagnostic or provider that they wish to run.
3+
Multiple `--diagnostic` or `--provider` arguments can be used to specify multiple diagnostics or providers.
4+
The diagnostic or provider slug must contain one of the filter values to be included in the calculations.
Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,48 @@
1+
from typing import Annotated
2+
13
import typer
24

3-
from climate_ref.solver import solve_required_executions
5+
from climate_ref.solver import SolveFilterOptions, solve_required_executions
46

57
app = typer.Typer()
68

79

810
@app.command()
9-
def solve(
11+
def solve( # noqa: PLR0913
1012
ctx: typer.Context,
11-
dry_run: bool = typer.Option(False, help="Do not execute any diagnostics"),
13+
dry_run: Annotated[
14+
bool,
15+
typer.Option(help="Do not execute any diagnostics"),
16+
] = False,
17+
execute: Annotated[
18+
bool,
19+
typer.Option(help="Solve the newly identified executions"),
20+
] = True,
1221
timeout: int = typer.Option(60, help="Timeout in seconds for the solve operation"),
1322
one_per_provider: bool = typer.Option(
1423
False, help="Limit to one execution per provider. This is useful for testing"
1524
),
1625
one_per_diagnostic: bool = typer.Option(
1726
False, help="Limit to one execution per diagnostic. This is useful for testing"
1827
),
28+
diagnostic: Annotated[
29+
list[str] | None,
30+
typer.Option(
31+
help="Filters executions by the diagnostic slug. "
32+
"Diagnostics will be included if any of the filters match a case-insensitive subset "
33+
"of the diagnostic slug. "
34+
"Multiple values can be provided"
35+
),
36+
] = None,
37+
provider: Annotated[
38+
list[str] | None,
39+
typer.Option(
40+
help="Filters executions by provider slug. "
41+
"Providers will be included if any of the filters match a case-insensitive subset "
42+
"of the provider slug. "
43+
"Multiple values can be provided"
44+
),
45+
] = None,
1946
) -> None:
2047
"""
2148
Solve for executions that require recalculation
@@ -25,11 +52,19 @@ def solve(
2552
"""
2653
config = ctx.obj.config
2754
db = ctx.obj.database
55+
56+
filters = SolveFilterOptions(
57+
diagnostic=diagnostic,
58+
provider=provider,
59+
)
60+
2861
solve_required_executions(
2962
config=config,
3063
db=db,
3164
dry_run=dry_run,
65+
execute=execute,
3266
timeout=timeout,
3367
one_per_provider=one_per_provider,
3468
one_per_diagnostic=one_per_diagnostic,
69+
filters=filters,
3570
)

packages/climate-ref/src/climate_ref/solver.py

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,57 @@ def _solve_from_data_requirements(
245245
)
246246

247247

248+
@define
249+
class SolveFilterOptions:
250+
"""
251+
Options to filter the diagnostics that are solved
252+
"""
253+
254+
diagnostic: list[str] | None = None
255+
"""
256+
Check if the diagnostic slug contains any of the provided values
257+
"""
258+
provider: list[str] | None = None
259+
"""
260+
Check if the provider slug contains any of the provided values
261+
"""
262+
263+
264+
def matches_filter(diagnostic: Diagnostic, filters: SolveFilterOptions | None) -> bool:
265+
"""
266+
Check if a diagnostic matches the provided filters
267+
268+
Each filter is optional and a diagnostic will match if it satisfies all the provided filters.
269+
i.e. the filters are ANDed together.
270+
271+
Parameters
272+
----------
273+
diagnostic
274+
Diagnostic to check against the filters
275+
filters
276+
Collection of filters to apply to the diagnostic
277+
278+
If no filters are provided, the diagnostic is considered to match
279+
280+
Returns
281+
-------
282+
True if the diagnostic matches the filters, False otherwise
283+
"""
284+
if filters is None:
285+
return True
286+
287+
diagnostic_slug = diagnostic.slug
288+
provider_slug = diagnostic.provider.slug
289+
290+
if filters.provider and not any([f.lower() in provider_slug for f in filters.provider]):
291+
return False
292+
293+
if filters.diagnostic and not any([f.lower() in diagnostic_slug for f in filters.diagnostic]):
294+
return False
295+
296+
return True
297+
298+
248299
@define
249300
class ExecutionSolver:
250301
"""
@@ -278,7 +329,9 @@ def build_from_db(config: Config, db: Database) -> "ExecutionSolver":
278329
},
279330
)
280331

281-
def solve(self) -> typing.Generator[DiagnosticExecution, None, None]:
332+
def solve(
333+
self, filters: SolveFilterOptions | None = None
334+
) -> typing.Generator[DiagnosticExecution, None, None]:
282335
"""
283336
Solve which executions need to be calculated for a dataset
284337
@@ -293,17 +346,23 @@ def solve(self) -> typing.Generator[DiagnosticExecution, None, None]:
293346
"""
294347
for provider in self.provider_registry.providers:
295348
for diagnostic in provider.diagnostics():
349+
# Filter the diagnostic based on the provided filters
350+
if not matches_filter(diagnostic, filters):
351+
logger.debug(f"Skipping {diagnostic.full_slug()} due to filter")
352+
continue
296353
yield from solve_executions(self.data_catalog, diagnostic, provider)
297354

298355

299356
def solve_required_executions( # noqa: PLR0913
300357
db: Database,
301358
dry_run: bool = False,
359+
execute: bool = True,
302360
solver: ExecutionSolver | None = None,
303361
config: Config | None = None,
304362
timeout: int = 60,
305363
one_per_provider: bool = False,
306364
one_per_diagnostic: bool = False,
365+
filters: SolveFilterOptions | None = None,
307366
) -> None:
308367
"""
309368
Solve for executions that require recalculation
@@ -328,7 +387,7 @@ def solve_required_executions( # noqa: PLR0913
328387
diagnostic_count = {}
329388
provider_count = {}
330389

331-
for potential_execution in solver.solve():
390+
for potential_execution in solver.solve(filters):
332391
# The diagnostic output is first written to the scratch directory
333392
definition = potential_execution.build_execution_definition(output_root=config.paths.scratch)
334393

@@ -371,6 +430,7 @@ def solve_required_executions( # noqa: PLR0913
371430
logger.info(f"Created new execution group: {potential_execution.execution_slug()!r}")
372431
db.session.flush()
373432

433+
# TODO: Move this logic to the solver
374434
# Check if we should run given the one_per_provider or one_per_diagnostic flags
375435
one_of_check_failed = (
376436
one_per_provider and provider_count.get(diagnostic.provider.slug, 0) > 0
@@ -403,10 +463,11 @@ def solve_required_executions( # noqa: PLR0913
403463
# Add links to the datasets used in the execution
404464
execution.register_datasets(db, definition.datasets)
405465

406-
executor.run(
407-
definition=definition,
408-
execution=execution,
409-
)
466+
if execute:
467+
executor.run(
468+
definition=definition,
469+
execution=execution,
470+
)
410471

411472
provider_count[diagnostic.provider.slug] += 1
412473
diagnostic_count[diagnostic.full_slug()] += 1

packages/climate-ref/tests/unit/cli/test_solve.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ def test_solve(self, sample_data_dir, db, invoke_cli, mocker):
1414

1515
assert kwargs["timeout"] == 60
1616
assert not kwargs["dry_run"]
17+
assert kwargs["execute"]
18+
assert kwargs["filters"].diagnostic is None
19+
assert kwargs["filters"].provider is None
1720

1821
def test_solve_with_timeout(self, sample_data_dir, db, invoke_cli, mocker):
1922
mock_solve = mocker.patch("climate_ref.cli.solve.solve_required_executions")
@@ -28,3 +31,21 @@ def test_solve_with_dryrun(self, sample_data_dir, db, invoke_cli, mocker):
2831

2932
args, kwargs = mock_solve.call_args
3033
assert kwargs["dry_run"]
34+
35+
def test_solve_with_filters(self, sample_data_dir, db, invoke_cli, mocker):
36+
mock_solve = mocker.patch("climate_ref.cli.solve.solve_required_executions")
37+
invoke_cli(
38+
[
39+
"solve",
40+
"--diagnostic",
41+
"global-mean-timeseries",
42+
"--provider",
43+
"esmvaltool",
44+
"--provider",
45+
"ilamb",
46+
]
47+
)
48+
49+
args, kwargs = mock_solve.call_args
50+
assert kwargs["filters"].diagnostic == ["global-mean-timeseries"]
51+
assert kwargs["filters"].provider == ["esmvaltool", "ilamb"]

packages/climate-ref/tests/unit/test_solver.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from climate_ref.solver import (
1313
DiagnosticExecution,
1414
ExecutionSolver,
15+
SolveFilterOptions,
1516
extract_covered_datasets,
1617
solve_executions,
1718
solve_required_executions,
@@ -31,6 +32,19 @@ def solver(db_seeded, config) -> ExecutionSolver:
3132
return metric_solver
3233

3334

35+
@pytest.fixture
36+
def aft_solver(db_seeded, config) -> ExecutionSolver:
37+
from climate_ref_esmvaltool import provider as esmvaltool_provider
38+
from climate_ref_ilamb import provider as ilamb_provider
39+
from climate_ref_pmp import provider as pmp_provider
40+
41+
registry = ProviderRegistry(providers=[pmp_provider, esmvaltool_provider, ilamb_provider])
42+
metric_solver = ExecutionSolver.build_from_db(config, db_seeded)
43+
metric_solver.provider_registry = registry
44+
45+
return metric_solver
46+
47+
3448
@pytest.fixture
3549
def mock_metric_execution(
3650
tmp_path, db_seeded, definition_factory, mock_diagnostic, provider
@@ -289,6 +303,69 @@ def test_extract_no_groups():
289303
extract_covered_datasets(data_catalog, requirement)
290304

291305

306+
def test_solver_solve_with_filters(aft_solver):
307+
def solve_filtered(**kwargs):
308+
"""Helper function to solve with filters and return a DataFrame of results."""
309+
return pd.DataFrame(
310+
[
311+
{
312+
"diagnostic": execution.diagnostic.slug,
313+
"provider": execution.provider.slug,
314+
"dataset_key": execution.dataset_key,
315+
}
316+
for execution in aft_solver.solve(filters=SolveFilterOptions(**kwargs))
317+
]
318+
)
319+
320+
# Empty filters should return all executions
321+
executions = solve_filtered()
322+
assert not executions.empty
323+
executions = solve_filtered(provider=None, diagnostic=None)
324+
assert not executions.empty
325+
executions = solve_filtered(provider=[], diagnostic=[])
326+
assert not executions.empty
327+
328+
# ILAMB filter should only return ILAMB executions
329+
executions = solve_filtered(provider=["ilamb"])
330+
assert executions["provider"].unique().tolist() == ["ilamb"]
331+
assert executions["diagnostic"].nunique() > 1
332+
333+
# Multiple provider filters
334+
executions = solve_filtered(provider=["ilamb", "pmp"])
335+
assert sorted(executions["provider"].unique().tolist()) == ["ilamb", "pmp"]
336+
337+
# Partial diagnostic filter should return executions for that diagnostic
338+
# enso metrics exist in both pmp and esmvaltool providers
339+
executions = solve_filtered(diagnostic=["enso"])
340+
assert sorted(executions["provider"].unique().tolist()) == ["esmvaltool", "pmp"]
341+
342+
# Adding in a provider filter as well should limit the results to that provider
343+
executions = solve_filtered(provider=["pmp"], diagnostic=["enso"])
344+
assert executions["provider"].unique().tolist() == ["pmp"]
345+
assert sorted(executions["diagnostic"].unique().tolist()) == ["enso_proc", "enso_tel"]
346+
347+
# Check lowercase
348+
pd.testing.assert_frame_equal(executions, solve_filtered(provider=["PmP"], diagnostic=["enSo"]))
349+
350+
# Missing provider should return no results
351+
assert not list(
352+
aft_solver.solve(
353+
filters=SolveFilterOptions(
354+
provider=["missing"],
355+
)
356+
)
357+
)
358+
359+
# Missing diagnostic should return no results
360+
assert not list(
361+
aft_solver.solve(
362+
filters=SolveFilterOptions(
363+
diagnostic=["missing"],
364+
)
365+
)
366+
)
367+
368+
292369
def test_solve_metrics_default_solver(mocker, mock_metric_execution, mock_executor, db_seeded, solver):
293370
mock_build_solver = mocker.patch.object(ExecutionSolver, "build_from_db")
294371

0 commit comments

Comments
 (0)