Skip to content

Commit 95e8788

Browse files
committed
fix(nnUNet): Correct background mask in BraTS preprocessor
1 parent 729963d commit 95e8788

File tree

1 file changed

+55
-14
lines changed

1 file changed

+55
-14
lines changed

PyTorch/Segmentation/nnUNet/data_preprocessing/preprocessor.py

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,15 @@
2525
from skimage.transform import resize
2626
from utils.utils import get_task_code, make_empty_dir
2727

28-
from data_preprocessing.configs import ct_max, ct_mean, ct_min, ct_std, patch_size, spacings, task
28+
from data_preprocessing.configs import (
29+
ct_max,
30+
ct_mean,
31+
ct_min,
32+
ct_std,
33+
patch_size,
34+
spacings,
35+
task,
36+
)
2937

3038

3139
class Preprocessor:
@@ -45,9 +53,15 @@ def __init__(self, args):
4553
self.ct_min, self.ct_max, self.ct_mean, self.ct_std = (0,) * 4
4654
if not self.training:
4755
self.results = os.path.join(self.results, self.args.exec_mode)
48-
self.crop_foreg = transforms.CropForegroundd(keys=["image", "label"], source_key="image")
49-
nonzero = True if self.modality != "CT" else False # normalize only non-zero region for MRI
50-
self.normalize_intensity = transforms.NormalizeIntensity(nonzero=nonzero, channel_wise=True)
56+
self.crop_foreg = transforms.CropForegroundd(
57+
keys=["image", "label"], source_key="image"
58+
)
59+
nonzero = (
60+
True if self.modality != "CT" else False
61+
) # normalize only non-zero region for MRI
62+
self.normalize_intensity = transforms.NormalizeIntensity(
63+
nonzero=nonzero, channel_wise=True
64+
)
5165
if self.args.exec_mode == "val":
5266
dataset_json = json.load(open(metadata_path, "r"))
5367
dataset_json["val"] = dataset_json["training"]
@@ -76,7 +90,9 @@ def run(self):
7690
_mean = round(self.ct_mean, 2)
7791
_std = round(self.ct_std, 2)
7892
if self.verbose:
79-
print(f"[CT] min: {self.ct_min}, max: {self.ct_max}, mean: {_mean}, std: {_std}")
93+
print(
94+
f"[CT] min: {self.ct_min}, max: {self.ct_max}, mean: {_mean}, std: {_std}"
95+
)
8096

8197
self.run_parallel(self.preprocess_pair, self.args.exec_mode)
8298

@@ -114,7 +130,7 @@ def preprocess_pair(self, pair):
114130
if self.args.ohe:
115131
mask = np.ones(image.shape[1:], dtype=np.float32)
116132
for i in range(image.shape[0]):
117-
zeros = np.where(image[i] <= 0)
133+
zeros = np.where(image[i] == 0)
118134
mask[zeros] *= 0.0
119135
image = self.normalize_intensity(image).astype(np.float32)
120136
mask = np.expand_dims(mask, 0)
@@ -131,15 +147,28 @@ def standardize(self, image, label):
131147
pad_shape = self.calculate_pad_shape(image)
132148
image_shape = image.shape[1:]
133149
if pad_shape != image_shape:
134-
paddings = [(pad_sh - image_sh) / 2 for (pad_sh, image_sh) in zip(pad_shape, image_shape)]
150+
paddings = [
151+
(pad_sh - image_sh) / 2
152+
for (pad_sh, image_sh) in zip(pad_shape, image_shape)
153+
]
135154
image = self.pad(image, paddings)
136155
label = self.pad(label, paddings)
137156
if self.args.dim == 2: # Center cropping 2D images.
138157
_, _, height, weight = image.shape
139158
start_h = (height - self.patch_size[0]) // 2
140159
start_w = (weight - self.patch_size[1]) // 2
141-
image = image[:, :, start_h : start_h + self.patch_size[0], start_w : start_w + self.patch_size[1]]
142-
label = label[:, :, start_h : start_h + self.patch_size[0], start_w : start_w + self.patch_size[1]]
160+
image = image[
161+
:,
162+
:,
163+
start_h : start_h + self.patch_size[0],
164+
start_w : start_w + self.patch_size[1],
165+
]
166+
label = label[
167+
:,
168+
:,
169+
start_h : start_h + self.patch_size[0],
170+
start_w : start_w + self.patch_size[1],
171+
]
143172
return image, label
144173

145174
def normalize(self, image):
@@ -148,7 +177,9 @@ def normalize(self, image):
148177
return self.normalize_intensity(image)
149178

150179
def save(self, image, label, fname, image_metadata):
151-
mean, std = np.round(np.mean(image, (1, 2, 3)), 2), np.round(np.std(image, (1, 2, 3)), 2)
180+
mean, std = np.round(np.mean(image, (1, 2, 3)), 2), np.round(
181+
np.std(image, (1, 2, 3)), 2
182+
)
152183
if self.verbose:
153184
print(f"Saving {fname} shape {image.shape} mean {mean} std {std}")
154185
self.save_npy(image, fname, "_x.npy")
@@ -191,7 +222,9 @@ def calculate_pad_shape(self, image):
191222
image_shape = image.shape[1:]
192223
if len(min_shape) == 2: # In 2D case we don't want to pad depth axis.
193224
min_shape.insert(0, image_shape[0])
194-
pad_shape = [max(mshape, ishape) for mshape, ishape in zip(min_shape, image_shape)]
225+
pad_shape = [
226+
max(mshape, ishape) for mshape, ishape in zip(min_shape, image_shape)
227+
]
195228
return pad_shape
196229

197230
def get_intensities(self, pair):
@@ -233,10 +266,16 @@ def calculate_new_shape(self, spacing, shape):
233266
return new_shape
234267

235268
def save_npy(self, image, fname, suffix):
236-
np.save(os.path.join(self.results, fname.replace(".nii.gz", suffix)), image, allow_pickle=False)
269+
np.save(
270+
os.path.join(self.results, fname.replace(".nii.gz", suffix)),
271+
image,
272+
allow_pickle=False,
273+
)
237274

238275
def run_parallel(self, func, exec_mode):
239-
return Parallel(n_jobs=self.args.n_jobs)(delayed(func)(pair) for pair in self.metadata[exec_mode])
276+
return Parallel(n_jobs=self.args.n_jobs)(
277+
delayed(func)(pair) for pair in self.metadata[exec_mode]
278+
)
240279

241280
def load_nifty(self, fname):
242281
return nibabel.load(os.path.join(self.data_path, fname))
@@ -266,7 +305,9 @@ def standardize_layout(data):
266305

267306
@staticmethod
268307
def resize_fn(image, shape, order, mode):
269-
return resize(image, shape, order=order, mode=mode, cval=0, clip=True, anti_aliasing=False)
308+
return resize(
309+
image, shape, order=order, mode=mode, cval=0, clip=True, anti_aliasing=False
310+
)
270311

271312
def resample_anisotrophic_image(self, image, shape):
272313
resized_channels = []

0 commit comments

Comments
 (0)