1
- from typing import List , Dict
1
+ import multiprocessing as mp
2
2
from concurrent import futures
3
+ from functools import partial
4
+ from typing import List , Dict , Optional
3
5
4
6
import numpy as np
5
7
6
8
from scipy .ndimage import binary_erosion , binary_dilation
7
- from skimage .filters import gaussian , sobel
9
+ from skimage import img_as_ubyte
10
+ from skimage .filters import gaussian , rank , sato , sobel
8
11
from skimage .measure import regionprops
12
+ from skimage .morphology import disk
9
13
from skimage .segmentation import watershed
10
14
from tqdm import tqdm
11
15
14
18
except ImportError :
15
19
vigra = None
16
20
17
- FILTERS = ("sobel" , "laplace" , "ggm" , "structure-tensor" )
21
+ FILTERS = ("sobel" , "laplace" , "ggm" , "structure-tensor" , "sato" )
22
+
23
+
24
+ def _sato_filter (raw , sigma , max_window = 16 ):
25
+ if raw .ndim != 2 :
26
+ raise NotImplementedError ("The sato filter is only implemented for 2D data." )
27
+ hmap = sato (raw )
28
+ hmap = gaussian (hmap , sigma = sigma )
29
+ hmap -= hmap .min ()
30
+ hmap /= hmap .max ()
31
+ hmap = rank .autolevel (img_as_ubyte (hmap ), disk (max_window )).astype ("float" ) / 255
32
+ return hmap
18
33
19
34
20
35
def edge_filter (
21
36
data : np .ndarray ,
22
37
sigma : float ,
23
38
method : str = "sobel" ,
24
- per_slice : bool = False ,
39
+ per_slice : bool = True ,
40
+ n_threads : Optional [int ] = None ,
25
41
) -> np .ndarray :
26
42
"""Find edges in the image data.
27
43
@@ -33,7 +49,9 @@ def edge_filter(
33
49
- "laplace": Edges are found with a laplacian of gaussian filter.
34
50
- "ggm": Edges are found with a gaussian gradient magnitude filter.
35
51
- "structure-tensor": Edges are found based on the 2nd eigenvalue of the structure tensor.
52
+ - "sato": Edges are found with a sato-filter, followed by smoothing and leveling.
36
53
per_slice: Compute the filter per slice instead of for the whole volume.
54
+ n_threads: Number of threads for parallel computation over the slices.
37
55
Returns:
38
56
Volume with edge strength.
39
57
"""
@@ -42,10 +60,12 @@ def edge_filter(
42
60
if method in FILTERS [1 :] and vigra is None :
43
61
raise ValueError (f"Filter { method } requires vigra." )
44
62
45
- if per_slice :
46
- edge_map = np .zeros (data .shape , dtype = "float32" )
47
- for z in range (data .shape [0 ]):
48
- edge_map [z ] = edge_filter (data [z ], sigma = sigma , method = method )
63
+ if per_slice and data .ndim == 2 :
64
+ n_threads = mp .cpu_count () if n_threads is None else n_threads
65
+ filter_func = partial (edge_filter , sigma = sigma , method = method , per_slice = False )
66
+ with futures .ThreadPoolExecutor (n_threads ) as tp :
67
+ edge_map = tp .map (filter_func , data )
68
+ edge_map = np .stack (edge_map )
49
69
return edge_map
50
70
51
71
if method == "sobel" :
@@ -60,6 +80,8 @@ def edge_filter(
60
80
edge_map = vigra .filters .structureTensorEigenvalues (
61
81
data .astype ("float32" ), innerScale = inner_scale , outerScale = outer_scale
62
82
)[..., 1 ]
83
+ elif method == "sato" :
84
+ edge_map = _sato_filter (data , sigma )
63
85
64
86
return edge_map
65
87
0 commit comments