Skip to content

Commit 8599c8e

Browse files
committed
bugfixes for roi_subset mr
1 parent bfd2b85 commit 8599c8e

File tree

7 files changed

+42
-12
lines changed

7 files changed

+42
-12
lines changed
Binary file not shown.

tests/test_end_to_end.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,13 @@ def test_prediction_liver_roi_subset(self):
5858
dice = dice_score_multilabel(img_ref, img_new)
5959
images_equal = dice > 0.99
6060
self.assertTrue(images_equal, f"roi subset prediction not correct (dice: {dice:.6f})")
61+
62+
def test_prediction_liver_roi_subset_mr(self):
63+
img_ref = nib.load("tests/reference_files/example_seg_roi_subset_mr.nii.gz").get_fdata()
64+
img_new = nib.load("tests/unittest_prediction_roi_subset_mr.nii.gz").get_fdata()
65+
dice = dice_score_multilabel(img_ref, img_new)
66+
images_equal = dice > 0.99
67+
self.assertTrue(images_equal, f"roi subset MR prediction not correct (dice: {dice:.6f})")
6168

6269
def test_preview(self):
6370
preview_exists = os.path.exists("tests/unittest_prediction_fast/preview_total.png")

tests/tests.sh

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,22 @@ set -e
44
# ./tests/tests.sh
55

66

7-
# Test multilabel prediction
7+
# Test - multilabel prediction
88
TotalSegmentator -i tests/reference_files/example_ct_sm.nii.gz -o tests/unittest_prediction.nii.gz -bs --ml -d cpu
99
pytest -v tests/test_end_to_end.py::test_end_to_end::test_prediction_multilabel
1010

11-
# Test organ prediction - roi subset
11+
# Test - roi subset
1212
# 2 cpus:
1313
# example_ct_sm.nii.gz: 34s, 3.0GB
1414
# example_ct.nii.gz: 36s, 3.0GB
1515
TotalSegmentator -i tests/reference_files/example_ct_sm.nii.gz -o tests/unittest_prediction_roi_subset.nii.gz --ml -rs liver brain -d cpu
1616
pytest -v tests/test_end_to_end.py::test_end_to_end::test_prediction_liver_roi_subset
1717

18-
# Test organ predictions - fast - statistics
18+
# Test - roi subset - MR
19+
TotalSegmentator -i tests/reference_files/example_mr_sm.nii.gz -o tests/unittest_prediction_roi_subset_mr.nii.gz -ta total_mr --ml -rs liver brain -d cpu
20+
pytest -v tests/test_end_to_end.py::test_end_to_end::test_prediction_liver_roi_subset_mr
21+
22+
# Test - fast - statistics
1923
# 2 cpus: (statistics <1s)
2024
# example_ct_sm.nii.gz: 13s, 4.1GB
2125
# example_ct.nii.gz: 16s, 4.1GB
@@ -24,15 +28,15 @@ pytest -v tests/test_end_to_end.py::test_end_to_end::test_prediction_fast
2428
pytest -v tests/test_end_to_end.py::test_end_to_end::test_statistics
2529
pytest -v tests/test_end_to_end.py::test_end_to_end::test_preview
2630

27-
# Test organ predictions - fast - multilabel
31+
# Test - fast - multilabel
2832
TotalSegmentator -i tests/reference_files/example_ct_sm.nii.gz -o tests/unittest_prediction_fast.nii.gz --fast --ml -d cpu
2933
pytest -v tests/test_end_to_end.py::test_end_to_end::test_prediction_multilabel_fast
3034

31-
# Test organ predictions - fast - multilabel - force split
35+
# Test - fast - multilabel - force split
3236
TotalSegmentator -i tests/reference_files/example_ct.nii.gz -o tests/unittest_prediction_fast_force_split.nii.gz --fast --ml -fs -d cpu
3337
pytest -v tests/test_end_to_end.py::test_end_to_end::test_prediction_multilabel_fast_force_split
3438

35-
# Test organ predictions - fast - multilabel - body_seg
39+
# Test - fast - multilabel - body_seg
3640
TotalSegmentator -i tests/reference_files/example_ct_sm.nii.gz -o tests/unittest_prediction_fast_body_seg.nii.gz --fast --ml -bs -d cpu
3741
pytest -v tests/test_end_to_end.py::test_end_to_end::test_prediction_multilabel_fast_body_seg
3842

tests/update_test_files.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ set -e
88
TotalSegmentator -i tests/reference_files/example_ct_sm.nii.gz -o tests/reference_files/example_seg.nii.gz -bs --ml -d cpu
99
TotalSegmentator -i tests/reference_files/example_mr_sm.nii.gz -o tests/reference_files/example_seg_mr.nii.gz -ta total_mr --ml -d cpu
1010
TotalSegmentator -i tests/reference_files/example_ct_sm.nii.gz -o tests/reference_files/example_seg_roi_subset.nii.gz --ml -rs liver brain -d cpu
11+
TotalSegmentator -i tests/reference_files/example_mr_sm.nii.gz -o tests/reference_files/example_seg_roi_subset_mr.nii.gz -ta total_mr --ml -rs liver brain -d cpu
1112
# TotalSegmentator -i tests/reference_files/example_ct_sm.nii.gz -o tests/reference_files/example_seg_fast --fast --statistics -sii -p -d cpu
1213
TotalSegmentator -i tests/reference_files/example_ct_sm.nii.gz -o tests/reference_files/example_seg_fast.nii.gz --fast --ml -d cpu
1314
TotalSegmentator -i tests/reference_files/example_ct.nii.gz -o tests/reference_files/example_seg_fast_force_split.nii.gz --fast --ml -fs -d cpu

totalsegmentator/map_to_binary.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,14 @@
363363
"face": {
364364
1: "face"
365365
},
366+
"face_mr": {
367+
1: "face"
368+
},
369+
# those classes need to be removed
370+
"face_mr_auxiliary": {
371+
2: "brain",
372+
3: "liver"
373+
},
366374
"test": {
367375
1: "carpal"
368376
}

totalsegmentator/preview.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@
130130
"face": [
131131
["face"]
132132
],
133+
"face_mr": [
134+
["face"]
135+
],
133136
# "aortic_branches_test": [
134137
# ["brachiocephalic_trunk", "subclavian_artery_right", "subclavian_artery_left", "aorta",
135138
# "common_carotid_artery_right", "common_carotid_artery_left"],

totalsegmentator/python_api.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,11 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa
305305
else:
306306
download_pretrained_weights(task_id)
307307

308+
# For MR always run 3mm model for roi_subset, because 6mm too bad results
309+
# (runtime for 3mm still very good for MR)
310+
if task.endswith("_mr") and roi_subset is not None:
311+
roi_subset_robust = roi_subset
312+
308313
if roi_subset_robust is not None:
309314
roi_subset = roi_subset_robust
310315
robust_rs = True
@@ -313,10 +318,10 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa
313318

314319
if roi_subset is not None and type(roi_subset) is not list:
315320
raise ValueError("roi_subset must be a list of strings")
316-
if roi_subset is not None and task != "total":
317-
raise ValueError("roi_subset only works with task 'total'")
321+
if roi_subset is not None and not task.startswith("total"):
322+
raise ValueError("roi_subset only works with task 'total' or 'total_mr'")
318323

319-
if task == "total_mr" or task == "tissue_types_mr":
324+
if task.endswith("_mr"):
320325
if body_seg:
321326
body_seg = False
322327
print("INFO: For MR models the argument '--body_seg' is not supported and will be ignored.")
@@ -335,13 +340,15 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa
335340
else:
336341
crop_model_task = 733 if task == "total_mr" else 298
337342
crop_spacing = 6.0
343+
crop_task = "total_mr" if task == "total_mr" else "total"
344+
crop_trainer = "nnUNetTrainer_DASegOrd0_NoMirroring" if task == "total_mr" else "nnUNetTrainer_4000epochs_NoMirroring"
338345
organ_seg, _, _ = nnUNet_predict_image(input, None, crop_model_task, model="3d_fullres", folds=[0],
339-
trainer="nnUNetTrainer_4000epochs_NoMirroring", tta=False, multilabel_image=True, resample=crop_spacing,
340-
crop=None, crop_path=None, task="total", nora_tag="None", preview=False,
346+
trainer=crop_trainer, tta=False, multilabel_image=True, resample=crop_spacing,
347+
crop=None, crop_path=None, task_name=crop_task, nora_tag="None", preview=False,
341348
save_binary=False, nr_threads_resampling=nr_thr_resamp, nr_threads_saving=1,
342349
crop_addon=crop_addon, output_type=output_type, statistics=False,
343350
quiet=quiet, verbose=verbose, test=0, skip_saving=False, device=device)
344-
class_map_inv = {v: k for k, v in class_map["total"].items()}
351+
class_map_inv = {v: k for k, v in class_map[crop_task].items()}
345352
crop_mask = np.zeros(organ_seg.shape, dtype=np.uint8)
346353
organ_seg_data = organ_seg.get_fdata()
347354
# roi_subset_crop = [map_to_total[roi] if roi in map_to_total else roi for roi in roi_subset]

0 commit comments

Comments
 (0)