Skip to content

Commit 9e59086

Browse files
authored
added cristae model (#113)
Add cristae model to the top-level segmentation function
1 parent 07a6b61 commit 9e59086

File tree

4 files changed

+12
-4
lines changed

4 files changed

+12
-4
lines changed

synapse_net/inference/cristae.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,17 @@ def segment_cristae(
6262
The segmentation mask as a numpy array, or a tuple containing the segmentation mask
6363
and the predictions if return_predictions is True.
6464
"""
65+
mitochondria = kwargs.pop("extra_segmentation")
6566
with_channels = kwargs.pop("with_channels", True)
6667
channels_to_standardize = kwargs.pop("channels_to_standardize", [0])
6768
if verbose:
6869
print("Segmenting cristae in volume of shape", input_volume.shape)
6970
# Create the scaler to handle prediction with a different scaling factor.
7071
scaler = _Scaler(scale, verbose)
71-
input_volume = scaler.scale_input(input_volume)
72+
# rescale each channel
73+
volume = scaler.scale_input(input_volume)
74+
mito_seg = scaler.scale_input(mitochondria, is_segmentation=True)
75+
input_volume = np.stack([volume, mito_seg], axis=0)
7276

7377
# Run prediction and segmentation.
7478
if mask is not None:

synapse_net/inference/inference.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def get_model_training_resolution(model_type: str) -> Dict[str, float]:
112112
"active_zone": {"x": 1.44, "y": 1.44, "z": 1.44},
113113
"compartments": {"x": 3.47, "y": 3.47, "z": 3.47},
114114
"mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07},
115+
"cristae": {"x": 1.44, "y": 1.44, "z": 1.44},
115116
"ribbon": {"x": 1.188, "y": 1.188, "z": 1.188},
116117
"vesicles_2d": {"x": 1.35, "y": 1.35},
117118
"vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35},

synapse_net/inference/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def scale_input(self, input_volume, is_segmentation=False):
5959

6060
if self._original_shape is None:
6161
self._original_shape = input_volume.shape
62-
elif self._oringal_shape != input_volume.shape:
62+
elif self._original_shape != input_volume.shape:
6363
raise RuntimeError(
6464
"Scaler was called with different input shapes. "
6565
"This is not supported, please create a new instance of the class for it."

synapse_net/tools/segmentation_widget.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,11 @@ def on_predict(self):
189189
extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name)
190190
kwargs = {"extra_segmentation": extra_seg}
191191
elif model_type == "cristae": # Cristae model expects 2 3D volumes
192-
image = np.stack([image, self._get_layer_selector_data(self.extra_seg_selector_name)], axis=0)
193-
kwargs = {}
192+
kwargs = {
193+
"extra_segmentation": self._get_layer_selector_data(self.extra_seg_selector_name),
194+
"with_channels": True,
195+
"channels_to_standardize": [0]
196+
}
194197
else:
195198
kwargs = {}
196199
segmentation = run_segmentation(

0 commit comments

Comments
 (0)