From 95e8788d622a37ebaea12406b51c15d8443e7e64 Mon Sep 17 00:00:00 2001 From: Vensenmu Date: Tue, 10 Jun 2025 16:53:05 +0800 Subject: [PATCH] fix(nnUNet): Correct background mask in BraTS preprocessor --- .../nnUNet/data_preprocessing/preprocessor.py | 69 +++++++++++++++---- 1 file changed, 55 insertions(+), 14 deletions(-) diff --git a/PyTorch/Segmentation/nnUNet/data_preprocessing/preprocessor.py b/PyTorch/Segmentation/nnUNet/data_preprocessing/preprocessor.py index 898135898..e03096b23 100644 --- a/PyTorch/Segmentation/nnUNet/data_preprocessing/preprocessor.py +++ b/PyTorch/Segmentation/nnUNet/data_preprocessing/preprocessor.py @@ -25,7 +25,15 @@ from skimage.transform import resize from utils.utils import get_task_code, make_empty_dir -from data_preprocessing.configs import ct_max, ct_mean, ct_min, ct_std, patch_size, spacings, task +from data_preprocessing.configs import ( + ct_max, + ct_mean, + ct_min, + ct_std, + patch_size, + spacings, + task, +) class Preprocessor: @@ -45,9 +53,15 @@ def __init__(self, args): self.ct_min, self.ct_max, self.ct_mean, self.ct_std = (0,) * 4 if not self.training: self.results = os.path.join(self.results, self.args.exec_mode) - self.crop_foreg = transforms.CropForegroundd(keys=["image", "label"], source_key="image") - nonzero = True if self.modality != "CT" else False # normalize only non-zero region for MRI - self.normalize_intensity = transforms.NormalizeIntensity(nonzero=nonzero, channel_wise=True) + self.crop_foreg = transforms.CropForegroundd( + keys=["image", "label"], source_key="image" + ) + nonzero = ( + True if self.modality != "CT" else False + ) # normalize only non-zero region for MRI + self.normalize_intensity = transforms.NormalizeIntensity( + nonzero=nonzero, channel_wise=True + ) if self.args.exec_mode == "val": dataset_json = json.load(open(metadata_path, "r")) dataset_json["val"] = dataset_json["training"] @@ -76,7 +90,9 @@ def run(self): _mean = round(self.ct_mean, 2) _std = round(self.ct_std, 2) if self.verbose: - print(f"[CT] min: {self.ct_min}, max: {self.ct_max}, mean: {_mean}, std: {_std}") + print( + f"[CT] min: {self.ct_min}, max: {self.ct_max}, mean: {_mean}, std: {_std}" + ) self.run_parallel(self.preprocess_pair, self.args.exec_mode) @@ -114,7 +130,7 @@ def preprocess_pair(self, pair): if self.args.ohe: mask = np.ones(image.shape[1:], dtype=np.float32) for i in range(image.shape[0]): - zeros = np.where(image[i] <= 0) + zeros = np.where(image[i] == 0) mask[zeros] *= 0.0 image = self.normalize_intensity(image).astype(np.float32) mask = np.expand_dims(mask, 0) @@ -131,15 +147,28 @@ def standardize(self, image, label): pad_shape = self.calculate_pad_shape(image) image_shape = image.shape[1:] if pad_shape != image_shape: - paddings = [(pad_sh - image_sh) / 2 for (pad_sh, image_sh) in zip(pad_shape, image_shape)] + paddings = [ + (pad_sh - image_sh) / 2 + for (pad_sh, image_sh) in zip(pad_shape, image_shape) + ] image = self.pad(image, paddings) label = self.pad(label, paddings) if self.args.dim == 2: # Center cropping 2D images. _, _, height, weight = image.shape start_h = (height - self.patch_size[0]) // 2 start_w = (weight - self.patch_size[1]) // 2 - image = image[:, :, start_h : start_h + self.patch_size[0], start_w : start_w + self.patch_size[1]] - label = label[:, :, start_h : start_h + self.patch_size[0], start_w : start_w + self.patch_size[1]] + image = image[ + :, + :, + start_h : start_h + self.patch_size[0], + start_w : start_w + self.patch_size[1], + ] + label = label[ + :, + :, + start_h : start_h + self.patch_size[0], + start_w : start_w + self.patch_size[1], + ] return image, label def normalize(self, image): @@ -148,7 +177,9 @@ def normalize(self, image): return self.normalize_intensity(image) def save(self, image, label, fname, image_metadata): - mean, std = np.round(np.mean(image, (1, 2, 3)), 2), np.round(np.std(image, (1, 2, 3)), 2) + mean, std = np.round(np.mean(image, (1, 2, 3)), 2), np.round( + np.std(image, (1, 2, 3)), 2 + ) if self.verbose: print(f"Saving {fname} shape {image.shape} mean {mean} std {std}") self.save_npy(image, fname, "_x.npy") @@ -191,7 +222,9 @@ def calculate_pad_shape(self, image): image_shape = image.shape[1:] if len(min_shape) == 2: # In 2D case we don't want to pad depth axis. min_shape.insert(0, image_shape[0]) - pad_shape = [max(mshape, ishape) for mshape, ishape in zip(min_shape, image_shape)] + pad_shape = [ + max(mshape, ishape) for mshape, ishape in zip(min_shape, image_shape) + ] return pad_shape def get_intensities(self, pair): @@ -233,10 +266,16 @@ def calculate_new_shape(self, spacing, shape): return new_shape def save_npy(self, image, fname, suffix): - np.save(os.path.join(self.results, fname.replace(".nii.gz", suffix)), image, allow_pickle=False) + np.save( + os.path.join(self.results, fname.replace(".nii.gz", suffix)), + image, + allow_pickle=False, + ) def run_parallel(self, func, exec_mode): - return Parallel(n_jobs=self.args.n_jobs)(delayed(func)(pair) for pair in self.metadata[exec_mode]) + return Parallel(n_jobs=self.args.n_jobs)( + delayed(func)(pair) for pair in self.metadata[exec_mode] + ) def load_nifty(self, fname): return nibabel.load(os.path.join(self.data_path, fname)) @@ -266,7 +305,9 @@ def standardize_layout(data): @staticmethod def resize_fn(image, shape, order, mode): - return resize(image, shape, order=order, mode=mode, cval=0, clip=True, anti_aliasing=False) + return resize( + image, shape, order=order, mode=mode, cval=0, clip=True, anti_aliasing=False + ) def resample_anisotrophic_image(self, image, shape): resized_channels = []