1
1
import argparse
2
2
from functools import partial
3
3
4
- from .util import run_segmentation , get_model
4
+ from .util import (
5
+ run_segmentation , get_model , get_model_registry , get_model_training_resolution , load_custom_model
6
+ )
5
7
from ..imod .to_imod import export_helper , write_segmentation_to_imod_as_points , write_segmentation_to_imod
6
8
from ..inference .util import inference_helper , parse_tiling
7
9
8
10
9
11
def imod_point_cli ():
10
- parser = argparse .ArgumentParser (description = "" )
12
+ parser = argparse .ArgumentParser (
13
+ description = "Convert a vesicle segmentation to an IMOD point model, "
14
+ "corresponding to a sphere for each vesicle in the segmentation."
15
+ )
11
16
parser .add_argument (
12
17
"--input_path" , "-i" , required = True ,
13
18
help = "The filepath to the mrc file or the directory containing the tomogram data."
14
19
)
15
20
parser .add_argument (
16
21
"--segmentation_path" , "-s" , required = True ,
17
- help = "The filepath to the tif file or the directory containing the segmentations."
22
+ help = "The filepath to the file or the directory containing the segmentations."
18
23
)
19
24
parser .add_argument (
20
25
"--output_path" , "-o" , required = True ,
21
26
help = "The filepath to directory where the segmentations will be saved."
22
27
)
23
28
parser .add_argument (
24
- "--segmentation_key" , "-k" , help = ""
29
+ "--segmentation_key" , "-k" ,
30
+ help = "The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
31
+ "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
25
32
)
26
33
parser .add_argument (
27
- "--min_radius" , type = float , default = 10.0 , help = ""
34
+ "--min_radius" , type = float , default = 10.0 ,
35
+ help = "The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export." # noqa
28
36
)
29
37
parser .add_argument (
30
- "--radius_factor" , type = float , default = 1.0 , help = "" ,
38
+ "--radius_factor" , type = float , default = 1.0 ,
39
+ help = "A factor for scaling the sphere radius for the export. "
40
+ "This can be used to fit the size of segmented vesicles to the best matching spheres." ,
31
41
)
32
42
parser .add_argument (
33
- "--force" , action = "store_true" , help = "" ,
43
+ "--force" , action = "store_true" ,
44
+ help = "Whether to over-write already present export results."
34
45
)
35
46
args = parser .parse_args ()
36
47
@@ -51,24 +62,29 @@ def imod_point_cli():
51
62
52
63
53
64
def imod_object_cli ():
54
- parser = argparse .ArgumentParser (description = "" )
65
+ parser = argparse .ArgumentParser (
66
+ description = "Convert segmented objects to close contour IMOD models."
67
+ )
55
68
parser .add_argument (
56
69
"--input_path" , "-i" , required = True ,
57
70
help = "The filepath to the mrc file or the directory containing the tomogram data."
58
71
)
59
72
parser .add_argument (
60
73
"--segmentation_path" , "-s" , required = True ,
61
- help = "The filepath to the tif file or the directory containing the segmentations."
74
+ help = "The filepath to the file or the directory containing the segmentations."
62
75
)
63
76
parser .add_argument (
64
77
"--output_path" , "-o" , required = True ,
65
78
help = "The filepath to directory where the segmentations will be saved."
66
79
)
67
80
parser .add_argument (
68
- "--segmentation_key" , "-k" , help = ""
81
+ "--segmentation_key" , "-k" ,
82
+ help = "The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
83
+ "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
69
84
)
70
85
parser .add_argument (
71
- "--force" , action = "store_true" , help = "" ,
86
+ "--force" , action = "store_true" ,
87
+ help = "Whether to over-write already present export results."
72
88
)
73
89
args = parser .parse_args ()
74
90
export_helper (
@@ -82,8 +98,6 @@ def imod_object_cli():
82
98
83
99
84
100
# TODO: handle kwargs
85
- # TODO: add custom model path
86
- # TODO: enable autoscaling from input resolution
87
101
def segmentation_cli ():
88
102
parser = argparse .ArgumentParser (description = "Run segmentation." )
89
103
parser .add_argument (
@@ -94,9 +108,11 @@ def segmentation_cli():
94
108
"--output_path" , "-o" , required = True ,
95
109
help = "The filepath to directory where the segmentations will be saved."
96
110
)
97
- # TODO: list the availabel models here by parsing the keys of the model registry
111
+ model_names = list (get_model_registry ().urls .keys ())
112
+ model_names = ", " .join (model_names )
98
113
parser .add_argument (
99
- "--model" , "-m" , required = True , help = "The model type."
114
+ "--model" , "-m" , required = True ,
115
+ help = f"The model type. The following models are currently available: { model_names } "
100
116
)
101
117
parser .add_argument (
102
118
"--mask_path" , help = "The filepath to a tif file with a mask that will be used to restrict the segmentation."
@@ -119,23 +135,45 @@ def segmentation_cli():
119
135
"--data_ext" , default = ".mrc" , help = "The extension of the tomogram data. By default .mrc."
120
136
)
121
137
parser .add_argument (
122
- "--segmentation_key " , "-s " , help = ""
138
+ "--checkpoint " , "-c " , help = "Path to a custom model, e.g. from domain adaptation." ,
123
139
)
124
- # TODO enable autoscaling
125
140
parser .add_argument (
126
- "--scale" , type = float , default = None , help = ""
141
+ "--segmentation_key" , "-s" ,
142
+ help = "If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif." ,
143
+ )
144
+ parser .add_argument (
145
+ "--scale" , type = float ,
146
+ help = "The factor for rescaling the data before inference. "
147
+ "By default, the scaling factor will be derived from the voxel size of the input data. "
148
+ "If this parameter is given it will over-ride the default behavior. "
127
149
)
128
150
args = parser .parse_args ()
129
151
130
- model = get_model (args .model )
131
- tiling = parse_tiling (args .tile_shape , args .halo )
132
- scale = None if args .scale is None else 3 * (args .scale ,)
152
+ if args .checkpoint is None :
153
+ model = get_model (args .model )
154
+ else :
155
+ model = load_custom_model (args .checkpoint )
156
+ assert model is not None , f"The model from { args .checkpoint } could not be loaded."
157
+
158
+ is_2d = "2d" in args .model
159
+ tiling = parse_tiling (args .tile_shape , args .halo , is_2d = is_2d )
160
+
161
+ # If the scale argument is not passed, then we get the average training resolution for the model.
162
+ # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files.
163
+ if args .scale is None :
164
+ model_resolution = get_model_training_resolution (args .model )
165
+ model_resolution = tuple (model_resolution [ax ] for ax in ("yx" if is_2d else "zyx" ))
166
+ scale = None
167
+ # Otherwise, we set the model resolution to None and use the scaling factor provided by the user.
168
+ else :
169
+ model_resolution = None
170
+ scale = (2 if is_2d else 3 ) * (args .scale ,)
133
171
134
172
segmentation_function = partial (
135
- run_segmentation , model = model , model_type = args .model , verbose = False , tiling = tiling , scale = scale
173
+ run_segmentation , model = model , model_type = args .model , verbose = False , tiling = tiling ,
136
174
)
137
175
inference_helper (
138
176
args .input_path , args .output_path , segmentation_function ,
139
177
mask_input_path = args .mask_path , force = args .force , data_ext = args .data_ext ,
140
- output_key = args .segmentation_key ,
178
+ output_key = args .segmentation_key , model_resolution = model_resolution , scale = scale ,
141
179
)
0 commit comments