Skip to content

Commit 58f3e6b

Browse files
Add sato as better default edge filter
1 parent 238b577 commit 58f3e6b

File tree

1 file changed

+30
-8
lines changed

1 file changed

+30
-8
lines changed

synaptic_reconstruction/ground_truth/shape_refinement.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1-
from typing import List, Dict
1+
import multiprocessing as mp
22
from concurrent import futures
3+
from functools import partial
4+
from typing import List, Dict, Optional
35

46
import numpy as np
57

68
from scipy.ndimage import binary_erosion, binary_dilation
7-
from skimage.filters import gaussian, sobel
9+
from skimage import img_as_ubyte
10+
from skimage.filters import gaussian, rank, sato, sobel
811
from skimage.measure import regionprops
12+
from skimage.morphology import disk
913
from skimage.segmentation import watershed
1014
from tqdm import tqdm
1115

@@ -14,14 +18,26 @@
1418
except ImportError:
1519
vigra = None
1620

17-
FILTERS = ("sobel", "laplace", "ggm", "structure-tensor")
21+
FILTERS = ("sobel", "laplace", "ggm", "structure-tensor", "sato")
22+
23+
24+
def _sato_filter(raw, sigma, max_window=16):
25+
if raw.ndim != 2:
26+
raise NotImplementedError("The sato filter is only implemented for 2D data.")
27+
hmap = sato(raw)
28+
hmap = gaussian(hmap, sigma=sigma)
29+
hmap -= hmap.min()
30+
hmap /= hmap.max()
31+
hmap = rank.autolevel(img_as_ubyte(hmap), disk(max_window)).astype("float") / 255
32+
return hmap
1833

1934

2035
def edge_filter(
2136
data: np.ndarray,
2237
sigma: float,
2338
method: str = "sobel",
24-
per_slice: bool = False,
39+
per_slice: bool = True,
40+
n_threads: Optional[int] = None,
2541
) -> np.ndarray:
2642
"""Find edges in the image data.
2743
@@ -33,7 +49,9 @@ def edge_filter(
3349
- "laplace": Edges are found with a laplacian of gaussian filter.
3450
- "ggm": Edges are found with a gaussian gradient magnitude filter.
3551
- "structure-tensor": Edges are found based on the 2nd eigenvalue of the structure tensor.
52+
- "sato": Edges are found with a sato-filter, followed by smoothing and leveling.
3653
per_slice: Compute the filter per slice instead of for the whole volume.
54+
n_threads: Number of threads for parallel computation over the slices.
3755
Returns:
3856
Volume with edge strength.
3957
"""
@@ -42,10 +60,12 @@ def edge_filter(
4260
if method in FILTERS[1:] and vigra is None:
4361
raise ValueError(f"Filter {method} requires vigra.")
4462

45-
if per_slice:
46-
edge_map = np.zeros(data.shape, dtype="float32")
47-
for z in range(data.shape[0]):
48-
edge_map[z] = edge_filter(data[z], sigma=sigma, method=method)
63+
if per_slice and data.ndim == 2:
64+
n_threads = mp.cpu_count() if n_threads is None else n_threads
65+
filter_func = partial(edge_filter, sigma=sigma, method=method, per_slice=False)
66+
with futures.ThreadPoolExecutor(n_threads) as tp:
67+
edge_map = tp.map(filter_func, data)
68+
edge_map = np.stack(edge_map)
4969
return edge_map
5070

5171
if method == "sobel":
@@ -60,6 +80,8 @@ def edge_filter(
6080
edge_map = vigra.filters.structureTensorEigenvalues(
6181
data.astype("float32"), innerScale=inner_scale, outerScale=outer_scale
6282
)[..., 1]
83+
elif method == "sato":
84+
edge_map = _sato_filter(data, sigma)
6385

6486
return edge_map
6587

0 commit comments

Comments
 (0)