From 162f8fcc5d653e65791149163771cadd24f9ee4f Mon Sep 17 00:00:00 2001 From: NicolasGensollen Date: Thu, 8 Aug 2024 16:52:49 +0200 Subject: [PATCH 1/2] use antspy in pet-linear --- clinica/pipelines/pet/linear/cli.py | 7 + clinica/pipelines/pet/linear/pipeline.py | 261 +++++++++++-- .../t1_linear/anat_linear_pipeline.py | 145 +++---- .../pipelines/t1_linear/anat_linear_utils.py | 220 ----------- clinica/pipelines/t1_linear/tasks.py | 50 --- clinica/pipelines/tasks.py | 135 +++++++ clinica/pipelines/utils.py | 366 ++++++++++++++++++ 7 files changed, 804 insertions(+), 380 deletions(-) delete mode 100644 clinica/pipelines/t1_linear/tasks.py create mode 100644 clinica/pipelines/utils.py diff --git a/clinica/pipelines/pet/linear/cli.py b/clinica/pipelines/pet/linear/cli.py index cb7d8a7dd..25004fe6e 100644 --- a/clinica/pipelines/pet/linear/cli.py +++ b/clinica/pipelines/pet/linear/cli.py @@ -34,6 +34,11 @@ @cli_param.option.working_directory @option.global_option_group @option.n_procs +@cli_param.option.option( + "--use-antspy", + is_flag=True, + help="Use ANTsPy instead of ANTs.", +) def cli( bids_directory: str, caps_directory: str, @@ -46,6 +51,7 @@ def cli( subjects_sessions_tsv: Optional[str] = None, working_directory: Optional[str] = None, n_procs: Optional[int] = None, + use_antspy: bool = False, ) -> None: """Affine registration of PET images to the MNI standard space. @@ -81,6 +87,7 @@ def cli( base_dir=working_directory, parameters=parameters, name=pipeline_name, + use_antspy=use_antspy, ) exec_pipeline = ( diff --git a/clinica/pipelines/pet/linear/pipeline.py b/clinica/pipelines/pet/linear/pipeline.py index 2013b4e1b..55cf813bf 100644 --- a/clinica/pipelines/pet/linear/pipeline.py +++ b/clinica/pipelines/pet/linear/pipeline.py @@ -1,7 +1,8 @@ # Use hash instead of parameters for iterables folder names # Otherwise path will be too long and generate OSError -from typing import List +from typing import List, Optional +import nipype.pipeline.engine as npe from nipype import config from clinica.pipelines.pet.engine import PETPipeline @@ -26,6 +27,44 @@ class PETLinear(PETPipeline): A clinica pipeline object containing the pet_linear pipeline. """ + def __init__( + self, + bids_directory: Optional[str] = None, + caps_directory: Optional[str] = None, + tsv_file: Optional[str] = None, + overwrite_caps: Optional[bool] = False, + base_dir: Optional[str] = None, + parameters: Optional[dict] = None, + name: Optional[str] = None, + ignore_dependencies: Optional[List[str]] = None, + use_antspy: bool = False, + ): + from clinica.utils.stream import cprint + + super().__init__( + bids_directory=bids_directory, + caps_directory=caps_directory, + tsv_file=tsv_file, + overwrite_caps=overwrite_caps, + base_dir=base_dir, + parameters=parameters, + ignore_dependencies=ignore_dependencies, + name=name, + ) + self.use_antspy = use_antspy + if self.use_antspy: + self._ignore_dependencies.append("ants") + cprint( + ( + "The PETLinear pipeline has been configured to use ANTsPy instead of ANTs.\n" + "This means that no installation of ANTs is required, but the antspyx Python " + "package must be installed in your environment.\nThis functionality has been " + "introduced in Clinica 0.9.0 and is considered experimental.\n" + "Please report any issue or unexpected results to the Clinica developer team." + ), + lvl="warning", + ) + def _check_custom_dependencies(self) -> None: """Check dependencies that can not be listed in the `info.json` file.""" pass @@ -263,10 +302,13 @@ def _build_output_node(self): def _build_core_nodes(self): """Build and connect the core nodes of the pipeline.""" import nipype.interfaces.utility as nutil - import nipype.pipeline.engine as npe from nipype.interfaces import ants - from clinica.pipelines.tasks import crop_nifti_task + from clinica.pipelines.tasks import ( + crop_nifti_task, + get_filename_no_ext_task, + run_ants_apply_transforms_task, + ) from .tasks import ( clip_task, @@ -282,6 +324,14 @@ def _build_core_nodes(self): ), name="initPipeline", ) + # image_id_node = npe.Node( + # interface=nutil.Function( + # input_names=["filename"], + # output_names=["image_id"], + # function=get_filename_no_ext_task, + # ), + # name="ImageID", + # ) concatenate_node = npe.Node( interface=nutil.Function( input_names=["pet_to_t1w_transform", "t1w_to_mni_transform"], @@ -291,8 +341,6 @@ def _build_core_nodes(self): name="concatenateTransforms", ) - # The core (processing) nodes - # 1. Clipping node clipping_node = npe.Node( name="clipping", @@ -304,52 +352,53 @@ def _build_core_nodes(self): ) clipping_node.inputs.output_dir = self.base_dir - # 2. `RegistrationSynQuick` by *ANTS*. It uses nipype interface. - ants_registration_node = npe.Node( - name="antsRegistration", interface=ants.RegistrationSynQuick() - ) - ants_registration_node.inputs.dimension = 3 - ants_registration_node.inputs.transform_type = "r" - - # 3. `ApplyTransforms` by *ANTS*. It uses nipype interface. PET to MRI + ants_registration_node = self._build_ants_registration_node() ants_applytransform_node = npe.Node( - name="antsApplyTransformPET2MNI", interface=ants.ApplyTransforms() + name="antsApplyTransformPET2MNI", + interface=( + nutil.Function( + function=run_ants_apply_transforms_task, + input_names=[ + "reference_image", + "input_image", + "transforms", + "output_dir", + ], + output_names=["output_image"], + ) + if self.use_antspy + else ants.ApplyTransforms() + ), ) - ants_applytransform_node.inputs.dimension = 3 + if not self.use_antspy: + ants_applytransform_node.inputs.dimension = 3 ants_applytransform_node.inputs.reference_image = self.ref_template - # 4. Normalize the image (using nifti). It uses custom interface, from utils file - ants_registration_nonlinear_node = npe.Node( - name="antsRegistrationT1W2MNI", interface=ants.Registration() + ants_registration_nonlinear_node = ( + self._build_ants_registration_nonlinear_node() ) - ants_registration_nonlinear_node.inputs.fixed_image = self.ref_template - ants_registration_nonlinear_node.inputs.metric = ["MI"] - ants_registration_nonlinear_node.inputs.metric_weight = [1.0] - ants_registration_nonlinear_node.inputs.transforms = ["SyN"] - ants_registration_nonlinear_node.inputs.transform_parameters = [(0.1, 3, 0)] - ants_registration_nonlinear_node.inputs.dimension = 3 - ants_registration_nonlinear_node.inputs.shrink_factors = [[8, 4, 2]] - ants_registration_nonlinear_node.inputs.smoothing_sigmas = [[3, 2, 1]] - ants_registration_nonlinear_node.inputs.sigma_units = ["vox"] - ants_registration_nonlinear_node.inputs.number_of_iterations = [[200, 50, 10]] - ants_registration_nonlinear_node.inputs.convergence_threshold = [1e-05] - ants_registration_nonlinear_node.inputs.convergence_window_size = [10] - ants_registration_nonlinear_node.inputs.radius_or_number_of_bins = [32] - ants_registration_nonlinear_node.inputs.winsorize_lower_quantile = 0.005 - ants_registration_nonlinear_node.inputs.winsorize_upper_quantile = 0.995 - ants_registration_nonlinear_node.inputs.collapse_output_transforms = True - ants_registration_nonlinear_node.inputs.use_histogram_matching = False - ants_registration_nonlinear_node.inputs.verbose = True ants_applytransform_nonlinear_node = npe.Node( - name="antsApplyTransformNonLinear", interface=ants.ApplyTransforms() + name="antsApplyTransformNonLinear", + interface=( + nutil.Function( + function=run_ants_apply_transforms_task, + input_names=[ + "reference_image", + "input_image", + "transforms", + "output_dir", + ], + output_names=["output_image"], + ) + if self.use_antspy + else ants.ApplyTransforms() + ), ) - ants_applytransform_nonlinear_node.inputs.dimension = 3 + if not self.use_antspy: + ants_applytransform_nonlinear_node.inputs.dimension = 3 ants_applytransform_nonlinear_node.inputs.reference_image = self.ref_template - if random_seed := self.parameters.get("random_seed", None): - ants_registration_nonlinear_node.inputs.random_seed = random_seed - normalize_intensity_node = npe.Node( name="intensityNormalization", interface=nutil.Function( @@ -385,9 +434,24 @@ def _build_core_nodes(self): # 7. Optional node: compute PET image in T1w ants_applytransform_optional_node = npe.Node( - name="antsApplyTransformPET2T1w", interface=ants.ApplyTransforms() + name="antsApplyTransformPET2T1w", + interface=( + nutil.Function( + function=run_ants_apply_transforms_task, + input_names=[ + "reference_image", + "input_image", + "transforms", + "output_dir", + ], + output_names=["output_image"], + ) + if self.use_antspy + else ants.ApplyTransforms() + ), ) - ants_applytransform_optional_node.inputs.dimension = 3 + if not self.use_antspy: + ants_applytransform_optional_node.inputs.dimension = 3 self.connect( [ @@ -514,3 +578,114 @@ def _build_core_nodes(self): ), ] ) + + def _build_ants_registration_node(self) -> npe.Node: + import nipype.interfaces.utility as nutil + from nipype.interfaces import ants + + from clinica.pipelines.tasks import run_ants_registration_synquick_task + from clinica.pipelines.utils import AntsRegistrationSynQuickTransformType + + ants_registration_node = npe.Node( + name="antsRegistration", + interface=( + nutil.Function( + function=run_ants_registration_synquick_task, + input_names=[ + "fixed_image", + "moving_image", + "random_seed", + "transform_type", + "output_prefix", + "output_dir", + ], + output_names=["warped_image", "out_matrix"], + ) + if self.use_antspy + else ants.RegistrationSynQuick() + ), + ) + ants_registration_node.inputs.fixed_image = self.ref_template + if self.use_antspy: + ants_registration_node.inputs.output_dir = str(self.base_dir) + ants_registration_node.inputs.transform_type = ( + AntsRegistrationSynQuickTransformType.RIGID + ) + else: + ants_registration_node.inputs.transform_type = "r" + ants_registration_node.inputs.dimension = 3 + ants_registration_node.inputs.random_seed = ( + self.parameters.get("random_seed", None) or 0 + ) + + return ants_registration_node + + def _build_ants_registration_nonlinear_node(self) -> npe.Node: + import nipype.interfaces.utility as nutil + from nipype.interfaces import ants + + from clinica.pipelines.tasks import run_ants_registration_task + from clinica.pipelines.utils import AntsRegistrationTransformType + + ants_registration_nonlinear_node = npe.Node( + name="antsRegistrationT1W2MNI", + interface=( + nutil.Function( + function=run_ants_registration_task, + input_names=[ + "fixed_image", + "moving_image", + "random_seed", + "transform_type", + "output_prefix", + "output_dir", + "shrink_factors", + "smoothing_sigmas", + "number_of_iterations", + "return_inverse_transform", + ], + output_names=[ + "warped_image", + "out_matrix", + "reverse_forward_transforms", + ], + ) + if self.use_antspy + else ants.Registration() + ), + ) + ants_registration_nonlinear_node.inputs.fixed_image = self.ref_template + if self.use_antspy: + ants_registration_nonlinear_node.inputs.transform_type = ( + AntsRegistrationTransformType.SYN + ) + ants_registration_nonlinear_node.inputs.shrink_factors = (8, 4, 2) + ants_registration_nonlinear_node.inputs.smoothing_sigmas = (3, 2, 1) + ants_registration_nonlinear_node.inputs.number_of_iterations = (200, 50, 10) + ants_registration_nonlinear_node.inputs.return_inverse_transform = True + else: + ants_registration_nonlinear_node.inputs.metric = ["MI"] + ants_registration_nonlinear_node.inputs.metric_weight = [1.0] + ants_registration_nonlinear_node.inputs.transforms = ["SyN"] + ants_registration_nonlinear_node.inputs.dimension = 3 + ants_registration_nonlinear_node.inputs.shrink_factors = [[8, 4, 2]] + ants_registration_nonlinear_node.inputs.smoothing_sigmas = [[3, 2, 1]] + ants_registration_nonlinear_node.inputs.sigma_units = ["vox"] + ants_registration_nonlinear_node.inputs.number_of_iterations = [ + [200, 50, 10] + ] + ants_registration_nonlinear_node.inputs.radius_or_number_of_bins = [32] + + ants_registration_nonlinear_node.inputs.transform_parameters = [(0.1, 3, 0)] + ants_registration_nonlinear_node.inputs.convergence_threshold = [1e-05] + ants_registration_nonlinear_node.inputs.convergence_window_size = [10] + ants_registration_nonlinear_node.inputs.winsorize_lower_quantile = 0.005 + ants_registration_nonlinear_node.inputs.winsorize_upper_quantile = 0.995 + ants_registration_nonlinear_node.inputs.collapse_output_transforms = True + ants_registration_nonlinear_node.inputs.use_histogram_matching = False + ants_registration_nonlinear_node.inputs.verbose = True + ants_registration_nonlinear_node.inputs.random_seed = ( + self.parameters.get("random_seed", None) or 0 + ) + + return ants_registration_nonlinear_node diff --git a/clinica/pipelines/t1_linear/anat_linear_pipeline.py b/clinica/pipelines/t1_linear/anat_linear_pipeline.py index daa0c714d..8fed68090 100644 --- a/clinica/pipelines/t1_linear/anat_linear_pipeline.py +++ b/clinica/pipelines/t1_linear/anat_linear_pipeline.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import List, Optional +import nipype.pipeline.engine as npe from nipype import config from clinica.pipelines.engine import Pipeline @@ -254,13 +255,7 @@ def _build_output_node(self): def _build_core_nodes(self): """Build and connect the core nodes of the pipeline.""" import nipype.interfaces.utility as nutil - import nipype.pipeline.engine as npe - from nipype.interfaces import ants - from clinica.pipelines.t1_linear.tasks import ( - run_ants_registration_task, - run_n4biasfieldcorrection_task, - ) from clinica.pipelines.tasks import crop_nifti_task, get_filename_no_ext_task from .anat_linear_utils import print_end_pipeline @@ -273,65 +268,8 @@ def _build_core_nodes(self): ), name="ImageID", ) - - # 1. N4biascorrection by ANTS. It uses nipype interface. - n4biascorrection = npe.Node( - name="n4biascorrection", - interface=( - nutil.Function( - function=run_n4biasfieldcorrection_task, - input_names=[ - "input_image", - "bspline_fitting_distance", - "output_prefix", - "output_dir", - "save_bias", - "verbose", - ], - output_names=["output_image"], - ) - if self.use_antspy - else ants.N4BiasFieldCorrection(dimension=3) - ), - ) - n4biascorrection.inputs.save_bias = True - if self.use_antspy: - n4biascorrection.inputs.output_dir = str(self.base_dir) - n4biascorrection.inputs.verbose = True - if self.name == "t1-linear": - n4biascorrection.inputs.bspline_fitting_distance = 600 - else: - n4biascorrection.inputs.bspline_fitting_distance = 100 - - # 2. `RegistrationSynQuick` by *ANTS*. It uses nipype interface. - ants_registration_node = npe.Node( - name="antsRegistrationSynQuick", - interface=( - nutil.Function( - function=run_ants_registration_task, - input_names=[ - "fixed_image", - "moving_image", - "random_seed", - "output_prefix", - "output_dir", - ], - output_names=["warped_image", "out_matrix"], - ) - if self.use_antspy - else ants.RegistrationSynQuick() - ), - ) - ants_registration_node.inputs.fixed_image = self.ref_template - if not self.use_antspy: - ants_registration_node.inputs.transform_type = "a" - ants_registration_node.inputs.dimension = 3 - - random_seed = self.parameters.get("random_seed", None) - ants_registration_node.inputs.random_seed = random_seed or 0 - - # 3. Crop image (using nifti). It uses custom interface, from utils file - + n4biascorrection = self._build_n4biascorrection_node() + ants_registration_node = self._build_ants_registration_node() cropnifti = npe.Node( name="cropnifti", interface=nutil.Function( @@ -341,8 +279,6 @@ def _build_core_nodes(self): ), ) cropnifti.inputs.output_path = self.base_dir - - # 4. Print end message print_end_message = npe.Node( interface=nutil.Function( input_names=["anat", "final_file"], function=print_end_pipeline @@ -410,3 +346,78 @@ def _build_core_nodes(self): ), ] ) + + def _build_n4biascorrection_node(self) -> npe.Node: + import nipype.interfaces.utility as nutil + from nipype.interfaces import ants + + from clinica.pipelines.tasks import run_n4biasfieldcorrection_task + + n4biascorrection = npe.Node( + name="n4biascorrection", + interface=( + nutil.Function( + function=run_n4biasfieldcorrection_task, + input_names=[ + "input_image", + "bspline_fitting_distance", + "output_prefix", + "output_dir", + "save_bias", + "verbose", + ], + output_names=["output_image"], + ) + if self.use_antspy + else ants.N4BiasFieldCorrection(dimension=3) + ), + ) + n4biascorrection.inputs.save_bias = True + if self.use_antspy: + n4biascorrection.inputs.output_dir = str(self.base_dir) + n4biascorrection.inputs.verbose = True + n4biascorrection.inputs.bspline_fitting_distance = ( + 600 if self.name == "t1-linear" else 100 + ) + + return n4biascorrection + + def _build_ants_registration_node(self) -> npe.Node: + import nipype.interfaces.utility as nutil + from nipype.interfaces import ants + + from clinica.pipelines.tasks import run_ants_registration_synquick_task + from clinica.pipelines.utils import AntsRegistrationSynQuickTransformType + + ants_registration_node = npe.Node( + name="antsRegistrationSynQuick", + interface=( + nutil.Function( + function=run_ants_registration_synquick_task, + input_names=[ + "fixed_image", + "moving_image", + "random_seed", + "transform_type", + "output_prefix", + "output_dir", + ], + output_names=["warped_image", "out_matrix"], + ) + if self.use_antspy + else ants.RegistrationSynQuick() + ), + ) + ants_registration_node.inputs.fixed_image = self.ref_template + if self.use_antspy: + ants_registration_node.inputs.output_dir = str(self.base_dir) + ants_registration_node.inputs.transform_type = ( + AntsRegistrationSynQuickTransformType.AFFINE + ) + else: + ants_registration_node.inputs.dimension = 3 + ants_registration_node.inputs.transform_type = "a" + random_seed = self.parameters.get("random_seed", None) + ants_registration_node.inputs.random_seed = random_seed or 0 + + return ants_registration_node diff --git a/clinica/pipelines/t1_linear/anat_linear_utils.py b/clinica/pipelines/t1_linear/anat_linear_utils.py index 0d2a35046..3248e558c 100644 --- a/clinica/pipelines/t1_linear/anat_linear_utils.py +++ b/clinica/pipelines/t1_linear/anat_linear_utils.py @@ -1,7 +1,3 @@ -from pathlib import Path -from typing import Optional, Tuple - - def get_substitutions_datasink_flair(bids_image_id: str) -> list: from clinica.pipelines.t1_linear.anat_linear_utils import ( # noqa _get_substitutions_datasink, @@ -64,219 +60,3 @@ def print_end_pipeline(anat, final_file): from clinica.utils.ux import print_end_image print_end_image(get_subject_id(anat)) - - -def run_n4biasfieldcorrection( - input_image: Path, - bspline_fitting_distance: int, - output_prefix: Optional[str] = None, - output_dir: Optional[Path] = None, - save_bias: bool = False, - verbose: bool = False, -) -> Path: - """Run n4biasfieldcorrection using antsPy. - - Parameters - ---------- - input_image : Path - The path to the input image. - - bspline_fitting_distance : int - This is the 'spline_param' of n4biasfieldcorrection. - - output_prefix : str, optional - The prefix to be put at the beginning of the output file names. - Ex: 'sub-XXX_ses-MYYY'. - - output_dir : Path, optional - The directory in which to write the output files. - If not provided, these files will be written in the current directory. - - save_bias : bool, optional - Whether to save the bias image or not. - If set to True, the bias image is not returned but saved in the - provided output_dir with a name of the form '{output_prefix}_bias_image.nii.gz'. - Default=False. - - verbose : bool, optional - Control the verbose mode of n4biasfieldcorrection. Set to True can be - useful for debugging. - Default=False. - - Returns - ------- - bias_corrected_output_path : Path - The path to the bias corrected image. - """ - from clinica.utils.exceptions import ClinicaMissingDependencyError - from clinica.utils.stream import cprint, log_and_raise - - try: - import ants - except ImportError: - log_and_raise( - "The package 'antsPy' is required to run antsRegistration in Python.", - ClinicaMissingDependencyError, - ) - - output_prefix = output_prefix or "" - bias_corrected_image = _call_n4_bias_field_correction( - input_image, bspline_fitting_distance, save_bias=False, verbose=verbose - ) - if save_bias: - bias_image = _call_n4_bias_field_correction( - input_image, - bspline_fitting_distance, - save_bias=True, - verbose=verbose, - ) - bias_output_path = ( - output_dir or Path.cwd() - ) / f"{output_prefix}_bias_image.nii.gz" - ants.image_write(bias_image, str(bias_output_path)) - cprint(f"Writing bias image to {bias_output_path}.", lvl="debug") - bias_corrected_output_path = ( - output_dir or Path.cwd() - ) / f"{output_prefix}_bias_corrected_image.nii.gz" - cprint( - f"Writing bias corrected image to {bias_corrected_output_path}.", lvl="debug" - ) - ants.image_write(bias_corrected_image, str(bias_corrected_output_path)) - - return bias_corrected_output_path - - -def _call_n4_bias_field_correction( - input_image: Path, - bspline_fitting_distance: int, - save_bias: bool = False, - verbose: bool = False, -) -> Path: - import ants - from ants.utils.bias_correction import n4_bias_field_correction - - return n4_bias_field_correction( - ants.image_read(str(input_image)), - spline_param=bspline_fitting_distance, - return_bias_field=save_bias, - verbose=verbose, - ) - - -def run_ants_registration( - fixed_image: Path, - moving_image: Path, - random_seed: int, - output_prefix: Optional[str] = None, - output_dir: Optional[Path] = None, - verbose: bool = False, -) -> Tuple[Path, Path]: - """Run antsRegistration using antsPy. - - Parameters - ---------- - fixed_image : Path - The path to the fixed image. - - moving_image : Path - The path to the moving image. - - random_seed : int - The random seed to be used. - - output_prefix : str, optional - The prefix to be put at the beginning of the output file names. - Ex: 'sub-XXX_ses-MYYY'. - - output_dir : Path, optional - The directory in which to write the output files. - If not provided, these files will be written in the current directory. - - verbose : bool, optional - Control the verbose mode of antsRegistration. Set to True can be - useful for debugging. - Default=False. - - Returns - ------- - warped_image_output_path : Path - The path to the warped nifti image generated by antsRegistration. - - transformation_matrix_output_path : Path - The path to the transforms to move from moving to fixed image. - This is a .mat file. - - Raises - ------ - RuntimeError : - If results cannot be extracted. - """ - from clinica.utils.stream import log_and_raise - - registration_results = _call_ants_registration( - fixed_image, moving_image, random_seed, verbose=verbose - ) - try: - warped_image = registration_results["warpedmovout"] - transformation_matrix = registration_results["fwdtransforms"][-1] - except (KeyError, IndexError): - msg = ( - "Something went wrong when calling antsRegistration with the following parameters :\n" - f"- fixed_image = {fixed_image}\n- moving_image = {moving_image}\n" - f"- random_seed = {random_seed}\n- type_of_transformation='antsRegistrationSyN[a]'\n" - ) - log_and_raise(msg, RuntimeError) - - return _write_ants_registration_results( - warped_image, transformation_matrix, output_prefix or "", output_dir - ) - - -def _call_ants_registration( - fixed_image: Path, - moving_image: Path, - random_seed: int, - verbose: bool = False, -) -> dict: - from clinica.utils.exceptions import ClinicaMissingDependencyError - from clinica.utils.stream import log_and_raise - - try: - import ants - except ImportError: - log_and_raise( - "The package 'antsPy' is required to run antsRegistration in Python.", - ClinicaMissingDependencyError, - ) - return ants.registration( - ants.image_read(str(fixed_image)), - ants.image_read(str(moving_image)), - type_of_transformation="antsRegistrationSyN[a]", - random_seed=random_seed, - verbose=verbose, - ) - - -def _write_ants_registration_results( - warped_image, - transformation_matrix, - output_prefix: str, - output_dir: Optional[Path] = None, -) -> Tuple[Path, Path]: - import shutil - - import ants - - from clinica.utils.stream import cprint - - warped_image_output_path = ( - output_dir or Path.cwd() - ) / f"{output_prefix}Warped.nii.gz" - transformation_matrix_output_path = ( - output_dir or Path.cwd() - ) / f"{output_prefix}0GenericAffine.mat" - cprint(f"Writing warped image to {warped_image_output_path}.", lvl="debug") - ants.image_write(warped_image, str(warped_image_output_path)) - shutil.copy(transformation_matrix, transformation_matrix_output_path) - - return warped_image_output_path, transformation_matrix_output_path diff --git a/clinica/pipelines/t1_linear/tasks.py b/clinica/pipelines/t1_linear/tasks.py deleted file mode 100644 index 73081e263..000000000 --- a/clinica/pipelines/t1_linear/tasks.py +++ /dev/null @@ -1,50 +0,0 @@ -def run_n4biasfieldcorrection_task( - input_image: str, - bspline_fitting_distance: int, - output_prefix=None, - output_dir=None, - save_bias=False, - verbose=False, -) -> str: - from pathlib import Path - - from clinica.pipelines.t1_linear.anat_linear_utils import run_n4biasfieldcorrection - - if output_dir: - output_dir = Path(output_dir) - - return str( - run_n4biasfieldcorrection( - Path(input_image), - bspline_fitting_distance, - output_prefix, - output_dir, - save_bias, - verbose, - ) - ) - - -def run_ants_registration_task( - fixed_image: str, - moving_image: str, - random_seed: int, - output_prefix=None, - output_dir=None, -) -> tuple: - from pathlib import Path - - from clinica.pipelines.t1_linear.anat_linear_utils import run_ants_registration - - if output_dir: - output_dir = Path(output_dir) - - warped_image_output_path, transformation_matrix_output_path = run_ants_registration( - Path(fixed_image), - Path(moving_image), - random_seed, - output_prefix, - output_dir, - ) - - return str(warped_image_output_path), str(transformation_matrix_output_path) diff --git a/clinica/pipelines/tasks.py b/clinica/pipelines/tasks.py index 25827a64d..8e9298ccc 100644 --- a/clinica/pipelines/tasks.py +++ b/clinica/pipelines/tasks.py @@ -15,3 +15,138 @@ def get_filename_no_ext_task(filename: str) -> str: from clinica.utils.filemanip import get_filename_no_ext return get_filename_no_ext(Path(filename)) + + +def run_n4biasfieldcorrection_task( + input_image: str, + bspline_fitting_distance: int, + output_prefix=None, + output_dir=None, + save_bias=False, + verbose=False, +) -> str: + from pathlib import Path + + from clinica.pipelines.utils import run_n4biasfieldcorrection + + if output_dir: + output_dir = Path(output_dir) + + return str( + run_n4biasfieldcorrection( + Path(input_image), + bspline_fitting_distance, + output_prefix, + output_dir, + save_bias, + verbose, + ) + ) + + +def run_ants_registration_synquick_task( + fixed_image: str, + moving_image: str, + random_seed: int, + transform_type: str, + output_prefix=None, + output_dir=None, + verbose: bool = False, + return_inverse_transform: bool = False, +) -> tuple: + from pathlib import Path + + from clinica.pipelines.utils import run_ants_registration_synquick + + if output_dir: + output_dir = Path(output_dir) + + ( + warped_image_output_path, + transformation_matrix_output_path, + transformation_matrix_inverse_output_path, + ) = run_ants_registration_synquick( + Path(fixed_image), + Path(moving_image), + random_seed, + transform_type, + output_prefix, + output_dir, + verbose=verbose, + ) + if return_inverse_transform: + return ( + str(warped_image_output_path), + str(transformation_matrix_output_path), + str(transformation_matrix_inverse_output_path), + ) + return str(warped_image_output_path), str(transformation_matrix_output_path) + + +def run_ants_registration_task( + fixed_image: str, + moving_image: str, + random_seed: int, + transform_type: str, + output_prefix=None, + output_dir=None, + verbose: bool = False, + shrink_factors=None, + smoothing_sigmas=None, + number_of_iterations=None, + return_inverse_transform: bool = False, +) -> tuple: + from pathlib import Path + + from clinica.pipelines.utils import run_ants_registration + + if output_dir: + output_dir = Path(output_dir) + + ( + warped_image_output_path, + transformation_matrix_output_path, + transformation_matrix_inverse_output_path, + ) = run_ants_registration( + Path(fixed_image), + Path(moving_image), + random_seed, + transform_type, + output_prefix, + output_dir, + verbose=verbose, + shrink_factors=shrink_factors, + smoothing_sigmas=smoothing_sigmas, + number_of_iterations=number_of_iterations, + ) + + if return_inverse_transform: + return ( + str(warped_image_output_path), + str(transformation_matrix_output_path), + str(transformation_matrix_inverse_output_path), + ) + return str(warped_image_output_path), str(transformation_matrix_output_path) + + +def run_ants_apply_transforms_task( + reference_image: str, + input_image: str, + transforms: list, + output_dir=None, +): + from pathlib import Path + + from clinica.pipelines.utils import run_ants_apply_transforms + + if output_dir: + output_dir = Path(output_dir) + + return str( + run_ants_apply_transforms( + Path(reference_image), + Path(input_image), + transforms, + output_dir=output_dir, + ) + ) diff --git a/clinica/pipelines/utils.py b/clinica/pipelines/utils.py new file mode 100644 index 000000000..f22de9a1d --- /dev/null +++ b/clinica/pipelines/utils.py @@ -0,0 +1,366 @@ +from enum import Enum +from pathlib import Path +from typing import List, Optional, Tuple, Union + +__all__ = [ + "AntsRegistrationTransformType", + "AntsRegistrationSynQuickTransformType", + "run_n4biasfieldcorrection", + "run_ants_registration", + "run_ants_registration_synquick", + "run_ants_apply_transforms", +] + + +class AntsRegistrationSynQuickTransformType(str, Enum): + """The possible values for the transform type of AntsRegistrationSynQuick.""" + + TRANSLATION = "antsRegistrationSyN[t]" + RIGID = "antsRegistrationSyN[r]" + SIMILARITY = "antsRegistrationSyN[s]" + AFFINE = "antsRegistrationSyN[a]" + + +class AntsRegistrationTransformType(str, Enum): + """The possible values for the transform type of AntsRegistration.""" + + TRANSLATION = "Translation" + RIGID = "Rigid" + SIMILARITY = "Similarity" + AFFINE = "Affine" + SYN = "SyN" + + +def run_n4biasfieldcorrection( + input_image: Path, + bspline_fitting_distance: int, + output_prefix: Optional[str] = None, + output_dir: Optional[Path] = None, + save_bias: bool = False, + verbose: bool = False, +) -> Path: + """Run n4biasfieldcorrection using antsPy. + + Parameters + ---------- + input_image : Path + The path to the input image. + + bspline_fitting_distance : int + This is the 'spline_param' of n4biasfieldcorrection. + + output_prefix : str, optional + The prefix to be put at the beginning of the output file names. + Ex: 'sub-XXX_ses-MYYY'. + + output_dir : Path, optional + The directory in which to write the output files. + If not provided, these files will be written in the current directory. + + save_bias : bool, optional + Whether to save the bias image or not. + If set to True, the bias image is not returned but saved in the + provided output_dir with a name of the form '{output_prefix}_bias_image.nii.gz'. + Default=False. + + verbose : bool, optional + Control the verbose mode of n4biasfieldcorrection. Set to True can be + useful for debugging. + Default=False. + + Returns + ------- + bias_corrected_output_path : Path + The path to the bias corrected image. + """ + from clinica.utils.exceptions import ClinicaMissingDependencyError + from clinica.utils.stream import cprint, log_and_raise + + try: + import ants + except ImportError: + log_and_raise( + "The package 'antsPy' is required to run antsRegistration in Python.", + ClinicaMissingDependencyError, + ) + + output_prefix = output_prefix or "" + bias_corrected_image = _call_n4_bias_field_correction( + input_image, bspline_fitting_distance, save_bias=False, verbose=verbose + ) + if save_bias: + bias_image = _call_n4_bias_field_correction( + input_image, + bspline_fitting_distance, + save_bias=True, + verbose=verbose, + ) + bias_output_path = ( + output_dir or Path.cwd() + ) / f"{output_prefix}_bias_image.nii.gz" + ants.image_write(bias_image, str(bias_output_path)) + cprint(f"Writing bias image to {bias_output_path}.", lvl="debug") + bias_corrected_output_path = ( + output_dir or Path.cwd() + ) / f"{output_prefix}_bias_corrected_image.nii.gz" + cprint( + f"Writing bias corrected image to {bias_corrected_output_path}.", lvl="debug" + ) + ants.image_write(bias_corrected_image, str(bias_corrected_output_path)) + + return bias_corrected_output_path + + +def _call_n4_bias_field_correction( + input_image: Path, + bspline_fitting_distance: int, + save_bias: bool = False, + verbose: bool = False, +) -> Path: + import ants + from ants.utils.bias_correction import n4_bias_field_correction + + return n4_bias_field_correction( + ants.image_read(str(input_image)), + spline_param=bspline_fitting_distance, + return_bias_field=save_bias, + verbose=verbose, + ) + + +def run_ants_registration_synquick( + fixed_image: Path, + moving_image: Path, + random_seed: int, + transform_type: Union[str, AntsRegistrationSynQuickTransformType], + output_prefix: Optional[str] = None, + output_dir: Optional[Path] = None, + verbose: bool = False, +) -> Tuple[Path, Path, Path]: + transform_type = AntsRegistrationSynQuickTransformType(transform_type) + return _run_ants_registration( + fixed_image, + moving_image, + random_seed, + transform_type, + output_prefix, + output_dir, + verbose, + ) + + +def run_ants_registration( + fixed_image: Path, + moving_image: Path, + random_seed: int, + transform_type: Union[str, AntsRegistrationTransformType], + output_prefix: Optional[str] = None, + output_dir: Optional[Path] = None, + verbose: bool = False, + shrink_factors: Optional[Tuple[int, ...]] = None, + smoothing_sigmas: Optional[Tuple[int, ...]] = None, + number_of_iterations: Optional[Tuple[int, ...]] = None, +) -> Tuple[Path, Path, Path]: + transform_type = AntsRegistrationTransformType(transform_type) + return _run_ants_registration( + fixed_image, + moving_image, + random_seed, + transform_type, + output_prefix, + output_dir, + verbose, + shrink_factors=shrink_factors, + smoothing_sigmas=smoothing_sigmas, + number_of_iterations=number_of_iterations, + ) + + +def _run_ants_registration( + fixed_image: Path, + moving_image: Path, + random_seed: int, + transform_type: Union[ + AntsRegistrationTransformType, AntsRegistrationSynQuickTransformType + ], + output_prefix: Optional[str] = None, + output_dir: Optional[Path] = None, + verbose: bool = False, + shrink_factors: Optional[Tuple[int, ...]] = None, + smoothing_sigmas: Optional[Tuple[int, ...]] = None, + number_of_iterations: Optional[Tuple[int, ...]] = None, +) -> Tuple[Path, Path, Path]: + """Run antsRegistration using antsPy. + + Parameters + ---------- + fixed_image : Path + The path to the fixed image. + + moving_image : Path + The path to the moving image. + + random_seed : int + The random seed to be used. + + transform_type : AntsRegistrationTransformType or AntsRegistrationSynQuickTransformType + The type of transformation to be applied. + + output_prefix : str, optional + The prefix to be put at the beginning of the output file names. + Ex: 'sub-XXX_ses-MYYY'. + + output_dir : Path, optional + The directory in which to write the output files. + If not provided, these files will be written in the current directory. + + verbose : bool, optional + Control the verbose mode of antsRegistration. Set to True can be + useful for debugging. + Default=False. + + Returns + ------- + warped_image_output_path : Path + The path to the warped nifti image generated by antsRegistration. + + transformation_matrix_output_path : Path + The path to the transforms to move from moving to fixed image. + This is a .mat file. + + Raises + ------ + RuntimeError : + If results cannot be extracted. + """ + from clinica.utils.stream import log_and_raise + + registration_results = _call_ants_registration( + fixed_image, + moving_image, + random_seed, + transform_type, + verbose=verbose, + shrink_factors=shrink_factors, + smoothing_sigmas=smoothing_sigmas, + number_of_iterations=number_of_iterations, + ) + try: + warped_image = registration_results["warpedmovout"] + transformation_matrix = registration_results["fwdtransforms"][-1] + transformation_matrix_inverse = registration_results["invtransforms"][0] + except (KeyError, IndexError): + msg = ( + "Something went wrong when calling antsRegistration with the following parameters :\n" + f"- fixed_image = {fixed_image}\n- moving_image = {moving_image}\n" + f"- random_seed = {random_seed}\n- type_of_transformation='{transform_type.value}'\n" + ) + log_and_raise(msg, RuntimeError) + + return _write_ants_registration_results( + warped_image, + transformation_matrix, + transformation_matrix_inverse, + output_prefix or "", + output_dir, + ) + + +def _call_ants_registration( + fixed_image: Path, + moving_image: Path, + random_seed: int, + transform_type: Union[ + AntsRegistrationTransformType, AntsRegistrationSynQuickTransformType + ], + verbose: bool = False, + shrink_factors: Optional[Tuple[int, ...]] = None, + smoothing_sigmas: Optional[Tuple[int, ...]] = None, + number_of_iterations: Optional[Tuple[int, ...]] = None, +) -> dict: + from clinica.utils.exceptions import ClinicaMissingDependencyError + from clinica.utils.stream import log_and_raise + + try: + import ants + except ImportError: + log_and_raise( + "The package 'antsPy' is required to run antsRegistration in Python.", + ClinicaMissingDependencyError, + ) + kwargs = {} + if shrink_factors is not None: + kwargs["aff_shrink_factors"] = shrink_factors + if smoothing_sigmas is not None: + kwargs["aff_smoothing_sigmas"] = smoothing_sigmas + if number_of_iterations is not None: + kwargs["aff_iterations"] = number_of_iterations + + return ants.registration( + ants.image_read(str(fixed_image)), + ants.image_read(str(moving_image)), + type_of_transformation=transform_type.value, + random_seed=random_seed, + verbose=verbose, + **kwargs, + ) + + +def _write_ants_registration_results( + warped_image, + transformation_matrix, + transformation_matrix_inverse, + output_prefix: str, + output_dir: Optional[Path] = None, +) -> Tuple[Path, Path, Path]: + import shutil + + import ants + + from clinica.utils.stream import cprint + + warped_image_output_path = ( + output_dir or Path.cwd() + ) / f"{output_prefix}Warped.nii.gz" + transformation_matrix_output_path = ( + output_dir or Path.cwd() + ) / f"{output_prefix}0GenericAffine.mat" + transformation_matrix_inverse_output_path = ( + output_dir or Path.cwd() + ) / f"{output_prefix}inverse.mat" + cprint(f"Writing warped image to {warped_image_output_path}.", lvl="debug") + ants.image_write(warped_image, str(warped_image_output_path)) + shutil.copy(transformation_matrix, transformation_matrix_output_path) + shutil.copy( + transformation_matrix_inverse, transformation_matrix_inverse_output_path + ) + + return ( + warped_image_output_path, + transformation_matrix_output_path, + transformation_matrix_inverse_output_path, + ) + + +def run_ants_apply_transforms( + fixed_image: Path, + moving_image: Path, + transformlist: List[str], + output_dir: Optional[Path] = None, +) -> Path: + import ants + + from clinica.utils.stream import cprint + + transformed_image = ants.apply_transforms( + ants.image_read(str(fixed_image)), + ants.image_read(str(moving_image)), + transformlist=transformlist, + ) + transformed_image_output_path = (output_dir or Path.cwd()) / "transformed.nii.gz" + cprint( + f"Writing transformed image to {transformed_image_output_path}.", lvl="debug" + ) + ants.image_write(transformed_image, str(transformed_image_output_path)) + + return transformed_image_output_path From f759acb1a8a45b4c357a3d7fade54be2998221cb Mon Sep 17 00:00:00 2001 From: NicolasGensollen Date: Thu, 22 Aug 2024 16:10:10 +0200 Subject: [PATCH 2/2] fix broken unit tests --- .../t1_linear/test_anat_linear_utils.py | 165 --------------- .../pipelines/test_pipelines_utils.py | 188 ++++++++++++++++++ 2 files changed, 188 insertions(+), 165 deletions(-) create mode 100644 test/unittests/pipelines/test_pipelines_utils.py diff --git a/test/unittests/pipelines/t1_linear/test_anat_linear_utils.py b/test/unittests/pipelines/t1_linear/test_anat_linear_utils.py index 8f3388fe1..859105d33 100644 --- a/test/unittests/pipelines/t1_linear/test_anat_linear_utils.py +++ b/test/unittests/pipelines/t1_linear/test_anat_linear_utils.py @@ -1,10 +1,4 @@ -from pathlib import Path -from unittest.mock import patch - -import nibabel as nib -import numpy as np import pytest -from numpy.testing import assert_array_equal @pytest.mark.parametrize("suffix", ["T1w", "FLAIR", "fooo"]) @@ -30,162 +24,3 @@ def test_get_substitutions_datasink(suffix): f"sub-ADNI022S0004_ses-M000_{suffix}Warped.nii.gz", f"sub-ADNI022S0004_ses-M000_space-MNI152NLin2009cSym_res-1x1x1_{suffix}.nii.gz", ) - - -def n4biasfieldcorrection_mock( - input_image: Path, - bspline_fitting_distance: int, - save_bias: bool = False, - verbose: bool = False, -): - """The mock simply returns the input image without any processing.""" - return nib.load(input_image) - - -def test_run_n4biasfieldcorrection_no_bias_saving(tmp_path): - from clinica.pipelines.t1_linear.anat_linear_utils import run_n4biasfieldcorrection - - data = np.random.random((10, 10, 10)) - nib.save(nib.Nifti1Image(data, np.eye(4)), tmp_path / "test.nii.gz") - output_dir = tmp_path / "out" - output_dir.mkdir() - - with patch("ants.image_write", wraps=nib.save) as image_write_mock: - with patch( - "clinica.pipelines.t1_linear.anat_linear_utils._call_n4_bias_field_correction", - wraps=n4biasfieldcorrection_mock, - ) as ants_bias_correction_mock: - bias_corrected_image = run_n4biasfieldcorrection( - tmp_path / "test.nii.gz", - bspline_fitting_distance=300, - output_prefix="sub-01_ses-M000", - output_dir=output_dir, - ) - image_write_mock.assert_called_once() - ants_bias_correction_mock.assert_called_once_with( - tmp_path / "test.nii.gz", - 300, - save_bias=False, - verbose=False, - ) - # Verify that the bias corrected image exists - # If all went well, it will be the same as the input image because of the mocks. - assert [f.name for f in output_dir.iterdir()] == [ - "sub-01_ses-M000_bias_corrected_image.nii.gz" - ] - assert bias_corrected_image.exists() - bias_corrected_nifti = nib.load(bias_corrected_image) - assert_array_equal(bias_corrected_nifti.affine, np.eye(4)) - assert_array_equal(bias_corrected_nifti.get_fdata(), data) - - -def test_run_n4biasfieldcorrection(tmp_path): - from clinica.pipelines.t1_linear.anat_linear_utils import run_n4biasfieldcorrection - - data = np.random.random((10, 10, 10)) - nib.save(nib.Nifti1Image(data, np.eye(4)), tmp_path / "test.nii.gz") - output_dir = tmp_path / "out" - output_dir.mkdir() - - with patch("ants.image_write", wraps=nib.save) as image_write_mock: - with patch( - "clinica.pipelines.t1_linear.anat_linear_utils._call_n4_bias_field_correction", - wraps=n4biasfieldcorrection_mock, - ) as ants_bias_correction_mock: - bias_corrected_image = run_n4biasfieldcorrection( - tmp_path / "test.nii.gz", - bspline_fitting_distance=300, - output_prefix="sub-01_ses-M000", - output_dir=output_dir, - save_bias=True, - verbose=True, - ) - image_write_mock.assert_called() - ants_bias_correction_mock.assert_called_with( - tmp_path / "test.nii.gz", - 300, - save_bias=True, - verbose=True, - ) - assert set([f.name for f in output_dir.iterdir()]) == { - "sub-01_ses-M000_bias_corrected_image.nii.gz", - "sub-01_ses-M000_bias_image.nii.gz", - } - assert bias_corrected_image.exists() - bias_corrected_nifti = nib.load(bias_corrected_image) - assert_array_equal(bias_corrected_nifti.affine, np.eye(4)) - assert_array_equal(bias_corrected_nifti.get_fdata(), data) - - -def generate_fake_fixed_and_moving_images(folder: Path): - data = np.random.random((10, 10, 10)) - nib.save(nib.Nifti1Image(data, np.eye(4)), folder / "fixed.nii.gz") - nib.save(nib.Nifti1Image(data, np.eye(4)), folder / "moving.nii.gz") - - -def test_run_ants_registration_error(tmp_path, mocker): - import re - - from clinica.pipelines.t1_linear.anat_linear_utils import run_ants_registration - - generate_fake_fixed_and_moving_images(tmp_path) - mocker.patch( - "clinica.pipelines.t1_linear.anat_linear_utils._call_ants_registration", - return_value={}, - ) - with pytest.raises( - RuntimeError, - match=re.escape( - "Something went wrong when calling antsRegistration with the following parameters :\n" - f"- fixed_image = {tmp_path / 'fixed.nii.gz'}\n" - f"- moving_image = {tmp_path / 'moving.nii.gz'}\n" - f"- random_seed = 0\n" - f"- type_of_transformation='antsRegistrationSyN[a]'\n" - ), - ): - run_ants_registration( - tmp_path / "fixed.nii.gz", - tmp_path / "moving.nii.gz", - random_seed=0, - ) - - -def ants_registration_mock( - fixed_image: Path, - moving_image: Path, - random_seed: int, - verbose: bool = False, -) -> dict: - workdir = fixed_image.parent / "workdir" - workdir.mkdir() - mocked_transform = workdir / "transform.mat" - mocked_transform.touch() - return { - "warpedmovout": nib.load(fixed_image), - "fwdtransforms": ["fooo.txt", mocked_transform], - "foo": "bar", - } - - -def test_run_ants_registration(tmp_path): - from clinica.pipelines.t1_linear.anat_linear_utils import run_ants_registration - - output_dir = tmp_path / "out" - output_dir.mkdir() - generate_fake_fixed_and_moving_images(tmp_path) - - with patch( - "clinica.pipelines.t1_linear.anat_linear_utils._call_ants_registration", - wraps=ants_registration_mock, - ) as mock1: - with patch("ants.image_write", wraps=nib.save) as mock2: - run_ants_registration( - tmp_path / "fixed.nii.gz", - tmp_path / "moving.nii.gz", - random_seed=12, - output_dir=output_dir, - ) - mock1.assert_called_once_with( - tmp_path / "fixed.nii.gz", tmp_path / "moving.nii.gz", 12, verbose=False - ) - mock2.assert_called_once() diff --git a/test/unittests/pipelines/test_pipelines_utils.py b/test/unittests/pipelines/test_pipelines_utils.py new file mode 100644 index 000000000..e6a281f6a --- /dev/null +++ b/test/unittests/pipelines/test_pipelines_utils.py @@ -0,0 +1,188 @@ +from pathlib import Path +from typing import Optional, Tuple, Union +from unittest.mock import patch + +import nibabel as nib +import numpy as np +import pytest +from numpy.testing import assert_array_equal + +from clinica.pipelines.utils import ( + AntsRegistrationSynQuickTransformType, + AntsRegistrationTransformType, +) + + +def n4biasfieldcorrection_mock( + input_image: Path, + bspline_fitting_distance: int, + save_bias: bool = False, + verbose: bool = False, +): + """The mock simply returns the input image without any processing.""" + return nib.load(input_image) + + +def test_run_n4biasfieldcorrection_no_bias_saving(tmp_path): + from clinica.pipelines.utils import run_n4biasfieldcorrection + + data = np.random.random((10, 10, 10)) + nib.save(nib.Nifti1Image(data, np.eye(4)), tmp_path / "test.nii.gz") + output_dir = tmp_path / "out" + output_dir.mkdir() + + with patch("ants.image_write", wraps=nib.save) as image_write_mock: + with patch( + "clinica.pipelines.utils._call_n4_bias_field_correction", + wraps=n4biasfieldcorrection_mock, + ) as ants_bias_correction_mock: + bias_corrected_image = run_n4biasfieldcorrection( + tmp_path / "test.nii.gz", + bspline_fitting_distance=300, + output_prefix="sub-01_ses-M000", + output_dir=output_dir, + ) + image_write_mock.assert_called_once() + ants_bias_correction_mock.assert_called_once_with( + tmp_path / "test.nii.gz", + 300, + save_bias=False, + verbose=False, + ) + # Verify that the bias corrected image exists + # If all went well, it will be the same as the input image because of the mocks. + assert [f.name for f in output_dir.iterdir()] == [ + "sub-01_ses-M000_bias_corrected_image.nii.gz" + ] + assert bias_corrected_image.exists() + bias_corrected_nifti = nib.load(bias_corrected_image) + assert_array_equal(bias_corrected_nifti.affine, np.eye(4)) + assert_array_equal(bias_corrected_nifti.get_fdata(), data) + + +def test_run_n4biasfieldcorrection(tmp_path): + from clinica.pipelines.utils import run_n4biasfieldcorrection + + data = np.random.random((10, 10, 10)) + nib.save(nib.Nifti1Image(data, np.eye(4)), tmp_path / "test.nii.gz") + output_dir = tmp_path / "out" + output_dir.mkdir() + + with patch("ants.image_write", wraps=nib.save) as image_write_mock: + with patch( + "clinica.pipelines.utils._call_n4_bias_field_correction", + wraps=n4biasfieldcorrection_mock, + ) as ants_bias_correction_mock: + bias_corrected_image = run_n4biasfieldcorrection( + tmp_path / "test.nii.gz", + bspline_fitting_distance=300, + output_prefix="sub-01_ses-M000", + output_dir=output_dir, + save_bias=True, + verbose=True, + ) + image_write_mock.assert_called() + ants_bias_correction_mock.assert_called_with( + tmp_path / "test.nii.gz", + 300, + save_bias=True, + verbose=True, + ) + assert set([f.name for f in output_dir.iterdir()]) == { + "sub-01_ses-M000_bias_corrected_image.nii.gz", + "sub-01_ses-M000_bias_image.nii.gz", + } + assert bias_corrected_image.exists() + bias_corrected_nifti = nib.load(bias_corrected_image) + assert_array_equal(bias_corrected_nifti.affine, np.eye(4)) + assert_array_equal(bias_corrected_nifti.get_fdata(), data) + + +def generate_fake_fixed_and_moving_images(folder: Path): + data = np.random.random((10, 10, 10)) + nib.save(nib.Nifti1Image(data, np.eye(4)), folder / "fixed.nii.gz") + nib.save(nib.Nifti1Image(data, np.eye(4)), folder / "moving.nii.gz") + + +def test_run_ants_registration_synquick_error(tmp_path, mocker): + import re + + from clinica.pipelines.utils import run_ants_registration_synquick + + generate_fake_fixed_and_moving_images(tmp_path) + mocker.patch( + "clinica.pipelines.utils._call_ants_registration", + return_value={}, + ) + with pytest.raises( + RuntimeError, + match=re.escape( + "Something went wrong when calling antsRegistration with the following parameters :\n" + f"- fixed_image = {tmp_path / 'fixed.nii.gz'}\n" + f"- moving_image = {tmp_path / 'moving.nii.gz'}\n" + f"- random_seed = 0\n" + f"- type_of_transformation='antsRegistrationSyN[a]'\n" + ), + ): + run_ants_registration_synquick( + tmp_path / "fixed.nii.gz", + tmp_path / "moving.nii.gz", + random_seed=0, + transform_type=AntsRegistrationSynQuickTransformType.AFFINE, + ) + + +def ants_registration_mock( + fixed_image: Path, + moving_image: Path, + random_seed: int, + transform_type: Union[ + AntsRegistrationTransformType, AntsRegistrationSynQuickTransformType + ], + verbose: bool = False, + shrink_factors: Optional[Tuple[int, ...]] = None, + smoothing_sigmas: Optional[Tuple[int, ...]] = None, + number_of_iterations: Optional[Tuple[int, ...]] = None, +) -> dict: + workdir = fixed_image.parent / "workdir" + workdir.mkdir() + mocked_transform = workdir / "transform.mat" + mocked_transform.touch() + return { + "warpedmovout": nib.load(fixed_image), + "fwdtransforms": ["fooo.txt", mocked_transform], + "invtransforms": [mocked_transform], + "foo": "bar", + } + + +def test_run_ants_registration_synquick(tmp_path): + from clinica.pipelines.utils import run_ants_registration_synquick + + output_dir = tmp_path / "out" + output_dir.mkdir() + generate_fake_fixed_and_moving_images(tmp_path) + + with patch( + "clinica.pipelines.utils._call_ants_registration", + wraps=ants_registration_mock, + ) as mock1: + with patch("ants.image_write", wraps=nib.save) as mock2: + run_ants_registration_synquick( + tmp_path / "fixed.nii.gz", + tmp_path / "moving.nii.gz", + random_seed=12, + transform_type=AntsRegistrationSynQuickTransformType.AFFINE, + output_dir=output_dir, + ) + mock1.assert_called_once_with( + tmp_path / "fixed.nii.gz", + tmp_path / "moving.nii.gz", + 12, + AntsRegistrationSynQuickTransformType.AFFINE, + verbose=False, + shrink_factors=None, + smoothing_sigmas=None, + number_of_iterations=None, + ) + mock2.assert_called_once()