|
| 1 | +import os |
| 2 | +from typing import Dict, List, Optional, Union |
| 3 | + |
| 4 | +import torch |
| 5 | +import numpy as np |
| 6 | +import pooch |
| 7 | + |
| 8 | +from .active_zone import segment_active_zone |
| 9 | +from .compartments import segment_compartments |
| 10 | +from .mitochondria import segment_mitochondria |
| 11 | +from .ribbon_synapse import segment_ribbon_synapse_structures |
| 12 | +from .vesicles import segment_vesicles |
| 13 | +from .util import get_device |
| 14 | +from ..file_utils import get_cache_dir |
| 15 | + |
| 16 | + |
| 17 | +# |
| 18 | +# Functions to access SynapseNet's pretrained models. |
| 19 | +# |
| 20 | + |
| 21 | + |
| 22 | +def _get_model_registry(): |
| 23 | + registry = { |
| 24 | + "active_zone": "a18f29168aed72edec0f5c2cb1aa9a4baa227812db6082a6538fd38d9f43afb0", |
| 25 | + "compartments": "527983720f9eb215c45c4f4493851fd6551810361eda7b79f185a0d304274ee1", |
| 26 | + "mitochondria": "24625018a5968b36f39fa9d73b121a32e8f66d0f2c0540d3df2e1e39b3d58186", |
| 27 | + "ribbon": "7c947f0ddfabe51a41d9d05c0a6ca7d6b238f43df2af8fffed5552d09bb075a9", |
| 28 | + "vesicles_2d": "eb0b74f7000a0e6a25b626078e76a9452019f2d1ea6cf2033073656f4f055df1", |
| 29 | + "vesicles_3d": "b329ec1f57f305099c984fbb3d7f6ae4b0ff51ec2fa0fa586df52dad6b84cf29", |
| 30 | + "vesicles_cryo": "782f5a21c3cda82c4e4eaeccc754774d5aaed5929f8496eb018aad7daf91661b", |
| 31 | + } |
| 32 | + urls = { |
| 33 | + "active_zone": "https://owncloud.gwdg.de/index.php/s/zvuY342CyQebPsX/download", |
| 34 | + "compartments": "https://owncloud.gwdg.de/index.php/s/DnFDeTmDDmZrDDX/download", |
| 35 | + "mitochondria": "https://owncloud.gwdg.de/index.php/s/1T542uvzfuruahD/download", |
| 36 | + "ribbon": "https://owncloud.gwdg.de/index.php/s/S3b5l0liPP1XPYA/download", |
| 37 | + "vesicles_2d": "https://owncloud.gwdg.de/index.php/s/d72QIvdX6LsgXip/download", |
| 38 | + "vesicles_3d": "https://owncloud.gwdg.de/index.php/s/A425mkAOSqePDhx/download", |
| 39 | + "vesicles_cryo": "https://owncloud.gwdg.de/index.php/s/e2lVdxjCJuZkLJm/download", |
| 40 | + } |
| 41 | + cache_dir = get_cache_dir() |
| 42 | + models = pooch.create( |
| 43 | + path=os.path.join(cache_dir, "models"), |
| 44 | + base_url="", |
| 45 | + registry=registry, |
| 46 | + urls=urls, |
| 47 | + ) |
| 48 | + return models |
| 49 | + |
| 50 | + |
| 51 | +def get_model_path(model_type: str) -> str: |
| 52 | + """Get the local path to a pretrained model. |
| 53 | +
|
| 54 | + Args: |
| 55 | + The model type. |
| 56 | +
|
| 57 | + Returns: |
| 58 | + The local path to the model. |
| 59 | + """ |
| 60 | + model_registry = _get_model_registry() |
| 61 | + model_path = model_registry.fetch(model_type) |
| 62 | + return model_path |
| 63 | + |
| 64 | + |
| 65 | +def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module: |
| 66 | + """Get the model for a specific segmentation type. |
| 67 | +
|
| 68 | + Args: |
| 69 | + model_type: The model for one of the following segmentation tasks: |
| 70 | + 'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'. |
| 71 | + device: The device to use. |
| 72 | +
|
| 73 | + Returns: |
| 74 | + The model. |
| 75 | + """ |
| 76 | + if device is None: |
| 77 | + device = get_device(device) |
| 78 | + model_path = get_model_path(model_type) |
| 79 | + model = torch.load(model_path, weights_only=False) |
| 80 | + model.to(device) |
| 81 | + return model |
| 82 | + |
| 83 | + |
| 84 | +# |
| 85 | +# Functions for training resolution / voxel size. |
| 86 | +# |
| 87 | + |
| 88 | + |
| 89 | +def get_model_training_resolution(model_type: str) -> Dict[str, float]: |
| 90 | + """Get the average resolution / voxel size of the training data for a given pretrained model. |
| 91 | +
|
| 92 | + Args: |
| 93 | + model_type: The name of the pretrained model. |
| 94 | +
|
| 95 | + Returns: |
| 96 | + Mapping of axis (x, y, z) to the voxel size (in nm) of that axis. |
| 97 | + """ |
| 98 | + resolutions = { |
| 99 | + "active_zone": {"x": 1.44, "y": 1.44, "z": 1.44}, |
| 100 | + "compartments": {"x": 3.47, "y": 3.47, "z": 3.47}, |
| 101 | + "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07}, |
| 102 | + "ribbon": {"x": 1.188, "y": 1.188, "z": 1.188}, |
| 103 | + "vesicles_2d": {"x": 1.35, "y": 1.35}, |
| 104 | + "vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35}, |
| 105 | + "vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88}, |
| 106 | + } |
| 107 | + return resolutions[model_type] |
| 108 | + |
| 109 | + |
| 110 | +def compute_scale_from_voxel_size( |
| 111 | + voxel_size: Dict[str, float], |
| 112 | + model_type: str |
| 113 | +) -> List[float]: |
| 114 | + """Compute the appropriate scale factor for inference with a given pretrained model. |
| 115 | +
|
| 116 | + Args: |
| 117 | + voxel_size: The voxel size of the data for inference. |
| 118 | + model_type: The name of the pretrained model. |
| 119 | +
|
| 120 | + Returns: |
| 121 | + The scale factor, as a list in zyx order. |
| 122 | + """ |
| 123 | + training_voxel_size = get_model_training_resolution(model_type) |
| 124 | + scale = [ |
| 125 | + voxel_size["x"] / training_voxel_size["x"], |
| 126 | + voxel_size["y"] / training_voxel_size["y"], |
| 127 | + ] |
| 128 | + if len(voxel_size) == 3 and len(training_voxel_size) == 3: |
| 129 | + scale.append( |
| 130 | + voxel_size["z"] / training_voxel_size["z"] |
| 131 | + ) |
| 132 | + return scale |
| 133 | + |
| 134 | + |
| 135 | +# |
| 136 | +# Convenience functions for segmentation. |
| 137 | +# |
| 138 | + |
| 139 | + |
| 140 | +def _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons): |
| 141 | + from synapse_net.inference.postprocessing import ( |
| 142 | + segment_ribbon, segment_presynaptic_density, segment_membrane_distance_based, |
| 143 | + ) |
| 144 | + |
| 145 | + ribbon = segment_ribbon( |
| 146 | + predictions["ribbon"], vesicles, n_slices_exclude=n_slices_exclude, n_ribbons=n_ribbons, |
| 147 | + max_vesicle_distance=40, |
| 148 | + ) |
| 149 | + PD = segment_presynaptic_density( |
| 150 | + predictions["PD"], ribbon, n_slices_exclude=n_slices_exclude, max_distance_to_ribbon=40, |
| 151 | + ) |
| 152 | + ref_segmentation = PD if PD.sum() > 0 else ribbon |
| 153 | + membrane = segment_membrane_distance_based( |
| 154 | + predictions["membrane"], ref_segmentation, max_distance=500, n_slices_exclude=n_slices_exclude, |
| 155 | + ) |
| 156 | + |
| 157 | + segmentations = {"ribbon": ribbon, "PD": PD, "membrane": membrane} |
| 158 | + return segmentations |
| 159 | + |
| 160 | + |
| 161 | +def _segment_ribbon_AZ(image, model, tiling, scale, verbose, return_predictions=False, **kwargs): |
| 162 | + # Parse additional keyword arguments from the kwargs. |
| 163 | + vesicles = kwargs.pop("extra_segmentation") |
| 164 | + threshold = kwargs.pop("threshold", 0.5) |
| 165 | + n_slices_exclude = kwargs.pop("n_slices_exclude", 20) |
| 166 | + n_ribbons = kwargs.pop("n_slices_exclude", 1) |
| 167 | + |
| 168 | + predictions = segment_ribbon_synapse_structures( |
| 169 | + image, model=model, tiling=tiling, scale=scale, verbose=verbose, threshold=threshold, **kwargs |
| 170 | + ) |
| 171 | + |
| 172 | + # Otherwise, just return the predictions. |
| 173 | + if vesicles is None: |
| 174 | + if verbose: |
| 175 | + print("Vesicle segmentation was not passed, WILL NOT run post-processing.") |
| 176 | + segmentations = predictions |
| 177 | + |
| 178 | + # If the vesicles were passed then run additional post-processing. |
| 179 | + else: |
| 180 | + if verbose: |
| 181 | + print("Vesicle segmentation was passed, WILL run post-processing.") |
| 182 | + segmentations = _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons) |
| 183 | + |
| 184 | + if return_predictions: |
| 185 | + return segmentations, predictions |
| 186 | + return segmentations |
| 187 | + |
| 188 | + |
| 189 | +def run_segmentation( |
| 190 | + image: np.ndarray, |
| 191 | + model: torch.nn.Module, |
| 192 | + model_type: str, |
| 193 | + tiling: Optional[Dict[str, Dict[str, int]]] = None, |
| 194 | + scale: Optional[List[float]] = None, |
| 195 | + verbose: bool = False, |
| 196 | + **kwargs, |
| 197 | +) -> np.ndarray | Dict[str, np.ndarray]: |
| 198 | + """Run synaptic structure segmentation. |
| 199 | +
|
| 200 | + Args: |
| 201 | + image: The input image or image volume. |
| 202 | + model: The segmentation model. |
| 203 | + model_type: The model type. This will determine which segmentation post-processing is used. |
| 204 | + tiling: The tiling settings for inference. |
| 205 | + scale: A scale factor for resizing the input before applying the model. |
| 206 | + The output will be scaled back to the initial size. |
| 207 | + verbose: Whether to print detailed information about the prediction and segmentation. |
| 208 | + kwargs: Optional parameters for the segmentation function. |
| 209 | +
|
| 210 | + Returns: |
| 211 | + The segmentation. For models that return multiple segmentations, this function returns a dictionary. |
| 212 | + """ |
| 213 | + if model_type.startswith("vesicles"): |
| 214 | + segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) |
| 215 | + elif model_type == "mitochondria": |
| 216 | + segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) |
| 217 | + elif model_type == "active_zone": |
| 218 | + segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) |
| 219 | + elif model_type == "compartments": |
| 220 | + segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) |
| 221 | + elif model_type == "ribbon": |
| 222 | + segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) |
| 223 | + else: |
| 224 | + raise ValueError(f"Unknown model type: {model_type}") |
| 225 | + return segmentation |
0 commit comments