Skip to content

Commit 9472a47

Browse files
Merge pull request #51 from computational-cell-analytics/more-plugin-updates
Updates to napari plugin and start to implement CLI
2 parents 38d6edf + 4e19b6e commit 9472a47

15 files changed

+345
-427
lines changed

run_correction.sh

Lines changed: 0 additions & 1 deletion
This file was deleted.

setup.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,12 @@
88
name="synaptic_reconstruction",
99
packages=find_packages(exclude=["test"]),
1010
version=__version__,
11-
author="Constantin Pape; Sarah Muth",
11+
author="Constantin Pape; Sarah Muth; Luca Freckmann",
1212
url="https://github.com/computational-cell-analytics/synaptic_reconstruction",
1313
license="MIT",
1414
entry_points={
1515
"console_scripts": [
16-
"sr_tools.correct_segmentation = synaptic_reconstruction.tools.segmentation_correction:main",
17-
"sr_tools.measure_distances = synaptic_reconstruction.tools.distance_measurement:main",
16+
"synapse_net.run_segmentation = synaptic_reconstruction.tools.cli:segmentation_cli"
1817
],
1918
"napari.manifest": [
2019
"synaptic_reconstruction = synaptic_reconstruction:napari.yaml",

synaptic_reconstruction/distance_measurements.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import multiprocessing as mp
3-
from typing import Dict, Optional, Tuple
3+
from typing import Dict, List, Optional, Tuple
44

55
import numpy as np
66

@@ -61,8 +61,9 @@ def _compute_boundary_distances(segmentation, resolution, n_threads):
6161
n = len(seg_ids)
6262

6363
pairwise_distances = np.zeros((n, n))
64-
end_points1 = np.zeros((n, n, 3), dtype="int")
65-
end_points2 = np.zeros((n, n, 3), dtype="int")
64+
ndim = segmentation.ndim
65+
end_points1 = np.zeros((n, n, ndim), dtype="int")
66+
end_points2 = np.zeros((n, n, ndim), dtype="int")
6667

6768
properties = regionprops(segmentation)
6869
properties = {prop.label: prop for prop in properties}
@@ -80,8 +81,11 @@ def compute_distances_for_object(i):
8081
prop = properties[ngb_id]
8182

8283
bb = prop.bbox
83-
offset = np.array(bb[:3])
84-
bb = np.s_[bb[0]:bb[3], bb[1]:bb[4], bb[2]:bb[5]]
84+
offset = np.array(bb[:ndim])
85+
if ndim == 2:
86+
bb = np.s_[bb[0]:bb[2], bb[1]:bb[3]]
87+
else:
88+
bb = np.s_[bb[0]:bb[3], bb[1]:bb[4], bb[2]:bb[5]]
8589

8690
mask = segmentation[bb] == ngb_id
8791
ngb_dist, ngb_index = distances[bb].copy(), indices[(slice(None),) + bb]
@@ -168,10 +172,9 @@ def _compute_seg_object_distances(segmentation, segmented_object, resolution, ve
168172
for prop in tqdm(props, disable=not verbose):
169173
bb = prop.bbox
170174
offset = np.array(bb[:ndim])
171-
# bb = np.s_[bb[0]:bb[3], bb[1]:bb[4], bb[2]:bb[5]]
172-
if len(bb) == 4: # 2D bounding box
175+
if ndim == 2:
173176
bb = np.s_[bb[0]:bb[2], bb[1]:bb[3]]
174-
elif len(bb) == 6: # 3D bounding box
177+
else:
175178
bb = np.s_[bb[0]:bb[3], bb[1]:bb[4], bb[2]:bb[5]]
176179

177180
label = prop.label
@@ -296,7 +299,7 @@ def create_pairwise_distance_lines(
296299
distances: np.ndarray,
297300
endpoints1: np.ndarray,
298301
endpoints2: np.ndarray,
299-
seg_ids: np.ndarray,
302+
seg_ids: List[List[int]],
300303
n_neighbors: Optional[int] = None,
301304
pairs: Optional[np.ndarray] = None,
302305
bb: Optional[Tuple[slice]] = None,
@@ -323,9 +326,7 @@ def create_pairwise_distance_lines(
323326
if pairs is None and n_neighbors is not None:
324327
pairs = _extract_nearest_neighbors(distances, seg_ids, n_neighbors, remove_duplicates=remove_duplicates)
325328
elif pairs is None:
326-
pairs = [
327-
[id1, id2] for id1 in seg_ids for id2 in seg_ids if id1 < id2
328-
]
329+
pairs = [[id1, id2] for id1 in seg_ids for id2 in seg_ids if id1 < id2]
329330

330331
assert pairs is not None
331332
pair_indices = (

synaptic_reconstruction/inference/mitochondria.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,30 +14,43 @@ def _run_segmentation(
1414
block_shape=(128, 256, 256),
1515
halo=(48, 48, 48)
1616
):
17-
18-
# get the segmentation via seeded watershed
1917
t0 = time.time()
20-
seeds = parallel.label((foreground - boundaries) > 0.5, block_shape=block_shape, verbose=verbose)
18+
boundary_threshold = 0.25
19+
dist = parallel.distance_transform(
20+
boundaries < boundary_threshold, halo=halo, verbose=verbose, block_shape=block_shape
21+
)
2122
if verbose:
22-
print("Compute connected components in", time.time() - t0, "s")
23+
print("Compute distance transform in", time.time() - t0, "s")
2324

25+
# Get the segmentation via seeded watershed.
2426
t0 = time.time()
25-
dist = parallel.distance_transform(seeds == 0, halo=halo, verbose=verbose, block_shape=block_shape)
27+
seed_distance = 6
28+
seeds = np.logical_and(foreground > 0.5, dist > seed_distance)
29+
seeds = parallel.label(seeds, block_shape=block_shape, verbose=verbose)
2630
if verbose:
27-
print("Compute distance transform in", time.time() - t0, "s")
31+
print("Compute connected components in", time.time() - t0, "s")
32+
33+
# import napari
34+
# v = napari.Viewer()
35+
# v.add_image(boundaries)
36+
# v.add_image(dist)
37+
# v.add_labels(seeds)
38+
# napari.run()
2839

2940
t0 = time.time()
41+
hmap = boundaries + ((dist.max() - dist) / dist.max())
3042
mask = (foreground + boundaries) > 0.5
43+
3144
seg = np.zeros_like(seeds)
3245
seg = parallel.seeded_watershed(
33-
dist, seeds, block_shape=block_shape,
46+
hmap, seeds, block_shape=block_shape,
3447
out=seg, mask=mask, verbose=verbose, halo=halo,
3548
)
3649
if verbose:
3750
print("Compute watershed in", time.time() - t0, "s")
3851

3952
seg = apply_size_filter(seg, min_size, verbose=verbose, block_shape=block_shape)
40-
seg = _postprocess_seg_3d(seg)
53+
seg = _postprocess_seg_3d(seg, area_threshold=5000)
4154
return seg
4255

4356

synaptic_reconstruction/inference/util.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -255,21 +255,23 @@ def _get_file_paths(input_path, ext=".mrc"):
255255

256256

257257
def _load_input(img_path, extra_files, i):
258-
# Load the input data data
259-
with open_file(img_path, "r") as f:
260-
261-
# Try to automatically derive the key with the raw data.
262-
keys = list(f.keys())
263-
if len(keys) == 1:
264-
key = keys[0]
265-
elif "data" in keys:
266-
key = "data"
267-
elif "raw" in keys:
268-
key = "raw"
269-
270-
input_volume = f[key][:]
271-
assert input_volume.ndim == 3
258+
# Load the input data.
259+
if os.path.splitext(img_path)[-1] == ".tif":
260+
input_volume = imageio.imread(img_path)
272261

262+
else:
263+
with open_file(img_path, "r") as f:
264+
# Try to automatically derive the key with the raw data.
265+
keys = list(f.keys())
266+
if len(keys) == 1:
267+
key = keys[0]
268+
elif "data" in keys:
269+
key = "data"
270+
elif "raw" in keys:
271+
key = "raw"
272+
input_volume = f[key][:]
273+
274+
assert input_volume.ndim in (2, 3)
273275
# For now we assume this is always tif.
274276
if extra_files is not None:
275277
extra_input = imageio.imread(extra_files[i])
@@ -470,7 +472,7 @@ def apply_size_filter(
470472
return segmentation
471473

472474

473-
def _postprocess_seg_3d(seg):
475+
def _postprocess_seg_3d(seg, area_threshold=1000, iterations=4, iterations_3d=8):
474476
# Structure lement for 2d dilation in 3d.
475477
structure_element = np.ones((3, 3)) # 3x3 structure for XY plane
476478
structure_3d = np.zeros((1, 3, 3)) # Only applied in the XY plane
@@ -483,9 +485,9 @@ def _postprocess_seg_3d(seg):
483485
mask = seg[bb] == prop.label
484486

485487
# Fill small holes and apply closing.
486-
mask = remove_small_holes(mask, area_threshold=1000)
487-
mask = np.logical_or(binary_closing(mask, iterations=4), mask)
488-
mask = np.logical_or(binary_closing(mask, iterations=8, structure=structure_3d), mask)
488+
mask = remove_small_holes(mask, area_threshold=area_threshold)
489+
mask = np.logical_or(binary_closing(mask, iterations=iterations), mask)
490+
mask = np.logical_or(binary_closing(mask, iterations=iterations_3d, structure=structure_3d), mask)
489491
seg[bb][mask] = prop.label
490492

491493
return seg

synaptic_reconstruction/napari.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ categories: ["Image Processing", "Annotation"]
55
contributions:
66
commands:
77
- id: synaptic_reconstruction.segment
8-
python_name: synaptic_reconstruction.tools.synaptic_plugin.segmentation_widget:get_segmentation_widget
8+
python_name: synaptic_reconstruction.tools.segmentation_widget:SegmentationWidget
99
title: Segment
1010
- id: synaptic_reconstruction.distance_measure
11-
python_name: synaptic_reconstruction.tools.synaptic_plugin.distance_measure_widget:get_distance_measure_widget
11+
python_name: synaptic_reconstruction.tools.distance_measure_widget:DistanceMeasureWidget
1212
title: Distance Measurement
1313
- id: synaptic_reconstruction.file_reader
1414
title: Read volumetric data

synaptic_reconstruction/tools/synaptic_plugin/base_widget.py renamed to synaptic_reconstruction/tools/base_widget.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from pathlib import Path
2+
23
import napari
3-
from qtpy.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QFileDialog, QLabel, QSpinBox, QLineEdit, QGroupBox, QFormLayout, QFrame, QComboBox, QCheckBox
44
import qtpy.QtWidgets as QtWidgets
5+
6+
from qtpy.QtWidgets import (
7+
QWidget, QVBoxLayout, QHBoxLayout, QLabel, QSpinBox, QComboBox, QCheckBox
8+
)
59
from superqt import QCollapsible
6-
from magicgui.widgets import create_widget
710

811

912
class BaseWidget(QWidget):
@@ -15,7 +18,7 @@ def __init__(self):
1518
def _create_layer_selector(self, selector_name, layer_type="Image"):
1619
"""
1720
Create a layer selector for an image or labels and store it in a dictionary.
18-
21+
1922
Parameters:
2023
- selector_name (str): The name of the selector, used as a key in the dictionary.
2124
- layer_type (str): The type of layer to filter for ("Image" or "Labels").
@@ -34,24 +37,24 @@ def _create_layer_selector(self, selector_name, layer_type="Image"):
3437
selector_widget = QtWidgets.QWidget()
3538
image_selector = QtWidgets.QComboBox()
3639
layer_label = QtWidgets.QLabel(f"{selector_name} Layer:")
37-
40+
3841
# Populate initial options
3942
self._update_selector(selector=image_selector, layer_filter=layer_filter)
40-
43+
4144
# Update selector on layer events
4245
self.viewer.layers.events.inserted.connect(lambda event: self._update_selector(image_selector, layer_filter))
4346
self.viewer.layers.events.removed.connect(lambda event: self._update_selector(image_selector, layer_filter))
4447

4548
# Store the selector in the dictionary
4649
self.layer_selectors[selector_name] = selector_widget
47-
50+
4851
# Set up layout
4952
layout = QVBoxLayout()
5053
layout.addWidget(layer_label)
5154
layout.addWidget(image_selector)
5255
selector_widget.setLayout(layout)
5356
return selector_widget
54-
57+
5558
def _update_selector(self, selector, layer_filter):
5659
"""Update a single selector with the current image layers in the viewer."""
5760
selector.clear()
@@ -62,10 +65,10 @@ def _get_layer_selector_data(self, selector_name):
6265
"""Return the data for the layer currently selected in a given selector."""
6366
if selector_name in self.layer_selectors:
6467
selector_widget = self.layer_selectors[selector_name]
65-
68+
6669
# Retrieve the QComboBox from the QWidget's layout
6770
image_selector = selector_widget.layout().itemAt(1).widget()
68-
71+
6972
if isinstance(image_selector, QComboBox):
7073
selected_layer_name = image_selector.currentText()
7174
if selected_layer_name in self.viewer.layers:
@@ -176,7 +179,7 @@ def _make_collapsible(self, widget, title):
176179
collapsible.addWidget(widget)
177180
parent_widget.layout().addWidget(collapsible)
178181
return parent_widget
179-
182+
180183
def _add_boolean_param(self, name, value, title=None, tooltip=None):
181184
checkbox = QCheckBox(name if title is None else title)
182185
checkbox.setChecked(value)

synaptic_reconstruction/tools/cli.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import argparse
2+
from functools import partial
3+
4+
from .util import run_segmentation, get_model
5+
from ..inference.util import inference_helper, parse_tiling
6+
7+
8+
# TODO: handle kwargs
9+
def segmentation_cli():
10+
parser = argparse.ArgumentParser(description="Run segmentation.")
11+
parser.add_argument(
12+
"--input_path", "-i", required=True,
13+
help="The filepath to the mrc file or the directory containing the tomogram data."
14+
)
15+
parser.add_argument(
16+
"--output_path", "-o", required=True,
17+
help="The filepath to directory where the segmentations will be saved."
18+
)
19+
parser.add_argument(
20+
"--model", "-m", required=True, help="The model type."
21+
)
22+
parser.add_argument(
23+
"--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation."
24+
"Can also be a directory with tifs if the filestructure matches input_path."
25+
)
26+
parser.add_argument("--input_key", "-k", required=False)
27+
parser.add_argument(
28+
"--force", action="store_true",
29+
help="Whether to over-write already present segmentation results."
30+
)
31+
parser.add_argument(
32+
"--tile_shape", type=int, nargs=3,
33+
help="The tile shape for prediction. Lower the tile shape if GPU memory is insufficient."
34+
)
35+
parser.add_argument(
36+
"--halo", type=int, nargs=3,
37+
help="The halo for prediction. Increase the halo to minimize boundary artifacts."
38+
)
39+
parser.add_argument(
40+
"--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc."
41+
)
42+
args = parser.parse_args()
43+
44+
model = get_model(args.model)
45+
tiling = parse_tiling(args.tile_shape, args.halo)
46+
47+
segmentation_function = partial(
48+
run_segmentation, model=model, model_type=args.model, verbose=False, tiling=tiling,
49+
)
50+
inference_helper(
51+
args.input_path, args.output_path, segmentation_function,
52+
mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext,
53+
)

0 commit comments

Comments
 (0)