diff --git a/httomo/sweep_runner/param_sweep_json_loader.py b/httomo/sweep_runner/param_sweep_json_loader.py new file mode 100644 index 000000000..072eb5157 --- /dev/null +++ b/httomo/sweep_runner/param_sweep_json_loader.py @@ -0,0 +1,54 @@ +from typing import Any, Dict, List, Optional, Tuple + +import json +import numpy as np + + +class ParamSweepJsonLoader: + """ + Loader for JSON pipelines containing parameter sweep + """ + + def __init__(self, json_string: str) -> None: + self.json_string = json_string + + def load(self) -> List[Dict[str, Any]]: + """ + Convert JSON data to python dict + """ + data: List[Dict[str, Any]] = json.loads(self.json_string) + res = self._find_range_sweep_param(data[1:]) + if res is not None: + sweep_dict = data[res[1] + 1]["parameters"][res[0]] + sweep_vals = tuple( + np.arange(sweep_dict["start"], sweep_dict["stop"], sweep_dict["step"]) + ) + data[res[1] + 1]["parameters"][res[0]] = sweep_vals + + res = self._find_manual_sweep_param(data[1:]) + if res is not None: + sweep_vals = data[res[1] + 1]["parameters"][res[0]] + data[res[1] + 1]["parameters"][res[0]] = tuple(sweep_vals) + + return data + + def _find_range_sweep_param( + self, methods: List[Dict[str, Any]] + ) -> Optional[Tuple[str, int]]: + for idx, method in enumerate(methods): + for name, value in method["parameters"].items(): + if isinstance(value, dict): + keys = value.keys() + has_keys_for_sweep = ( + "start" in keys and "stop" in keys and "step" in keys + ) + if has_keys_for_sweep and len(keys) == 3: + return name, idx + + def _find_manual_sweep_param( + self, methods: List[Dict[str, Any]] + ) -> Optional[Tuple[str, int]]: + for idx, method in enumerate(methods): + for name, value in method["parameters"].items(): + if isinstance(value, list): + return name, idx diff --git a/tests/sweep_runner/test_param_sweep_json_loader.py b/tests/sweep_runner/test_param_sweep_json_loader.py new file mode 100644 index 000000000..e70d4217f --- /dev/null +++ b/tests/sweep_runner/test_param_sweep_json_loader.py @@ -0,0 +1,92 @@ +import numpy as np + +from httomo.sweep_runner.param_sweep_json_loader import ParamSweepJsonLoader + + +def test_load_range_sweep(): + PARAM_NAME = "parameter_1" + JSON_STRING = """ +[ + { + "method": "standard_tomo", + "module_path": "httomo.data.hdf.loaders", + "parameters": {} + }, + { + "method": "some_method", + "module_path": "some.module.path", + "parameters": { + "parameter_1": { + "start": 10, + "stop": 110, + "step": 5 + } + } + } +] +""" + data = ParamSweepJsonLoader(JSON_STRING).load() + assert isinstance(data[1]["parameters"][PARAM_NAME], tuple) + assert data[1]["parameters"][PARAM_NAME] == tuple(np.arange(10, 110, 5)) + + +def test_param_value_with_start_stop_step_and_other_keys_unaffected_by_range_sweep_parsing(): + PARAM_NAME = "parameter_1" + JSON_STRING = """ +[ + { + "method": "standard_tomo", + "module_path": "httomo.data.hdf.loaders", + "parameters": {} + }, + { + "method": "some_method", + "module_path": "some.module.path", + "parameters": { + "parameter_1": { + "start": 10, + "stop": 110, + "step": 5, + "another": 0 + } + } + } +] +""" + data = ParamSweepJsonLoader(JSON_STRING).load() + assert isinstance(data[1]["parameters"][PARAM_NAME], dict) + assert data[1]["parameters"][PARAM_NAME] == { + "start": 10, + "stop": 110, + "step": 5, + "another": 0, + } + + +def test_load_manual_sweep(): + PARAM_NAME = "parameter_1" + JSON_STRING = """ +[ + { + "method": "standard_tomo", + "module_path": "httomo.data.hdf.loaders", + "parameters": {} + }, + { + "method": "some_method", + "module_path": "some.module.path", + "parameters": { + "parameter_1": [ + 1, + 5, + 12, + 13, + 14 + ] + } + } +] +""" + data = ParamSweepJsonLoader(JSON_STRING).load() + assert isinstance(data[1]["parameters"][PARAM_NAME], tuple) + assert data[1]["parameters"][PARAM_NAME] == (1, 5, 12, 13, 14)