Skip to content

Commit 22ea714

Browse files
Refactor segmentation functionality
1 parent 8b5cf12 commit 22ea714

File tree

9 files changed

+405
-324
lines changed

9 files changed

+405
-324
lines changed

synapse_net/file_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,17 @@
33

44
import mrcfile
55
import numpy as np
6+
import pooch
7+
8+
9+
def get_cache_dir() -> str:
10+
"""Get the cache directory of synapse net.
11+
12+
Returns:
13+
The cache directory.
14+
"""
15+
cache_dir = os.path.expanduser(pooch.os_cache("synapse-net"))
16+
return cache_dir
617

718

819
def get_data_path(folder: str, n_tomograms: Optional[int] = 1) -> Union[str, List[str]]:

synapse_net/inference/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""This submodule implements SynapseNet's segmentation functionality.
22
"""
3-
from .vesicles import segment_vesicles
3+
from .inference import run_segmentation, get_model

synapse_net/inference/inference.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
import os
2+
from typing import Dict, List, Optional, Union
3+
4+
import torch
5+
import numpy as np
6+
import pooch
7+
8+
from .active_zone import segment_active_zone
9+
from .compartments import segment_compartments
10+
from .mitochondria import segment_mitochondria
11+
from .ribbon_synapse import segment_ribbon_synapse_structures
12+
from .vesicles import segment_vesicles
13+
from .util import get_device
14+
from ..file_utils import get_cache_dir
15+
16+
17+
#
18+
# Functions to access SynapseNet's pretrained models.
19+
#
20+
21+
22+
def _get_model_registry():
23+
registry = {
24+
"active_zone": "a18f29168aed72edec0f5c2cb1aa9a4baa227812db6082a6538fd38d9f43afb0",
25+
"compartments": "527983720f9eb215c45c4f4493851fd6551810361eda7b79f185a0d304274ee1",
26+
"mitochondria": "24625018a5968b36f39fa9d73b121a32e8f66d0f2c0540d3df2e1e39b3d58186",
27+
"ribbon": "7c947f0ddfabe51a41d9d05c0a6ca7d6b238f43df2af8fffed5552d09bb075a9",
28+
"vesicles_2d": "eb0b74f7000a0e6a25b626078e76a9452019f2d1ea6cf2033073656f4f055df1",
29+
"vesicles_3d": "b329ec1f57f305099c984fbb3d7f6ae4b0ff51ec2fa0fa586df52dad6b84cf29",
30+
"vesicles_cryo": "782f5a21c3cda82c4e4eaeccc754774d5aaed5929f8496eb018aad7daf91661b",
31+
}
32+
urls = {
33+
"active_zone": "https://owncloud.gwdg.de/index.php/s/zvuY342CyQebPsX/download",
34+
"compartments": "https://owncloud.gwdg.de/index.php/s/DnFDeTmDDmZrDDX/download",
35+
"mitochondria": "https://owncloud.gwdg.de/index.php/s/1T542uvzfuruahD/download",
36+
"ribbon": "https://owncloud.gwdg.de/index.php/s/S3b5l0liPP1XPYA/download",
37+
"vesicles_2d": "https://owncloud.gwdg.de/index.php/s/d72QIvdX6LsgXip/download",
38+
"vesicles_3d": "https://owncloud.gwdg.de/index.php/s/A425mkAOSqePDhx/download",
39+
"vesicles_cryo": "https://owncloud.gwdg.de/index.php/s/e2lVdxjCJuZkLJm/download",
40+
}
41+
cache_dir = get_cache_dir()
42+
models = pooch.create(
43+
path=os.path.join(cache_dir, "models"),
44+
base_url="",
45+
registry=registry,
46+
urls=urls,
47+
)
48+
return models
49+
50+
51+
def get_model_path(model_type: str) -> str:
52+
"""Get the local path to a pretrained model.
53+
54+
Args:
55+
The model type.
56+
57+
Returns:
58+
The local path to the model.
59+
"""
60+
model_registry = _get_model_registry()
61+
model_path = model_registry.fetch(model_type)
62+
return model_path
63+
64+
65+
def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module:
66+
"""Get the model for a specific segmentation type.
67+
68+
Args:
69+
model_type: The model for one of the following segmentation tasks:
70+
'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'.
71+
device: The device to use.
72+
73+
Returns:
74+
The model.
75+
"""
76+
if device is None:
77+
device = get_device(device)
78+
model_path = get_model_path(model_type)
79+
model = torch.load(model_path, weights_only=False)
80+
model.to(device)
81+
return model
82+
83+
84+
#
85+
# Functions for training resolution / voxel size.
86+
#
87+
88+
89+
def get_model_training_resolution(model_type: str) -> Dict[str, float]:
90+
"""Get the average resolution / voxel size of the training data for a given pretrained model.
91+
92+
Args:
93+
model_type: The name of the pretrained model.
94+
95+
Returns:
96+
Mapping of axis (x, y, z) to the voxel size (in nm) of that axis.
97+
"""
98+
resolutions = {
99+
"active_zone": {"x": 1.44, "y": 1.44, "z": 1.44},
100+
"compartments": {"x": 3.47, "y": 3.47, "z": 3.47},
101+
"mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07},
102+
"ribbon": {"x": 1.188, "y": 1.188, "z": 1.188},
103+
"vesicles_2d": {"x": 1.35, "y": 1.35},
104+
"vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35},
105+
"vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88},
106+
}
107+
return resolutions[model_type]
108+
109+
110+
def compute_scale_from_voxel_size(
111+
voxel_size: Dict[str, float],
112+
model_type: str
113+
) -> List[float]:
114+
"""Compute the appropriate scale factor for inference with a given pretrained model.
115+
116+
Args:
117+
voxel_size: The voxel size of the data for inference.
118+
model_type: The name of the pretrained model.
119+
120+
Returns:
121+
The scale factor, as a list in zyx order.
122+
"""
123+
training_voxel_size = get_model_training_resolution(model_type)
124+
scale = [
125+
voxel_size["x"] / training_voxel_size["x"],
126+
voxel_size["y"] / training_voxel_size["y"],
127+
]
128+
if len(voxel_size) == 3 and len(training_voxel_size) == 3:
129+
scale.append(
130+
voxel_size["z"] / training_voxel_size["z"]
131+
)
132+
return scale
133+
134+
135+
#
136+
# Convenience functions for segmentation.
137+
#
138+
139+
140+
def _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons):
141+
from synapse_net.inference.postprocessing import (
142+
segment_ribbon, segment_presynaptic_density, segment_membrane_distance_based,
143+
)
144+
145+
ribbon = segment_ribbon(
146+
predictions["ribbon"], vesicles, n_slices_exclude=n_slices_exclude, n_ribbons=n_ribbons,
147+
max_vesicle_distance=40,
148+
)
149+
PD = segment_presynaptic_density(
150+
predictions["PD"], ribbon, n_slices_exclude=n_slices_exclude, max_distance_to_ribbon=40,
151+
)
152+
ref_segmentation = PD if PD.sum() > 0 else ribbon
153+
membrane = segment_membrane_distance_based(
154+
predictions["membrane"], ref_segmentation, max_distance=500, n_slices_exclude=n_slices_exclude,
155+
)
156+
157+
segmentations = {"ribbon": ribbon, "PD": PD, "membrane": membrane}
158+
return segmentations
159+
160+
161+
def _segment_ribbon_AZ(image, model, tiling, scale, verbose, return_predictions=False, **kwargs):
162+
# Parse additional keyword arguments from the kwargs.
163+
vesicles = kwargs.pop("extra_segmentation")
164+
threshold = kwargs.pop("threshold", 0.5)
165+
n_slices_exclude = kwargs.pop("n_slices_exclude", 20)
166+
n_ribbons = kwargs.pop("n_slices_exclude", 1)
167+
168+
predictions = segment_ribbon_synapse_structures(
169+
image, model=model, tiling=tiling, scale=scale, verbose=verbose, threshold=threshold, **kwargs
170+
)
171+
172+
# Otherwise, just return the predictions.
173+
if vesicles is None:
174+
if verbose:
175+
print("Vesicle segmentation was not passed, WILL NOT run post-processing.")
176+
segmentations = predictions
177+
178+
# If the vesicles were passed then run additional post-processing.
179+
else:
180+
if verbose:
181+
print("Vesicle segmentation was passed, WILL run post-processing.")
182+
segmentations = _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons)
183+
184+
if return_predictions:
185+
return segmentations, predictions
186+
return segmentations
187+
188+
189+
def run_segmentation(
190+
image: np.ndarray,
191+
model: torch.nn.Module,
192+
model_type: str,
193+
tiling: Optional[Dict[str, Dict[str, int]]] = None,
194+
scale: Optional[List[float]] = None,
195+
verbose: bool = False,
196+
**kwargs,
197+
) -> np.ndarray | Dict[str, np.ndarray]:
198+
"""Run synaptic structure segmentation.
199+
200+
Args:
201+
image: The input image or image volume.
202+
model: The segmentation model.
203+
model_type: The model type. This will determine which segmentation post-processing is used.
204+
tiling: The tiling settings for inference.
205+
scale: A scale factor for resizing the input before applying the model.
206+
The output will be scaled back to the initial size.
207+
verbose: Whether to print detailed information about the prediction and segmentation.
208+
kwargs: Optional parameters for the segmentation function.
209+
210+
Returns:
211+
The segmentation. For models that return multiple segmentations, this function returns a dictionary.
212+
"""
213+
if model_type.startswith("vesicles"):
214+
segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
215+
elif model_type == "mitochondria":
216+
segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
217+
elif model_type == "active_zone":
218+
segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
219+
elif model_type == "compartments":
220+
segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
221+
elif model_type == "ribbon":
222+
segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
223+
else:
224+
raise ValueError(f"Unknown model type: {model_type}")
225+
return segmentation

synapse_net/inference/util.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import time
33
import warnings
44
from glob import glob
5-
from typing import Dict, Optional, Tuple
5+
from typing import Dict, Optional, Tuple, Union
66

77
# # Suppress annoying import warnings.
88
# with warnings.catch_warnings():
@@ -26,6 +26,11 @@
2626
from tqdm import tqdm
2727

2828

29+
#
30+
# Utils for prediction.
31+
#
32+
33+
2934
class _Scaler:
3035
def __init__(self, scale, verbose):
3136
self.scale = scale
@@ -474,6 +479,11 @@ def parse_tiling(
474479
return tiling
475480

476481

482+
#
483+
# Utils for post-processing.
484+
#
485+
486+
477487
def apply_size_filter(
478488
segmentation: np.ndarray,
479489
min_size: int,
@@ -525,3 +535,54 @@ def _postprocess_seg_3d(seg, area_threshold=1000, iterations=4, iterations_3d=8)
525535
seg[bb][mask] = prop.label
526536

527537
return seg
538+
539+
540+
#
541+
# Utils for torch device.
542+
#
543+
544+
def _get_default_device():
545+
# Check that we're in CI and use the CPU if we are.
546+
# Otherwise the tests may run out of memory on MAC if MPS is used.
547+
if os.getenv("GITHUB_ACTIONS") == "true":
548+
return "cpu"
549+
# Use cuda enabled gpu if it's available.
550+
if torch.cuda.is_available():
551+
device = "cuda"
552+
# As second priority use mps.
553+
# See https://pytorch.org/docs/stable/notes/mps.html for details
554+
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
555+
device = "mps"
556+
# Use the CPU as fallback.
557+
else:
558+
device = "cpu"
559+
return device
560+
561+
562+
def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]:
563+
"""Get the torch device.
564+
565+
If no device is passed the default device for your system is used.
566+
Else it will be checked if the device you have passed is supported.
567+
568+
Args:
569+
device: The input device.
570+
571+
Returns:
572+
The device.
573+
"""
574+
if device is None or device == "auto":
575+
device = _get_default_device()
576+
else:
577+
device_type = device if isinstance(device, str) else device.type
578+
if device_type.lower() == "cuda":
579+
if not torch.cuda.is_available():
580+
raise RuntimeError("PyTorch CUDA backend is not available.")
581+
elif device_type.lower() == "mps":
582+
if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
583+
raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.")
584+
elif device_type.lower() == "cpu":
585+
pass # cpu is always available
586+
else:
587+
raise RuntimeError(f"Unsupported device: {device}. Please choose from 'cpu', 'cuda', or 'mps'.")
588+
return device

synapse_net/sample_data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import pooch
33

4-
from .file_utils import read_mrc
4+
from .file_utils import read_mrc, get_cache_dir
55

66

77
def get_sample_data(name: str) -> str:
@@ -27,7 +27,7 @@ def get_sample_data(name: str) -> str:
2727
valid_names = [k[:-4] for k in registry.keys()]
2828
raise ValueError(f"Invalid sample name {name}, please choose one of {valid_names}.")
2929

30-
cache_dir = os.path.expanduser(pooch.os_cache("synapse-net"))
30+
cache_dir = get_cache_dir()
3131
data_registry = pooch.create(
3232
path=os.path.join(cache_dir, "sample_data"),
3333
base_url="",

synapse_net/tools/cli.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import argparse
22
from functools import partial
33

4-
from .util import (
5-
run_segmentation, get_model, get_model_registry, get_model_training_resolution, load_custom_model
6-
)
4+
import torch
75
from ..imod.to_imod import export_helper, write_segmentation_to_imod_as_points, write_segmentation_to_imod
6+
from ..inference.inference import _get_model_registry, get_model, get_model_training_resolution, run_segmentation
87
from ..inference.util import inference_helper, parse_tiling
98

109

@@ -108,7 +107,7 @@ def segmentation_cli():
108107
"--output_path", "-o", required=True,
109108
help="The filepath to directory where the segmentations will be saved."
110109
)
111-
model_names = list(get_model_registry().urls.keys())
110+
model_names = list(_get_model_registry().urls.keys())
112111
model_names = ", ".join(model_names)
113112
parser.add_argument(
114113
"--model", "-m", required=True,
@@ -152,7 +151,7 @@ def segmentation_cli():
152151
if args.checkpoint is None:
153152
model = get_model(args.model)
154153
else:
155-
model = load_custom_model(args.checkpoint)
154+
model = torch.load(args.checkpoint, weights_only=False)
156155
assert model is not None, f"The model from {args.checkpoint} could not be loaded."
157156

158157
is_2d = "2d" in args.model

0 commit comments

Comments
 (0)