Skip to content

Commit 851e5ba

Browse files
committed
nnunet post-processing times
1 parent 8d6d221 commit 851e5ba

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

evaluate_baselines.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def evaluate_nnunet(result_dir='/share/data_rechenknecht03_2/students/kaftan/Fis
191191
test_hd = torch.zeros_like(test_assd)
192192
test_hd95 = torch.zeros_like(test_assd)
193193
test_missing_percent = torch.zeros_like(test_assd)
194+
all_times = []
194195
for fold in range(n_folds):
195196
fold_dir = os.path.join(result_dir, f'fold_{fold}')
196197
files = sorted(glob(os.path.join(fold_dir, 'validation_raw_postprocessed', '*.nii.gz')))
@@ -219,7 +220,8 @@ def evaluate_nnunet(result_dir='/share/data_rechenknecht03_2/students/kaftan/Fis
219220
labelmap_predict, _ = lobes_to_fissures(labelmap_predict, mask=ds.get_lung_mask(img_index))
220221

221222
if mode == 'surface':
222-
_, predicted_meshes = poisson_reconstruction(labelmap_predict, ds.get_lung_mask(img_index))
223+
_, predicted_meshes, times = poisson_reconstruction(labelmap_predict, ds.get_lung_mask(img_index), return_times=True)
224+
all_times.append(torch.tensor(times))
223225
# TODO: compare Poisson to Marching Cubes mesh generation
224226
for i, m in enumerate(predicted_meshes):
225227
# save reconstructed mesh
@@ -241,6 +243,8 @@ def evaluate_nnunet(result_dir='/share/data_rechenknecht03_2/students/kaftan/Fis
241243
# img.CopyInformation(fissures_predict)
242244
# sitk.WriteImage(img, f'./results/nnunet_pred_skeletonized_{case}_{sequence}.nii.gz')
243245

246+
print(f'Current mean time per fissure: {torch.stack(all_times).sum(1).mean():4f}s +- {torch.stack(all_times).sum(1).std():4f}s')
247+
244248
# compute surface distances
245249
mean_assd, std_assd, mean_sdsd, std_sdsd, mean_hd, std_hd, mean_hd95, std_hd95, percent_missing = compute_mesh_metrics(
246250
all_predictions, all_targ_meshes, ids=ids, show=show, spacings=spacings, plot_folder=plot_dir)
@@ -281,12 +285,13 @@ def evaluate_nnunet(result_dir='/share/data_rechenknecht03_2/students/kaftan/Fis
281285
n_fissures = 3
282286

283287
data_dir = '../TotalSegmentator/ThoraxCrop'
284-
evaluate_voxel2mesh(data_dir, show=False)
285288

286-
# nnu_task = "Task503_FissuresTotalSeg"
287-
# nnu_trainer = "nnUNetTrainerV2_200ep"
288-
# nnu_result_dir = f'/share/data_rechenknecht03_2/students/kaftan/FissureSegmentation/nnUNet_baseline/nnu_results/nnUNet/3d_fullres/{nnu_task}/{nnu_trainer}__nnUNetPlansv2.1'
289-
# evaluate_nnunet(nnu_result_dir, my_data_dir=data_dir, mode='surface', show=False)
289+
# evaluate_voxel2mesh(data_dir, show=False)
290+
291+
nnu_task = "Task503_FissuresTotalSeg"
292+
nnu_trainer = "nnUNetTrainerV2_200ep"
293+
nnu_result_dir = f'/share/data_rechenknecht03_2/students/kaftan/FissureSegmentation/nnUNet_baseline/nnu_results/nnUNet/3d_fullres/{nnu_task}/{nnu_trainer}__nnUNetPlansv2.1'
294+
evaluate_nnunet(nnu_result_dir, my_data_dir=data_dir, mode='surface', show=False)
290295
# evaluate_nnunet(nnu_result_dir, my_data_dir=data_dir, mode='voxels', show=False)
291296

292297
# lobes_nnunet = '/share/data_rechenknecht03_2/students/kaftan/FissureSegmentation/nnUNet_baseline/nnu_results/nnUNet/3d_fullres/Task502_LobesCOPDEMPIRE/nnUNetTrainerV2_200ep__nnUNetPlansv2.1'

0 commit comments

Comments
 (0)