Skip to content

fix(nnUNet): Correct background mask in BraTS preprocessor #1457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 55 additions & 14 deletions PyTorch/Segmentation/nnUNet/data_preprocessing/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,15 @@
from skimage.transform import resize
from utils.utils import get_task_code, make_empty_dir

from data_preprocessing.configs import ct_max, ct_mean, ct_min, ct_std, patch_size, spacings, task
from data_preprocessing.configs import (
ct_max,
ct_mean,
ct_min,
ct_std,
patch_size,
spacings,
task,
)


class Preprocessor:
Expand All @@ -45,9 +53,15 @@ def __init__(self, args):
self.ct_min, self.ct_max, self.ct_mean, self.ct_std = (0,) * 4
if not self.training:
self.results = os.path.join(self.results, self.args.exec_mode)
self.crop_foreg = transforms.CropForegroundd(keys=["image", "label"], source_key="image")
nonzero = True if self.modality != "CT" else False # normalize only non-zero region for MRI
self.normalize_intensity = transforms.NormalizeIntensity(nonzero=nonzero, channel_wise=True)
self.crop_foreg = transforms.CropForegroundd(
keys=["image", "label"], source_key="image"
)
nonzero = (
True if self.modality != "CT" else False
) # normalize only non-zero region for MRI
self.normalize_intensity = transforms.NormalizeIntensity(
nonzero=nonzero, channel_wise=True
)
if self.args.exec_mode == "val":
dataset_json = json.load(open(metadata_path, "r"))
dataset_json["val"] = dataset_json["training"]
Expand Down Expand Up @@ -76,7 +90,9 @@ def run(self):
_mean = round(self.ct_mean, 2)
_std = round(self.ct_std, 2)
if self.verbose:
print(f"[CT] min: {self.ct_min}, max: {self.ct_max}, mean: {_mean}, std: {_std}")
print(
f"[CT] min: {self.ct_min}, max: {self.ct_max}, mean: {_mean}, std: {_std}"
)

self.run_parallel(self.preprocess_pair, self.args.exec_mode)

Expand Down Expand Up @@ -114,7 +130,7 @@ def preprocess_pair(self, pair):
if self.args.ohe:
mask = np.ones(image.shape[1:], dtype=np.float32)
for i in range(image.shape[0]):
zeros = np.where(image[i] <= 0)
zeros = np.where(image[i] == 0)
mask[zeros] *= 0.0
image = self.normalize_intensity(image).astype(np.float32)
mask = np.expand_dims(mask, 0)
Expand All @@ -131,15 +147,28 @@ def standardize(self, image, label):
pad_shape = self.calculate_pad_shape(image)
image_shape = image.shape[1:]
if pad_shape != image_shape:
paddings = [(pad_sh - image_sh) / 2 for (pad_sh, image_sh) in zip(pad_shape, image_shape)]
paddings = [
(pad_sh - image_sh) / 2
for (pad_sh, image_sh) in zip(pad_shape, image_shape)
]
image = self.pad(image, paddings)
label = self.pad(label, paddings)
if self.args.dim == 2: # Center cropping 2D images.
_, _, height, weight = image.shape
start_h = (height - self.patch_size[0]) // 2
start_w = (weight - self.patch_size[1]) // 2
image = image[:, :, start_h : start_h + self.patch_size[0], start_w : start_w + self.patch_size[1]]
label = label[:, :, start_h : start_h + self.patch_size[0], start_w : start_w + self.patch_size[1]]
image = image[
:,
:,
start_h : start_h + self.patch_size[0],
start_w : start_w + self.patch_size[1],
]
label = label[
:,
:,
start_h : start_h + self.patch_size[0],
start_w : start_w + self.patch_size[1],
]
return image, label

def normalize(self, image):
Expand All @@ -148,7 +177,9 @@ def normalize(self, image):
return self.normalize_intensity(image)

def save(self, image, label, fname, image_metadata):
mean, std = np.round(np.mean(image, (1, 2, 3)), 2), np.round(np.std(image, (1, 2, 3)), 2)
mean, std = np.round(np.mean(image, (1, 2, 3)), 2), np.round(
np.std(image, (1, 2, 3)), 2
)
if self.verbose:
print(f"Saving {fname} shape {image.shape} mean {mean} std {std}")
self.save_npy(image, fname, "_x.npy")
Expand Down Expand Up @@ -191,7 +222,9 @@ def calculate_pad_shape(self, image):
image_shape = image.shape[1:]
if len(min_shape) == 2: # In 2D case we don't want to pad depth axis.
min_shape.insert(0, image_shape[0])
pad_shape = [max(mshape, ishape) for mshape, ishape in zip(min_shape, image_shape)]
pad_shape = [
max(mshape, ishape) for mshape, ishape in zip(min_shape, image_shape)
]
return pad_shape

def get_intensities(self, pair):
Expand Down Expand Up @@ -233,10 +266,16 @@ def calculate_new_shape(self, spacing, shape):
return new_shape

def save_npy(self, image, fname, suffix):
np.save(os.path.join(self.results, fname.replace(".nii.gz", suffix)), image, allow_pickle=False)
np.save(
os.path.join(self.results, fname.replace(".nii.gz", suffix)),
image,
allow_pickle=False,
)

def run_parallel(self, func, exec_mode):
return Parallel(n_jobs=self.args.n_jobs)(delayed(func)(pair) for pair in self.metadata[exec_mode])
return Parallel(n_jobs=self.args.n_jobs)(
delayed(func)(pair) for pair in self.metadata[exec_mode]
)

def load_nifty(self, fname):
return nibabel.load(os.path.join(self.data_path, fname))
Expand Down Expand Up @@ -266,7 +305,9 @@ def standardize_layout(data):

@staticmethod
def resize_fn(image, shape, order, mode):
return resize(image, shape, order=order, mode=mode, cval=0, clip=True, anti_aliasing=False)
return resize(
image, shape, order=order, mode=mode, cval=0, clip=True, anti_aliasing=False
)

def resample_anisotrophic_image(self, image, shape):
resized_channels = []
Expand Down