Skip to content

Commit 2f82384

Browse files
committed
feat(inference): add parallel postproc utils
1 parent afde4de commit 2f82384

File tree

11 files changed

+272
-57
lines changed

11 files changed

+272
-57
lines changed

cellseg_models_pytorch/inference/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from ._base_inferer import BaseInferer
12
from .folder_dataset import FolderDataset
23
from .post_processor import PostProcessor
34
from .predictor import Predictor
45
from .resize_inferer import ResizeInferer
56
from .sliding_window_inferer import SlidingWindowInferer
67

78
__all__ = [
9+
"BaseInferer",
810
"Predictor",
911
"PostProcessor",
1012
"ResizeInferer",

cellseg_models_pytorch/inference/_base_inferer.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections import OrderedDict
33
from itertools import chain
44
from pathlib import Path
5-
from typing import Dict, List, Tuple, Union
5+
from typing import Callable, Dict, List, Tuple, Union
66

77
import numpy as np
88
import torch
@@ -36,6 +36,8 @@ def __init__(
3636
save_dir: Union[Path, str] = None,
3737
checkpoint_path: Union[Path, str] = None,
3838
n_images: int = None,
39+
type_post_proc: Callable = None,
40+
sem_post_proc: Callable = None,
3941
**postproc_kwargs,
4042
) -> None:
4143
"""Inference for an image folder.
@@ -84,6 +86,12 @@ def __init__(
8486
Path to the model weight checkpoints.
8587
n_images : int, optional
8688
First n-number of images used from the `input_folder`.
89+
type_post_proc : Callable, optional
90+
A post-processing function for the type maps. If not None, overrides
91+
the default.
92+
sem_post_proc : Callable, optional
93+
A post-processing function for the semantc seg maps. If not None,
94+
overrides the default.
8795
**postproc_kwargs:
8896
Arbitrary keyword arguments for the post-processing.
8997
"""
@@ -142,6 +150,8 @@ def __init__(
142150
instance_postproc,
143151
inst_key=self.model.inst_key,
144152
aux_key=self.model.aux_key,
153+
type_post_proc=type_post_proc,
154+
sem_post_proc=sem_post_proc,
145155
**postproc_kwargs,
146156
)
147157

cellseg_models_pytorch/inference/post_processor.py

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
1-
from typing import Dict, List
1+
from typing import Callable, Dict, List
22

33
import numpy as np
44
from pathos.multiprocessing import ThreadPool as Pool
5-
from skimage.filters import rank
6-
from skimage.morphology import closing, disk, opening
75
from skimage.util import img_as_ubyte
86
from tqdm import tqdm
97

108
from ..postproc import POSTPROC_LOOKUP
11-
from ..utils import binarize, fill_holes_semantic, remove_debris_semantic
9+
from ..utils import (
10+
fill_holes_semantic,
11+
majority_vote_parallel,
12+
majority_vote_sequential,
13+
med_filt_parallel,
14+
med_filt_sequential,
15+
remove_debris_semantic,
16+
)
1217

1318
__all__ = ["PostProcessor"]
1419

@@ -28,7 +33,13 @@
2833

2934
class PostProcessor:
3035
def __init__(
31-
self, instance_postproc: str, inst_key: str, aux_key: str, **kwargs
36+
self,
37+
instance_postproc: str,
38+
inst_key: str,
39+
aux_key: str,
40+
type_post_proc: Callable = None,
41+
sem_post_proc: Callable = None,
42+
**kwargs,
3243
) -> None:
3344
"""Multi-threaded post-processing.
3445
@@ -42,6 +53,12 @@ def __init__(
4253
aux_key : Tuple[str, ...]:
4354
The key/name of the model auxilliary output that is used for the
4455
instance segmentation post-processing pipeline.
56+
type_post_proc : Callable, optional
57+
A post-processing function for the type maps. If not None, overrides
58+
the default.
59+
sem_post_proc : Callable, optional
60+
A post-processing function for the semantc seg maps. If not None,
61+
overrides the default.
4562
**kwargs
4663
Arbitrary keyword arguments that can be used for any of the private
4764
post-processing functions of this class.
@@ -57,36 +74,32 @@ def __init__(
5774
self.inst_key = inst_key
5875
self.aux_key = aux_key
5976
self.kwargs = kwargs
77+
self.sem_post_proc = sem_post_proc
78+
self.type_post_proc = type_post_proc
6079

6180
def _get_sem_map(
6281
self,
6382
prob_map: np.ndarray,
64-
use_blur: bool = False,
65-
use_closing: bool = False,
66-
use_opening: bool = True,
67-
disk_size: int = 10,
83+
parallel: bool = False,
84+
kernel_width: int = 15,
6885
**kwargs,
6986
) -> np.ndarray:
7087
"""Run the semantic segmentation post-processing."""
71-
# Median filtering to get rid of noise. Adds a lot of overhead sop optional.
72-
if use_blur:
73-
sem = np.zeros_like(prob_map)
74-
for i in range(prob_map.shape[0]):
75-
sem[i, ...] = rank.median(
76-
img_as_ubyte(prob_map[i, ...]), footprint=disk(disk_size)
77-
)
78-
prob_map = sem
79-
80-
sem = np.argmax(prob_map, axis=0)
88+
sem_map = img_as_ubyte(prob_map)
8189

82-
if use_opening:
83-
sem = opening(sem, disk(disk_size))
84-
85-
if use_closing:
86-
sem = closing(sem, disk(disk_size))
90+
if self.sem_post_proc is not None:
91+
sem = self.sem_post_proc(sem_map)
92+
else:
93+
if parallel:
94+
sem = med_filt_parallel(
95+
sem_map, kernel_size=(kernel_width, kernel_width)
96+
)
97+
else:
98+
sem = med_filt_sequential(sem_map, kernel_width)
8799

88-
sem = remove_debris_semantic(sem)
89-
sem = fill_holes_semantic(sem)
100+
sem = np.argmax(sem, axis=0)
101+
sem = remove_debris_semantic(sem)
102+
sem = fill_holes_semantic(sem)
90103

91104
return sem
92105

@@ -105,31 +118,19 @@ def _get_type_map(
105118
self,
106119
prob_map: np.ndarray,
107120
inst_map: np.ndarray,
108-
use_mask: bool = False,
121+
parallel: bool = True,
109122
**kwargs,
110123
) -> np.ndarray:
111-
"""Run the type map post-processing. Majority voting for each instance.
112-
113-
Adapted from:
114-
https://github.com/vqdang/hover_net/blob/master/models/hovernet/post_proc.py
115-
"""
124+
"""Run the type map post-processing. Majority voting for each instance."""
116125
type_map = np.argmax(prob_map, axis=0)
117-
if use_mask:
118-
type_map = binarize(inst_map) * type_map
119-
120-
pred_id_list = np.unique(inst_map)[1:]
121-
for inst_id in pred_id_list:
122-
inst_type = type_map[inst_map == inst_id]
123-
type_list, type_pixels = np.unique(inst_type, return_counts=True)
124-
type_list = list(zip(type_list, type_pixels))
125-
type_list = sorted(type_list, key=lambda x: x[1], reverse=True)
126-
cell_type = type_list[0][0]
127-
128-
if cell_type == 0:
129-
if len(type_list) > 1:
130-
cell_type = type_list[1][0]
131126

132-
type_map[inst_map == inst_id] = cell_type
127+
if self.type_post_proc is not None:
128+
type_map = self.type_post_proc(type_map, inst_map, **kwargs)
129+
else:
130+
if parallel:
131+
type_map = majority_vote_parallel(type_map, inst_map)
132+
else:
133+
type_map = majority_vote_sequential(type_map, inst_map)
133134

134135
return type_map
135136

@@ -175,7 +176,6 @@ def run_parallel(
175176
progress_bar : bool, default=False
176177
If True, a tqdm progress bar is shown.
177178
178-
179179
Returns
180180
-------
181181
List[Dict[str, np.ndarray]]:

cellseg_models_pytorch/inference/predictor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def classify(
154154
patch: torch.Tensor,
155155
act: Union[str, None] = "softmax",
156156
apply_weights: bool = False,
157-
) -> np.ndarray:
157+
) -> torch.Tensor:
158158
"""Take in logits and output probabilities.
159159
160160
Additionally apply a weight matrix to help with boundary artefacts.
@@ -171,7 +171,7 @@ def classify(
171171
172172
Returns
173173
-------
174-
np.ndarray:
174+
torch.Tensor:
175175
The model prediction. Same shape as input `patch`.
176176
"""
177177
allowed = ("sigmoid", "softmax", "tanh", None)

cellseg_models_pytorch/inference/resize_inferer.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Dict, Tuple, Union
2+
from typing import Callable, Dict, Tuple, Union
33

44
import torch
55
import torch.nn as nn
@@ -26,6 +26,8 @@ def __init__(
2626
save_dir: Union[Path, str] = None,
2727
checkpoint_path: Union[Path, str] = None,
2828
n_images: int = None,
29+
type_post_proc: Callable = None,
30+
sem_post_proc: Callable = None,
2931
**postproc_kwargs,
3032
) -> None:
3133
"""Resize inference for a folder of images.
@@ -81,6 +83,12 @@ def __init__(
8183
Path to the model weight checkpoints.
8284
n_images : int, optional
8385
First n-number of images used from the `input_folder`.
86+
type_post_proc : Callable, optional
87+
A post-processing function for the type maps. If not None, overrides
88+
the default.
89+
sem_post_proc : Callable, optional
90+
A post-processing function for the semantc seg maps. If not None,
91+
overrides the default.
8492
**postproc_kwargs:
8593
Arbitrary keyword arguments for the post-processing.
8694
"""
@@ -100,6 +108,8 @@ def __init__(
100108
save_dir=save_dir,
101109
checkpoint_path=checkpoint_path,
102110
n_images=n_images,
111+
type_post_proc=type_post_proc,
112+
sem_post_proc=sem_post_proc,
103113
**postproc_kwargs,
104114
)
105115

cellseg_models_pytorch/inference/sliding_window_inferer.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Dict, Tuple, Union
2+
from typing import Callable, Dict, Tuple, Union
33

44
import torch
55
import torch.nn as nn
@@ -29,6 +29,8 @@ def __init__(
2929
save_dir: Union[Path, str] = None,
3030
checkpoint_path: Union[Path, str] = None,
3131
n_images: int = None,
32+
type_post_proc: Callable = None,
33+
sem_post_proc: Callable = None,
3234
**postproc_kwargs,
3335
) -> None:
3436
"""Sliding window inference for a folder of images.
@@ -83,6 +85,12 @@ def __init__(
8385
Path to the model weight checkpoints.
8486
n_images : int, optional
8587
First n-number of images used from the `ìnput_folder`.
88+
type_post_proc : Callable, optional
89+
A post-processing function for the type maps. If not None, overrides
90+
the default.
91+
sem_post_proc : Callable, optional
92+
A post-processing function for the semantc seg maps. If not None,
93+
overrides the default.
8694
**postproc_kwargs:
8795
Arbitrary keyword arguments for the post-processing.
8896
"""
@@ -102,6 +110,8 @@ def __init__(
102110
save_dir=save_dir,
103111
checkpoint_path=checkpoint_path,
104112
n_images=n_images,
113+
type_post_proc=type_post_proc,
114+
sem_post_proc=sem_post_proc,
105115
**postproc_kwargs,
106116
)
107117

cellseg_models_pytorch/inference/tests/test_inference.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ def test_slidingwin_inference(img_dir, batch_size):
2020
batch_size=batch_size,
2121
save_intermediate=False,
2222
device="cpu",
23-
use_blur=True,
24-
use_closing=True,
23+
parallel=False,
2524
)
2625

2726
inferer.infer()
@@ -45,8 +44,7 @@ def test_resize_inference(img_dir, batch_size):
4544
batch_size=batch_size,
4645
save_intermediate=False,
4746
device="cpu",
48-
use_blur=True,
49-
use_closing=True,
47+
parallel=False,
5048
)
5149

5250
inferer.infer()

cellseg_models_pytorch/utils/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
get_inst_types,
1919
get_type_instances,
2020
label_semantic,
21+
majority_vote_parallel,
22+
majority_vote_sequential,
23+
med_filt_parallel,
24+
med_filt_sequential,
2125
one_hot,
2226
remap_label,
2327
remove_1px_boundary,
@@ -124,4 +128,8 @@
124128
"NORM_LOOKUP",
125129
"draw_stuff_contours",
126130
"draw_thing_contours",
131+
"majority_vote_sequential",
132+
"majority_vote_parallel",
133+
"med_filt_parallel",
134+
"med_filt_sequential",
127135
]

0 commit comments

Comments
 (0)