7
7
import numpy as np
8
8
import torch
9
9
import torch .nn as nn
10
+ import yaml
10
11
from pathos .multiprocessing import ThreadPool as Pool
11
12
from torch .utils .data import DataLoader
12
13
from tqdm import tqdm
@@ -31,6 +32,7 @@ def __init__(
31
32
batch_size : int = 8 ,
32
33
normalization : str = None ,
33
34
device : str = "cuda" ,
35
+ n_devices : int = 1 ,
34
36
save_masks : bool = True ,
35
37
save_intermediate : bool = False ,
36
38
save_dir : Union [Path , str ] = None ,
@@ -72,6 +74,9 @@ def __init__(
72
74
One of: "dataset", "minmax", "norm", "percentile", None.
73
75
device : str, default="cuda"
74
76
The device of the input and model. One of: "cuda", "cpu"
77
+ n_devices : int, default=1
78
+ Number of devices (cpus/gpus) used for inference.
79
+ The model will be copied into these devices.
75
80
save_masks : bool, default=False
76
81
If True, the resulting segmentation masks will be saved into `out_masks`
77
82
variable.
@@ -95,6 +100,16 @@ def __init__(
95
100
**postproc_kwargs:
96
101
Arbitrary keyword arguments for the post-processing.
97
102
"""
103
+ # basic inits
104
+ self .model = model
105
+ self .out_heads = self ._get_out_info () # the names and num channels of out heads
106
+ self .batch_size = batch_size
107
+ self .patch_size = patch_size
108
+ self .padding = padding
109
+ self .out_activations = out_activations
110
+ self .out_boundary_weights = out_boundary_weights
111
+ self .head_kwargs = self ._check_and_set_head_args ()
112
+
98
113
self .save_dir = Path (save_dir ) if save_dir is not None else None
99
114
self .save_masks = save_masks
100
115
self .save_intermediate = save_intermediate
@@ -106,17 +121,17 @@ def __init__(
106
121
folder_ds , batch_size = batch_size , shuffle = False , pin_memory = True
107
122
)
108
123
109
- # model and device
110
- self .model = model
111
- if device == "cpu" :
112
- self .model .cpu ()
113
- self .device = torch .device ("cpu" )
114
- if torch .cuda .is_available () and device == "cuda" :
115
- self .model .cuda ()
116
- self .device = torch .device ("cuda" )
117
-
118
- self .model .eval ()
124
+ # Set post processor
125
+ self .postprocessor = PostProcessor (
126
+ instance_postproc ,
127
+ inst_key = self .model .inst_key ,
128
+ aux_key = self .model .aux_key ,
129
+ type_post_proc = type_post_proc ,
130
+ sem_post_proc = sem_post_proc ,
131
+ ** postproc_kwargs ,
132
+ )
119
133
134
+ # load weights and set devices
120
135
if checkpoint_path is not None :
121
136
ckpt = torch .load (
122
137
checkpoint_path , map_location = lambda storage , loc : storage
@@ -130,30 +145,41 @@ def __init__(
130
145
except BaseException as e :
131
146
print (e )
132
147
133
- #
148
+ assert device in ("cuda" , "cpu" )
149
+ if device == "cpu" :
150
+ self .device = torch .device ("cpu" )
151
+ if torch .cuda .is_available () and device == "cuda" :
152
+ self .device = torch .device ("cuda" )
153
+
154
+ if torch .cuda .device_count () > 1 and n_devices > 1 :
155
+ self .model = nn .DataParallel (self .model , device_ids = range (n_devices ))
156
+
157
+ self .model .to (self .device )
158
+ self .model .eval ()
159
+
160
+ # Helper class to perform forward + extra processing
134
161
self .predictor = Predictor (
135
162
model = self .model ,
136
163
patch_size = patch_size ,
137
164
normalization = normalization ,
138
165
device = self .device ,
139
166
)
140
- self .out_heads = self ._get_out_info () # the names and num channels of out heads
141
- self .batch_size = batch_size
142
- self .patch_size = patch_size
143
- self .padding = padding
144
- self .out_activations = out_activations
145
- self .out_boundary_weights = out_boundary_weights
146
- self .head_kwargs = self ._check_and_set_head_args ()
147
167
148
- #
149
- self .postprocessor = PostProcessor (
150
- instance_postproc ,
151
- inst_key = self .model .inst_key ,
152
- aux_key = self .model .aux_key ,
153
- type_post_proc = type_post_proc ,
154
- sem_post_proc = sem_post_proc ,
155
- ** postproc_kwargs ,
156
- )
168
+ @classmethod
169
+ def from_yaml (cls , model : nn .Module , yaml_path : str ):
170
+ """Initialize the inferer from a yaml-file.
171
+
172
+ Parameters
173
+ ----------
174
+ model : nn.Module
175
+ Initialized segmentation model.
176
+ yaml_path : str
177
+ Path to the yaml file containing rest of the params
178
+ """
179
+ with open (yaml_path , "r" ) as stream :
180
+ kwargs = yaml .full_load (stream )
181
+
182
+ return cls (model , ** kwargs )
157
183
158
184
@abstractmethod
159
185
def _infer_batch (self ):
0 commit comments