diff --git a/tests/IVIMmodels/unit_tests/test_ivim_fit.py b/tests/IVIMmodels/unit_tests/test_ivim_fit.py index fb58d13..f1357c4 100644 --- a/tests/IVIMmodels/unit_tests/test_ivim_fit.py +++ b/tests/IVIMmodels/unit_tests/test_ivim_fit.py @@ -29,6 +29,7 @@ def tolerances_helper(tolerances, data): tolerances["atol"] = tolerances.get("atol", {"f": 2e-1, "D": 5e-4, "Dp": 4e-2}) return tolerances + def data_ivim_fit_saved(): # Find the algorithms from algorithms.json file = pathlib.Path(__file__) @@ -45,8 +46,8 @@ def data_ivim_fit_saved(): bvals = all_data.pop('config') bvals = bvals['bvalues'] first=True - for name, data in all_data.items(): - for algorithm in algorithms: + for algorithm in algorithms: + for name, data in all_data.items(): algorithm_dict = algorithm_information.get(algorithm, {}) xfail = {"xfail": name in algorithm_dict.get("xfail_names", {}), "strict": algorithm_dict.get("xfail_names", {}).get(name, True)} @@ -59,15 +60,38 @@ def data_ivim_fit_saved(): first = False yield name, bvals, data, algorithm, xfail, kwargs, tolerances, skiptime + +def make_hashable(obj): + if isinstance(obj, dict): + return tuple(sorted((k, make_hashable(v)) for k, v in obj.items())) + elif isinstance(obj, (list, tuple)): + return tuple(make_hashable(i) for i in obj) + else: + return obj + + +@pytest.fixture(scope="module") +def algorithm_cache(): + cache = {} + + def get_instance(algorithm, kwargs): + hashable_key = (algorithm, make_hashable(kwargs)) + if hashable_key not in cache: + cache[hashable_key] = OsipiBase(algorithm=algorithm, **kwargs) + return cache[hashable_key] + + return get_instance + + @pytest.mark.parametrize("name, bvals, data, algorithm, xfail, kwargs, tolerances, skiptime", data_ivim_fit_saved()) -def test_ivim_fit_saved(name, bvals, data, algorithm, xfail, kwargs, tolerances,skiptime, request, record_property): +def test_ivim_fit_saved(name, bvals, data, algorithm, xfail, kwargs, tolerances, skiptime, request, record_property, algorithm_cache): if xfail["xfail"]: mark = pytest.mark.xfail(reason="xfail", strict=xfail["strict"]) request.node.add_marker(mark) signal = signal_helper(data["data"]) tolerances = tolerances_helper(tolerances, data) + fit = algorithm_cache(algorithm, kwargs) start_time = time.time() # Record the start time - fit = OsipiBase(algorithm=algorithm, **kwargs) fit_result = fit.osipi_fit(signal, bvals) elapsed_time = time.time() - start_time # Calculate elapsed time def to_list_if_needed(value): @@ -153,10 +177,14 @@ def bound_input(): @pytest.mark.parametrize("name, bvals, data, algorithm, xfail, kwargs, tolerances", bound_input()) -def test_bounds(name, bvals, data, algorithm, xfail, kwargs, tolerances, request): +def test_bounds(name, bvals, data, algorithm, xfail, kwargs, tolerances, request, algorithm_cache): + if xfail["xfail"]: + mark = pytest.mark.xfail(reason="xfail", strict=xfail["strict"]) + request.node.add_marker(mark) bounds = ([0.0008, 0.2, 0.01, 1.1], [0.0012, 0.3, 0.02, 1.3]) # deliberately have silly bounds to see whether they are used - fit = OsipiBase(algorithm=algorithm, bounds=bounds, initial_guess = [0.001, 0.25, 0.015, 1.2], **kwargs) + extended_kwargs = {**kwargs, "bounds": bounds, "initial_guess": [0.001, 0.25, 0.015, 1.2]} + fit = algorithm_cache(algorithm, extended_kwargs) if fit.use_bounds: signal = signal_helper(data["data"]) fit_result = fit.osipi_fit(signal, bvals)