Skip to content

Commit 48c0a0a

Browse files
committed
feat(infernce): multi-gpu inference support
1 parent 00c0440 commit 48c0a0a

File tree

5 files changed

+75
-32
lines changed

5 files changed

+75
-32
lines changed

cellseg_models_pytorch/inference/_base_inferer.py

Lines changed: 53 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
import torch
99
import torch.nn as nn
10+
import yaml
1011
from pathos.multiprocessing import ThreadPool as Pool
1112
from torch.utils.data import DataLoader
1213
from tqdm import tqdm
@@ -31,6 +32,7 @@ def __init__(
3132
batch_size: int = 8,
3233
normalization: str = None,
3334
device: str = "cuda",
35+
n_devices: int = 1,
3436
save_masks: bool = True,
3537
save_intermediate: bool = False,
3638
save_dir: Union[Path, str] = None,
@@ -72,6 +74,9 @@ def __init__(
7274
One of: "dataset", "minmax", "norm", "percentile", None.
7375
device : str, default="cuda"
7476
The device of the input and model. One of: "cuda", "cpu"
77+
n_devices : int, default=1
78+
Number of devices (cpus/gpus) used for inference.
79+
The model will be copied into these devices.
7580
save_masks : bool, default=False
7681
If True, the resulting segmentation masks will be saved into `out_masks`
7782
variable.
@@ -95,6 +100,16 @@ def __init__(
95100
**postproc_kwargs:
96101
Arbitrary keyword arguments for the post-processing.
97102
"""
103+
# basic inits
104+
self.model = model
105+
self.out_heads = self._get_out_info() # the names and num channels of out heads
106+
self.batch_size = batch_size
107+
self.patch_size = patch_size
108+
self.padding = padding
109+
self.out_activations = out_activations
110+
self.out_boundary_weights = out_boundary_weights
111+
self.head_kwargs = self._check_and_set_head_args()
112+
98113
self.save_dir = Path(save_dir) if save_dir is not None else None
99114
self.save_masks = save_masks
100115
self.save_intermediate = save_intermediate
@@ -106,17 +121,17 @@ def __init__(
106121
folder_ds, batch_size=batch_size, shuffle=False, pin_memory=True
107122
)
108123

109-
# model and device
110-
self.model = model
111-
if device == "cpu":
112-
self.model.cpu()
113-
self.device = torch.device("cpu")
114-
if torch.cuda.is_available() and device == "cuda":
115-
self.model.cuda()
116-
self.device = torch.device("cuda")
117-
118-
self.model.eval()
124+
# Set post processor
125+
self.postprocessor = PostProcessor(
126+
instance_postproc,
127+
inst_key=self.model.inst_key,
128+
aux_key=self.model.aux_key,
129+
type_post_proc=type_post_proc,
130+
sem_post_proc=sem_post_proc,
131+
**postproc_kwargs,
132+
)
119133

134+
# load weights and set devices
120135
if checkpoint_path is not None:
121136
ckpt = torch.load(
122137
checkpoint_path, map_location=lambda storage, loc: storage
@@ -130,30 +145,41 @@ def __init__(
130145
except BaseException as e:
131146
print(e)
132147

133-
#
148+
assert device in ("cuda", "cpu")
149+
if device == "cpu":
150+
self.device = torch.device("cpu")
151+
if torch.cuda.is_available() and device == "cuda":
152+
self.device = torch.device("cuda")
153+
154+
if torch.cuda.device_count() > 1 and n_devices > 1:
155+
self.model = nn.DataParallel(self.model, device_ids=range(n_devices))
156+
157+
self.model.to(self.device)
158+
self.model.eval()
159+
160+
# Helper class to perform forward + extra processing
134161
self.predictor = Predictor(
135162
model=self.model,
136163
patch_size=patch_size,
137164
normalization=normalization,
138165
device=self.device,
139166
)
140-
self.out_heads = self._get_out_info() # the names and num channels of out heads
141-
self.batch_size = batch_size
142-
self.patch_size = patch_size
143-
self.padding = padding
144-
self.out_activations = out_activations
145-
self.out_boundary_weights = out_boundary_weights
146-
self.head_kwargs = self._check_and_set_head_args()
147167

148-
#
149-
self.postprocessor = PostProcessor(
150-
instance_postproc,
151-
inst_key=self.model.inst_key,
152-
aux_key=self.model.aux_key,
153-
type_post_proc=type_post_proc,
154-
sem_post_proc=sem_post_proc,
155-
**postproc_kwargs,
156-
)
168+
@classmethod
169+
def from_yaml(cls, model: nn.Module, yaml_path: str):
170+
"""Initialize the inferer from a yaml-file.
171+
172+
Parameters
173+
----------
174+
model : nn.Module
175+
Initialized segmentation model.
176+
yaml_path : str
177+
Path to the yaml file containing rest of the params
178+
"""
179+
with open(yaml_path, "r") as stream:
180+
kwargs = yaml.full_load(stream)
181+
182+
return cls(model, **kwargs)
157183

158184
@abstractmethod
159185
def _infer_batch(self):

cellseg_models_pytorch/inference/predictor.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,11 +184,15 @@ def classify(
184184
if apply_weights:
185185
# work out the tensor shape first for the weight mat
186186
B, C = patch.shape[:2]
187-
W = torch.repeat_interleave(
188-
self.weight_mat,
189-
dim=1,
190-
repeats=C,
191-
).repeat_interleave(repeats=B, dim=0)
187+
W = (
188+
torch.repeat_interleave(
189+
self.weight_mat,
190+
dim=1,
191+
repeats=C,
192+
)
193+
.repeat_interleave(repeats=B, dim=0)
194+
.to(patch.device)
195+
)
192196
patch *= W
193197

194198
# apply classification activation

cellseg_models_pytorch/inference/resize_inferer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def __init__(
2121
batch_size: int = 8,
2222
normalization: str = None,
2323
device: str = "cuda",
24+
n_devices: int = 1,
2425
save_masks: bool = True,
2526
save_intermediate: bool = False,
2627
save_dir: Union[Path, str] = None,
@@ -69,6 +70,9 @@ def __init__(
6970
"minmax", "norm", "percentile", None.
7071
device : str, default="cuda"
7172
The device of the input and model. One of: "cuda", "cpu"
73+
n_devices : int, default=1
74+
Number of devices (cpus/gpus) used for inference.
75+
The model will be copied into these devices.
7276
save_masks : bool, default=False
7377
If True, the resulting segmentation masks will be saved into `out_masks`
7478
variable.
@@ -103,6 +107,7 @@ def __init__(
103107
normalization=normalization,
104108
instance_postproc=instance_postproc,
105109
device=device,
110+
n_devices=n_devices,
106111
save_masks=save_masks,
107112
save_intermediate=save_intermediate,
108113
save_dir=save_dir,

cellseg_models_pytorch/inference/sliding_window_inferer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def __init__(
2424
batch_size: int = 8,
2525
normalization: str = None,
2626
device: str = "cuda",
27+
n_devices: int = 1,
2728
save_masks: bool = True,
2829
save_intermediate: bool = False,
2930
save_dir: Union[Path, str] = None,
@@ -71,6 +72,9 @@ def __init__(
7172
"minmax", "norm", "percentile", None.
7273
device : str, default="cuda"
7374
The device of the input and model. One of: "cuda", "cpu"
75+
n_devices : int, default=1
76+
Number of devices (cpus/gpus) used for inference.
77+
The model will be copied into these devices.
7478
save_masks : bool, default=False
7579
If True, the resulting segmentation masks will be saved into `out_masks`
7680
variable.
@@ -110,6 +114,7 @@ def __init__(
110114
save_dir=save_dir,
111115
checkpoint_path=checkpoint_path,
112116
n_images=n_images,
117+
n_devices=n_devices,
113118
type_post_proc=type_post_proc,
114119
sem_post_proc=sem_post_proc,
115120
**postproc_kwargs,
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
## Features
2+
3+
- Add multi-gpu inference via DataParallel

0 commit comments

Comments
 (0)