Skip to content

Commit 3054b2b

Browse files
committed
udpate phase prediction model
1 parent f8dac08 commit 3054b2b

File tree

7 files changed

+60
-20
lines changed

7 files changed

+60
-20
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
## Master
22
* add brain structures
33
* add liver vessels
4+
* greatly improved phase classification model
45

56

67
## Release 2.3.0

README.md

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -133,15 +133,6 @@ If you want to reduce memory consumption you can use the following options:
133133
* `--nr_thr_saving 1`: Saving big images with several threads will take a lot of memory
134134

135135

136-
### Train/validation/test split
137-
The exact split of the dataset can be found in the file `meta.csv` inside of the [dataset](https://doi.org/10.5281/zenodo.6802613). This was used for the validation in our paper.
138-
The exact numbers of the results for the high-resolution model (1.5mm) can be found [here](resources/results_all_classes_v1.json). The paper shows these numbers in the supplementary materials Figure 11.
139-
140-
141-
### Retrain model and run evaluation
142-
See [here](resources/train_nnunet.md) for more info on how to train a nnU-Net yourself on the TotalSegmentator dataset, how to split the data into train/validation/test set as in our paper, and how to run the same evaluation as in our paper.
143-
144-
145136
### Python API
146137
You can run totalsegmentator via Python:
147138
```python
@@ -159,13 +150,12 @@ if __name__ == "__main__":
159150
```
160151
You can see all available arguments [here](https://github.com/wasserth/TotalSegmentator/blob/master/totalsegmentator/python_api.py). Running from within the main environment should avoid some multiprocessing issues.
161152

162-
The segmentation image contains the names of the classes in the extended header. If you want to load this additional header information you can use the following code:
153+
The segmentation image contains the names of the classes in the extended header. If you want to load this additional header information you can use the following code (requires `pip install xmltodict`):
163154
```python
164155
from totalsegmentator.nifti_ext_header import load_multilabel_nifti
165156

166157
segmentation_nifti_img, label_map_dict = load_multilabel_nifti(image_path)
167158
```
168-
The above code requires `pip install xmltodict`.
169159

170160

171161
### Install latest master branch (contains latest bug fixes)
@@ -175,6 +165,11 @@ pip install git+https://github.com/wasserth/TotalSegmentator.git
175165

176166

177167
### Other commands
168+
If you want to know which contrast phase a CT image is you can use the following command (requires `pip install xgboost`). More details can be found [here](resources/contrast_phase_prediction.md):
169+
```
170+
totalseg_get_phase -i ct.nii.gz -o contrast_phase.json
171+
```
172+
178173
If you want to combine some subclasses (e.g. lung lobes) into one binary mask (e.g. entire lung) you can use the following command:
179174
```
180175
totalseg_combine_masks -i totalsegmentator_output_dir -o combined_mask.nii.gz -m lungcomm
@@ -191,6 +186,15 @@ totalseg_set_license -l aca_12345678910
191186
```
192187

193188

189+
### Train/validation/test split
190+
The exact split of the dataset can be found in the file `meta.csv` inside of the [dataset](https://doi.org/10.5281/zenodo.6802613). This was used for the validation in our paper.
191+
The exact numbers of the results for the high-resolution model (1.5mm) can be found [here](resources/results_all_classes_v1.json). The paper shows these numbers in the supplementary materials Figure 11.
192+
193+
194+
### Retrain model and run evaluation
195+
See [here](resources/train_nnunet.md) for more info on how to train a nnU-Net yourself on the TotalSegmentator dataset, how to split the data into train/validation/test set as in our paper, and how to run the same evaluation as in our paper.
196+
197+
194198
### Typical problems
195199

196200
**ITK loading Error**
-1.58 MB
Binary file not shown.
Binary file not shown.
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Details on how the prediction of the contrast phase is done
2+
3+
TotalSegmentator is used to predict the following structures:
4+
```python
5+
["liver", "pancreas", "urinary_bladder", "gallbladder",
6+
"heart", "aorta", "inferior_vena_cava", "portal_vein_and_splenic_vein",
7+
"iliac_vena_left", "iliac_vena_right", "iliac_artery_left", "iliac_artery_right",
8+
"pulmonary_vein", "brain", "colon", "small_bowel",
9+
"internal_carotid_artery_right", "internal_carotid_artery_left",
10+
"internal_jugular_vein_right", "internal_jugular_vein_left"]
11+
```
12+
Then the median intensity (HU value) of each structure is used as feature for a xgboost classifier
13+
to predict the post injection time (pi_time). The pi_time can be mapped to the contrast phase
14+
then. It classifies into `native`, `arterial_early`, `arterial_late`, and `portal_venous` phase.
15+
The classifier was trained on the TotalSegmentator dataset and therefore works with all sorts
16+
of different CT images.
17+
18+
Results on 5-fold cross validation:
19+
20+
- Mean absolute error (MAE): 5.55s
21+
- F1 scores for each class:
22+
- native: 0.980
23+
- arterial_early+late: 0.915
24+
- portal: 0.940
25+
26+
The results contain a probablity for each class which is high if the predicted pi_time is close to the ideal
27+
pi_time for the given phase. Moreover, the classifier is an ensemble of 5 models. The output contains the
28+
standard deviation of the predictions which can be used as a measure of confidence. If it is low the 5 models
29+
give similar predictions which is a good sign.
30+

setup.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
python_requires='>=3.9',
1212
license='Apache 2.0',
1313
packages=find_packages(),
14-
package_data={"totalsegmentator": ["resources/totalsegmentator_snomed_mapping.csv"]},
14+
package_data={"totalsegmentator":
15+
["resources/totalsegmentator_snomed_mapping.csv",
16+
"resources/contrast_phase_classifiers_2024_07_19.pkl"]
17+
},
1518
install_requires=[
1619
'torch>=2.0.0',
1720
'numpy<2',

totalsegmentator/bin/totalseg_get_phase.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def get_ct_contrast_phase(ct_img: nib.Nifti1Image, model_file: Path = None):
6868
# print(f"ts took: {time.time()-st:.2f}s")
6969

7070
if stats["brain"]["volume"] > 100:
71-
# print(f"Brain in image, therefore also running headneck model.")
71+
# print("Brain in image, therefore also running headneck model.")
7272
st = time.time()
7373
seg_img_hn, stats_hn = totalsegmentator(ct_img, None, ml=True, fast=False, statistics=True,
7474
task="headneck_bones_vessels",
@@ -85,11 +85,9 @@ def get_ct_contrast_phase(ct_img: nib.Nifti1Image, model_file: Path = None):
8585
features.append(stats_hn[organ]["intensity"])
8686

8787
if model_file is None:
88-
# weights from longitudinalliver dataset
89-
classifier_path = Path(__file__).parents[2] / "resources" / "contrast_phase_classifiers.pkl"
90-
else:
91-
# weights from megaseg dataset
92-
# classifier_path = "/mnt/nor/wasserthalj_data/classifiers_megaseg.pkl"
88+
classifier_path = Path(__file__).parents[2] / "resources" / "contrast_phase_classifiers_2024_07_19.pkl"
89+
else:
90+
# manually set model file
9391
classifier_path = model_file
9492
clfs = pickle.load(open(classifier_path, "rb"))
9593

@@ -136,13 +134,17 @@ def main():
136134
parser.add_argument("-m", metavar="filepath", dest="model_file",
137135
help="path to classifier model",
138136
type=lambda p: Path(p).absolute(), required=False, default=None)
137+
138+
parser.add_argument("-q", dest="quiet", action="store_true",
139+
help="Print no output to stdout", default=False)
139140

140141
args = parser.parse_args()
141142

142143
res = get_ct_contrast_phase(nib.load(args.input_file), args.model_file)
143144

144-
print("Result:")
145-
pprint(res)
145+
if not args.quiet:
146+
print("Result:")
147+
pprint(res)
146148

147149
with open(args.output_file, "w") as f:
148150
f.write(json.dumps(res, indent=4))

0 commit comments

Comments
 (0)