Skip to content

Commit 2afed55

Browse files
authored
added cristae model and single channel transfrom (#100)
Update input standardization to later support cristae model
1 parent 92be799 commit 2afed55

File tree

4 files changed

+21
-4
lines changed

4 files changed

+21
-4
lines changed

synapse_net/inference/cristae.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def segment_cristae(
4242
return_predictions: bool = False,
4343
scale: Optional[List[float]] = None,
4444
mask: Optional[np.ndarray] = None,
45+
**kwargs
4546
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
4647
"""Segment cristae in an input volume.
4748
@@ -61,6 +62,8 @@ def segment_cristae(
6162
The segmentation mask as a numpy array, or a tuple containing the segmentation mask
6263
and the predictions if return_predictions is True.
6364
"""
65+
with_channels = kwargs.pop("with_channels", True)
66+
channels_to_standardize = kwargs.pop("channels_to_standardize", [0])
6467
if verbose:
6568
print("Segmenting cristae in volume of shape", input_volume.shape)
6669
# Create the scaler to handle prediction with a different scaling factor.
@@ -72,7 +75,7 @@ def segment_cristae(
7275
mask = scaler.scale_input(mask, is_segmentation=True)
7376
pred = get_prediction(
7477
input_volume, model_path=model_path, model=model, mask=mask,
75-
tiling=tiling, with_channels=True, verbose=verbose
78+
tiling=tiling, with_channels=with_channels, channels_to_standardize=channels_to_standardize, verbose=verbose
7679
)
7780
foreground, boundaries = pred[:2]
7881
seg = _run_segmentation(foreground, verbose=verbose, min_size=min_size)

synapse_net/inference/inference.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .mitochondria import segment_mitochondria
1111
from .ribbon_synapse import segment_ribbon_synapse_structures
1212
from .vesicles import segment_vesicles
13+
from .cristae import segment_cristae
1314
from .util import get_device
1415
from ..file_utils import get_cache_dir
1516

@@ -25,6 +26,7 @@ def _get_model_registry():
2526
"compartments": "527983720f9eb215c45c4f4493851fd6551810361eda7b79f185a0d304274ee1",
2627
"mitochondria": "24625018a5968b36f39fa9d73b121a32e8f66d0f2c0540d3df2e1e39b3d58186",
2728
"mitochondria2": "553decafaff4838fff6cc8347f22c8db3dee5bcbeffc34ffaec152f8449af673",
29+
"cristae": "f96c90484f4ea92ac0515a06e389cc117580f02c2aacdc44b5828820cf38c3c3",
2830
"ribbon": "7c947f0ddfabe51a41d9d05c0a6ca7d6b238f43df2af8fffed5552d09bb075a9",
2931
"vesicles_2d": "eb0b74f7000a0e6a25b626078e76a9452019f2d1ea6cf2033073656f4f055df1",
3032
"vesicles_3d": "b329ec1f57f305099c984fbb3d7f6ae4b0ff51ec2fa0fa586df52dad6b84cf29",
@@ -35,6 +37,7 @@ def _get_model_registry():
3537
"compartments": "https://owncloud.gwdg.de/index.php/s/DnFDeTmDDmZrDDX/download",
3638
"mitochondria": "https://owncloud.gwdg.de/index.php/s/1T542uvzfuruahD/download",
3739
"mitochondria2": "https://owncloud.gwdg.de/index.php/s/GZghrXagc54FFXd/download",
40+
"cristae": "https://owncloud.gwdg.de/index.php/s/Df7OUOyQ1Kc2eEO/download",
3841
"ribbon": "https://owncloud.gwdg.de/index.php/s/S3b5l0liPP1XPYA/download",
3942
"vesicles_2d": "https://owncloud.gwdg.de/index.php/s/d72QIvdX6LsgXip/download",
4043
"vesicles_3d": "https://owncloud.gwdg.de/index.php/s/A425mkAOSqePDhx/download",
@@ -214,14 +217,16 @@ def run_segmentation(
214217
"""
215218
if model_type.startswith("vesicles"):
216219
segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
217-
elif model_type == "mitochondria":
220+
elif model_type == "mitochondria" or model_type == "mitochondria2":
218221
segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
219222
elif model_type == "active_zone":
220223
segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
221224
elif model_type == "compartments":
222225
segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
223226
elif model_type == "ribbon":
224227
segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
228+
elif model_type == "cristae":
229+
segmentation = segment_cristae(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
225230
else:
226231
raise ValueError(f"Unknown model type: {model_type}")
227232
return segmentation

synapse_net/inference/util.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import time
33
import warnings
44
from glob import glob
5-
from typing import Dict, Optional, Tuple, Union
5+
from typing import Dict, List, Optional, Tuple, Union
66

77
# # Suppress annoying import warnings.
88
# with warnings.catch_warnings():
@@ -101,6 +101,7 @@ def get_prediction(
101101
model: Optional[torch.nn.Module] = None,
102102
verbose: bool = True,
103103
with_channels: bool = False,
104+
channels_to_standardize: Optional[List[int]] = None,
104105
mask: Optional[np.ndarray] = None,
105106
) -> np.ndarray:
106107
"""Run prediction on a given volume.
@@ -115,6 +116,7 @@ def get_prediction(
115116
tiling: The tiling configuration for the prediction.
116117
verbose: Whether to print timing information.
117118
with_channels: Whether to predict with channels.
119+
channels_to_standardize: List of channels to standardize. Defaults to None.
118120
mask: Optional binary mask. If given, the prediction will only be run in
119121
the foreground region of the mask.
120122
@@ -136,8 +138,12 @@ def get_prediction(
136138
# We standardize the data for the whole volume beforehand.
137139
# If we have channels then the standardization is done independently per channel.
138140
if with_channels:
141+
input_volume = input_volume.astype(np.float32, copy=False)
139142
# TODO Check that this is the correct axis.
140-
input_volume = torch_em.transform.raw.standardize(input_volume, axis=(1, 2, 3))
143+
if channels_to_standardize is None: # assume all channels
144+
channels_to_standardize = range(input_volume.shape[0])
145+
for ch in channels_to_standardize:
146+
input_volume[ch] = torch_em.transform.raw.standardize(input_volume[ch])
141147
else:
142148
input_volume = torch_em.transform.raw.standardize(input_volume)
143149

synapse_net/tools/segmentation_widget.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,9 @@ def on_predict(self):
183183
if model_type == "ribbon": # Currently only the ribbon model needs the extra seg.
184184
extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name)
185185
kwargs = {"extra_segmentation": extra_seg}
186+
elif model_type == "cristae": # Cristae model expects 2 3D volumes
187+
image = np.stack([image, self._get_layer_selector_data(self.extra_seg_selector_name)], axis=0)
188+
kwargs = {}
186189
else:
187190
kwargs = {}
188191
segmentation = run_segmentation(

0 commit comments

Comments
 (0)