Skip to content

Add simple JSON pipeline loader that handles range and manual parameter sweeps #583

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions httomo/sweep_runner/param_sweep_json_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Any, Dict, List, Optional, Tuple

import json
import numpy as np


class ParamSweepJsonLoader:
"""
Loader for JSON pipelines containing parameter sweep
"""

def __init__(self, json_string: str) -> None:
self.json_string = json_string

def load(self) -> List[Dict[str, Any]]:
"""
Convert JSON data to python dict
"""
data: List[Dict[str, Any]] = json.loads(self.json_string)
res = self._find_range_sweep_param(data[1:])
if res is not None:
sweep_dict = data[res[1] + 1]["parameters"][res[0]]
sweep_vals = tuple(
np.arange(sweep_dict["start"], sweep_dict["stop"], sweep_dict["step"])
)
data[res[1] + 1]["parameters"][res[0]] = sweep_vals

res = self._find_manual_sweep_param(data[1:])
if res is not None:
sweep_vals = data[res[1] + 1]["parameters"][res[0]]
data[res[1] + 1]["parameters"][res[0]] = tuple(sweep_vals)

return data

def _find_range_sweep_param(
self, methods: List[Dict[str, Any]]
) -> Optional[Tuple[str, int]]:
for idx, method in enumerate(methods):
for name, value in method["parameters"].items():
if isinstance(value, dict):
keys = value.keys()
has_keys_for_sweep = (
"start" in keys and "stop" in keys and "step" in keys
)
if has_keys_for_sweep and len(keys) == 3:
return name, idx

def _find_manual_sweep_param(
self, methods: List[Dict[str, Any]]
) -> Optional[Tuple[str, int]]:
for idx, method in enumerate(methods):
for name, value in method["parameters"].items():
if isinstance(value, list):
return name, idx
92 changes: 92 additions & 0 deletions tests/sweep_runner/test_param_sweep_json_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import numpy as np

from httomo.sweep_runner.param_sweep_json_loader import ParamSweepJsonLoader


def test_load_range_sweep():
PARAM_NAME = "parameter_1"
JSON_STRING = """
[
{
"method": "standard_tomo",
"module_path": "httomo.data.hdf.loaders",
"parameters": {}
},
{
"method": "some_method",
"module_path": "some.module.path",
"parameters": {
"parameter_1": {
"start": 10,
"stop": 110,
"step": 5
}
}
}
]
"""
data = ParamSweepJsonLoader(JSON_STRING).load()
assert isinstance(data[1]["parameters"][PARAM_NAME], tuple)
assert data[1]["parameters"][PARAM_NAME] == tuple(np.arange(10, 110, 5))


def test_param_value_with_start_stop_step_and_other_keys_unaffected_by_range_sweep_parsing():
PARAM_NAME = "parameter_1"
JSON_STRING = """
[
{
"method": "standard_tomo",
"module_path": "httomo.data.hdf.loaders",
"parameters": {}
},
{
"method": "some_method",
"module_path": "some.module.path",
"parameters": {
"parameter_1": {
"start": 10,
"stop": 110,
"step": 5,
"another": 0
}
}
}
]
"""
data = ParamSweepJsonLoader(JSON_STRING).load()
assert isinstance(data[1]["parameters"][PARAM_NAME], dict)
assert data[1]["parameters"][PARAM_NAME] == {
"start": 10,
"stop": 110,
"step": 5,
"another": 0,
}


def test_load_manual_sweep():
PARAM_NAME = "parameter_1"
JSON_STRING = """
[
{
"method": "standard_tomo",
"module_path": "httomo.data.hdf.loaders",
"parameters": {}
},
{
"method": "some_method",
"module_path": "some.module.path",
"parameters": {
"parameter_1": [
1,
5,
12,
13,
14
]
}
}
]
"""
data = ParamSweepJsonLoader(JSON_STRING).load()
assert isinstance(data[1]["parameters"][PARAM_NAME], tuple)
assert data[1]["parameters"][PARAM_NAME] == (1, 5, 12, 13, 14)