Skip to content

Commit 9b3e7b5

Browse files
Add sample data, CI, and auto-scaling for segmenation CLI
1 parent 414b4cb commit 9b3e7b5

File tree

7 files changed

+272
-39
lines changed

7 files changed

+272
-39
lines changed

.github/workflows/run_tests.yaml

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
name: test
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
tags:
8+
- "v*" # Push events to matching v*, i.e. v1.0, v20.15.10
9+
pull_request: # run CI on commits to any open PR
10+
workflow_dispatch: # can manually trigger CI from GitHub actions tab
11+
12+
13+
jobs:
14+
test:
15+
name: ${{ matrix.os }} ${{ matrix.python-version }}
16+
runs-on: ${{ matrix.os }}
17+
timeout-minutes: 60
18+
strategy:
19+
fail-fast: false
20+
matrix:
21+
os: [ubuntu-latest]
22+
python-version: ["3.11"]
23+
steps:
24+
- name: Checkout
25+
uses: actions/checkout@v4
26+
27+
- name: Setup micromamba
28+
uses: mamba-org/setup-micromamba@v1
29+
with:
30+
environment-file: environment_cpu.yaml
31+
create-args: >-
32+
python=${{ matrix.python-version }}
33+
34+
- name: Install SynapseNet
35+
shell: bash -l {0}
36+
run: pip install --no-deps -e .
37+
38+
- name: Run tests
39+
shell: bash -l {0}
40+
run: python -m unittest discover -s test -v

synaptic_reconstruction/inference/util.py

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import imageio.v3 as imageio
1313
import elf.parallel as parallel
14+
import mrcfile
1415
import numpy as np
1516
import torch
1617
import torch_em
@@ -131,7 +132,7 @@ def get_prediction(
131132
# torch_em expects the root folder of a checkpoint path instead of the checkpoint itself.
132133
if model_path.endswith("best.pt"):
133134
model_path = os.path.split(model_path)[0]
134-
print(f"tiling {tiling}")
135+
# print(f"tiling {tiling}")
135136
# Create updated_tiling with the same structure
136137
updated_tiling = {
137138
"tile": {},
@@ -140,7 +141,7 @@ def get_prediction(
140141
# Update tile dimensions
141142
for dim in tiling["tile"]:
142143
updated_tiling["tile"][dim] = tiling["tile"][dim] - 2 * tiling["halo"][dim]
143-
print(f"updated_tiling {updated_tiling}")
144+
# print(f"updated_tiling {updated_tiling}")
144145
pred = get_prediction_torch_em(
145146
input_volume, updated_tiling, model_path, model, verbose, with_channels, mask=mask
146147
)
@@ -252,6 +253,33 @@ def _load_input(img_path, extra_files, i):
252253
return input_volume
253254

254255

256+
def _derive_scale(img_path, model_resolution):
257+
try:
258+
with mrcfile.open(img_path, "r") as f:
259+
voxel_size = f.voxel_size
260+
if len(model_resolution) == 2:
261+
voxel_size = [voxel_size.y, voxel_size.x]
262+
else:
263+
voxel_size = [voxel_size.z, voxel_size.y, voxel_size.x]
264+
265+
assert len(voxel_size) == len(model_resolution)
266+
# The voxel size is given in Angstrom and we need to translate it to nanometer.
267+
voxel_size = [vsize / 10 for vsize in voxel_size]
268+
269+
# Compute the correct scale factor.
270+
scale = tuple(vsize / res for vsize, res in zip(voxel_size, model_resolution))
271+
print("Rescaling the data at", img_path, "by", scale, "to match the training voxel size", model_resolution)
272+
273+
except Exception:
274+
warnings.warn(
275+
f"The voxel size could not be read from the data for {img_path}. "
276+
"This data will not be scaled for prediction."
277+
)
278+
scale = None
279+
280+
return scale
281+
282+
255283
def inference_helper(
256284
input_path: str,
257285
output_root: str,
@@ -263,6 +291,8 @@ def inference_helper(
263291
mask_input_ext: str = ".tif",
264292
force: bool = False,
265293
output_key: Optional[str] = None,
294+
model_resolution: Optional[Tuple[float, float, float]] = None,
295+
scale: Optional[Tuple[float, float, float]] = None,
266296
) -> None:
267297
"""Helper function to run segmentation for mrc files.
268298
@@ -282,7 +312,13 @@ def inference_helper(
282312
mask_input_ext: File extension for the mask inputs (by default .tif).
283313
force: Whether to rerun segmentation for output files that are already present.
284314
output_key: Output key for the prediction. If none will write an hdf5 file.
315+
model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction.
316+
If given, the scaling factor will automatically be determined based on the voxel_size of the input data.
317+
scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'.
285318
"""
319+
if (scale is not None) and (model_resolution is not None):
320+
raise ValueError("You must not provide both 'scale' and 'model_resolution' arguments.")
321+
286322
# Get the input files. If input_path is a folder then this will load all
287323
# the mrc files beneath it. Otherwise we assume this is an mrc file already
288324
# and just return the path to this mrc file.
@@ -333,8 +369,18 @@ def inference_helper(
333369
# Load the mask (if given).
334370
mask = None if mask_files is None else imageio.imread(mask_files[i])
335371

372+
# Determine the scale factor:
373+
# If the neither the 'scale' nor 'model_resolution' arguments were passed then set it to None.
374+
if scale is None and model_resolution is None:
375+
this_scale = None
376+
elif scale is not None: # If 'scale' was passed then use it.
377+
this_scale = scale
378+
else: # Otherwise 'model_resolution' was passed, use it to derive the scaling from the data
379+
assert model_resolution is not None
380+
this_scale = _derive_scale(img_path, model_resolution)
381+
336382
# Run the segmentation.
337-
segmentation = segmentation_function(input_volume, mask=mask)
383+
segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale)
338384

339385
# Write the result to tif or h5.
340386
os.makedirs(os.path.split(output_path)[0], exist_ok=True)
@@ -348,15 +394,21 @@ def inference_helper(
348394
print(f"Saved segmentation to {output_path}.")
349395

350396

351-
def get_default_tiling() -> Dict[str, Dict[str, int]]:
397+
def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]:
352398
"""Determine the tile shape and halo depending on the available VRAM.
353399
400+
Args:
401+
is_2d: Whether to return tiling settings for 2d inference.
402+
354403
Returns:
355404
The default tiling settings for the available computational resources.
356405
"""
357-
if torch.cuda.is_available():
358-
print("Determining suitable tiling")
406+
if is_2d:
407+
tile = {"x": 768, "y": 768, "z": 1}
408+
halo = {"x": 128, "y": 128, "z": 0}
409+
return {"tile": tile, "halo": halo}
359410

411+
if torch.cuda.is_available():
360412
# We always use the same default halo.
361413
halo = {"x": 64, "y": 64, "z": 16} # before 64,64,8
362414

@@ -390,19 +442,21 @@ def get_default_tiling() -> Dict[str, Dict[str, int]]:
390442

391443
def parse_tiling(
392444
tile_shape: Tuple[int, int, int],
393-
halo: Tuple[int, int, int]
445+
halo: Tuple[int, int, int],
446+
is_2d: bool = False,
394447
) -> Dict[str, Dict[str, int]]:
395448
"""Helper function to parse tiling parameter input from the command line.
396449
397450
Args:
398451
tile_shape: The tile shape. If None the default tile shape is used.
399452
halo: The halo. If None the default halo is used.
453+
is_2d: Whether to return tiling for a 2d model.
400454
401455
Returns:
402456
The tiling specification.
403457
"""
404458

405-
default_tiling = get_default_tiling()
459+
default_tiling = get_default_tiling(is_2d=is_2d)
406460

407461
if tile_shape is None:
408462
tile_shape = default_tiling["tile"]
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import os
2+
import pooch
3+
4+
5+
def get_sample_data(name: str) -> str:
6+
"""Get the filepath to SynapseNet sample data, stored as mrc file.
7+
8+
Args:
9+
name: The name of the sample data. Currently, we only provide the 'tem_2d' sample data.
10+
11+
Returns:
12+
The filepath to the downloaded sample data.
13+
"""
14+
registry = {
15+
"tem_2d.mrc": "3c6f9ff6d7673d9bf2fd46c09750c3c7dbb8fa1aa59dcdb3363b65cc774dcf28",
16+
}
17+
urls = {
18+
"tem_2d.mrc": "https://owncloud.gwdg.de/index.php/s/5sAQ0U4puAspcHg/download",
19+
}
20+
key = f"{name}.mrc"
21+
22+
if key not in registry:
23+
valid_names = [k[:-4] for k in registry.keys()]
24+
raise ValueError(f"Invalid sample name {name}, please choose one of {valid_names}.")
25+
26+
cache_dir = os.path.expanduser(pooch.os_cache("synapse-net"))
27+
data_registry = pooch.create(
28+
path=os.path.join(cache_dir, "sample_data"),
29+
base_url="",
30+
registry=registry,
31+
urls=urls,
32+
)
33+
file_path = data_registry.fetch(key)
34+
return file_path

synaptic_reconstruction/tools/cli.py

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,47 @@
11
import argparse
22
from functools import partial
33

4-
from .util import run_segmentation, get_model
4+
from .util import (
5+
run_segmentation, get_model, get_model_registry, get_model_training_resolution, load_custom_model
6+
)
57
from ..imod.to_imod import export_helper, write_segmentation_to_imod_as_points, write_segmentation_to_imod
68
from ..inference.util import inference_helper, parse_tiling
79

810

911
def imod_point_cli():
10-
parser = argparse.ArgumentParser(description="")
12+
parser = argparse.ArgumentParser(
13+
description="Convert a vesicle segmentation to an IMOD point model, "
14+
"corresponding to a sphere for each vesicle in the segmentation."
15+
)
1116
parser.add_argument(
1217
"--input_path", "-i", required=True,
1318
help="The filepath to the mrc file or the directory containing the tomogram data."
1419
)
1520
parser.add_argument(
1621
"--segmentation_path", "-s", required=True,
17-
help="The filepath to the tif file or the directory containing the segmentations."
22+
help="The filepath to the file or the directory containing the segmentations."
1823
)
1924
parser.add_argument(
2025
"--output_path", "-o", required=True,
2126
help="The filepath to directory where the segmentations will be saved."
2227
)
2328
parser.add_argument(
24-
"--segmentation_key", "-k", help=""
29+
"--segmentation_key", "-k",
30+
help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
31+
"If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
2532
)
2633
parser.add_argument(
27-
"--min_radius", type=float, default=10.0, help=""
34+
"--min_radius", type=float, default=10.0,
35+
help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export." # noqa
2836
)
2937
parser.add_argument(
30-
"--radius_factor", type=float, default=1.0, help="",
38+
"--radius_factor", type=float, default=1.0,
39+
help="A factor for scaling the sphere radius for the export. "
40+
"This can be used to fit the size of segmented vesicles to the best matching spheres.",
3141
)
3242
parser.add_argument(
33-
"--force", action="store_true", help="",
43+
"--force", action="store_true",
44+
help="Whether to over-write already present export results."
3445
)
3546
args = parser.parse_args()
3647

@@ -51,24 +62,29 @@ def imod_point_cli():
5162

5263

5364
def imod_object_cli():
54-
parser = argparse.ArgumentParser(description="")
65+
parser = argparse.ArgumentParser(
66+
description="Convert segmented objects to close contour IMOD models."
67+
)
5568
parser.add_argument(
5669
"--input_path", "-i", required=True,
5770
help="The filepath to the mrc file or the directory containing the tomogram data."
5871
)
5972
parser.add_argument(
6073
"--segmentation_path", "-s", required=True,
61-
help="The filepath to the tif file or the directory containing the segmentations."
74+
help="The filepath to the file or the directory containing the segmentations."
6275
)
6376
parser.add_argument(
6477
"--output_path", "-o", required=True,
6578
help="The filepath to directory where the segmentations will be saved."
6679
)
6780
parser.add_argument(
68-
"--segmentation_key", "-k", help=""
81+
"--segmentation_key", "-k",
82+
help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
83+
"If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
6984
)
7085
parser.add_argument(
71-
"--force", action="store_true", help="",
86+
"--force", action="store_true",
87+
help="Whether to over-write already present export results."
7288
)
7389
args = parser.parse_args()
7490
export_helper(
@@ -82,8 +98,6 @@ def imod_object_cli():
8298

8399

84100
# TODO: handle kwargs
85-
# TODO: add custom model path
86-
# TODO: enable autoscaling from input resolution
87101
def segmentation_cli():
88102
parser = argparse.ArgumentParser(description="Run segmentation.")
89103
parser.add_argument(
@@ -94,9 +108,11 @@ def segmentation_cli():
94108
"--output_path", "-o", required=True,
95109
help="The filepath to directory where the segmentations will be saved."
96110
)
97-
# TODO: list the availabel models here by parsing the keys of the model registry
111+
model_names = list(get_model_registry().urls.keys())
112+
model_names = ", ".join(model_names)
98113
parser.add_argument(
99-
"--model", "-m", required=True, help="The model type."
114+
"--model", "-m", required=True,
115+
help=f"The model type. The following models are currently available: {model_names}"
100116
)
101117
parser.add_argument(
102118
"--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation."
@@ -119,23 +135,45 @@ def segmentation_cli():
119135
"--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc."
120136
)
121137
parser.add_argument(
122-
"--segmentation_key", "-s", help=""
138+
"--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.",
123139
)
124-
# TODO enable autoscaling
125140
parser.add_argument(
126-
"--scale", type=float, default=None, help=""
141+
"--segmentation_key", "-s",
142+
help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.",
143+
)
144+
parser.add_argument(
145+
"--scale", type=float,
146+
help="The factor for rescaling the data before inference. "
147+
"By default, the scaling factor will be derived from the voxel size of the input data. "
148+
"If this parameter is given it will over-ride the default behavior. "
127149
)
128150
args = parser.parse_args()
129151

130-
model = get_model(args.model)
131-
tiling = parse_tiling(args.tile_shape, args.halo)
132-
scale = None if args.scale is None else 3 * (args.scale,)
152+
if args.checkpoint is None:
153+
model = get_model(args.model)
154+
else:
155+
model = load_custom_model(args.checkpoint)
156+
assert model is not None, f"The model from {args.checkpoint} could not be loaded."
157+
158+
is_2d = "2d" in args.model
159+
tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d)
160+
161+
# If the scale argument is not passed, then we get the average training resolution for the model.
162+
# The inputs will then be scaled to match this resolution based on the voxel size from the mrc files.
163+
if args.scale is None:
164+
model_resolution = get_model_training_resolution(args.model)
165+
model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx"))
166+
scale = None
167+
# Otherwise, we set the model resolution to None and use the scaling factor provided by the user.
168+
else:
169+
model_resolution = None
170+
scale = (2 if is_2d else 3) * (args.scale,)
133171

134172
segmentation_function = partial(
135-
run_segmentation, model=model, model_type=args.model, verbose=False, tiling=tiling, scale=scale
173+
run_segmentation, model=model, model_type=args.model, verbose=False, tiling=tiling,
136174
)
137175
inference_helper(
138176
args.input_path, args.output_path, segmentation_function,
139177
mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext,
140-
output_key=args.segmentation_key,
178+
output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale,
141179
)

0 commit comments

Comments
 (0)