Skip to content

Commit afeb8e7

Browse files
Merge pull request #74 from computational-cell-analytics/software-updates
Software updates
2 parents 8e8c25d + 333b24a commit afeb8e7

File tree

12 files changed

+344
-59
lines changed

12 files changed

+344
-59
lines changed

.github/workflows/run_tests.yaml

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
name: test
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
tags:
8+
- "v*" # Push events to matching v*, i.e. v1.0, v20.15.10
9+
pull_request: # run CI on commits to any open PR
10+
workflow_dispatch: # can manually trigger CI from GitHub actions tab
11+
12+
13+
jobs:
14+
test:
15+
name: ${{ matrix.os }} ${{ matrix.python-version }}
16+
runs-on: ${{ matrix.os }}
17+
timeout-minutes: 60
18+
strategy:
19+
fail-fast: false
20+
matrix:
21+
os: [ubuntu-latest]
22+
python-version: ["3.11"]
23+
steps:
24+
- name: Checkout
25+
uses: actions/checkout@v4
26+
27+
- name: Setup micromamba
28+
uses: mamba-org/setup-micromamba@v1
29+
with:
30+
environment-file: environment_cpu.yaml
31+
create-args: >-
32+
python=${{ matrix.python-version }}
33+
34+
- name: Install SynapseNet
35+
shell: bash -l {0}
36+
run: pip install --no-deps -e .
37+
38+
- name: Run tests
39+
shell: bash -l {0}
40+
run: python -m unittest discover -s test -v

doc/start_page.md

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,53 @@ especially throught the [domain adaptation](domain-adaptation) functionality.
1414
SynapseNet offers a [napari plugin](napari-plugin), [command line interface](command-line-interface), and [python library](python-library).
1515
Please cite our [bioRxiv preprint](TODO) if you use it in your research.
1616

17-
**The rest of the documentation will be updated in the next days!**
1817

1918
## Requirements & Installation
2019

21-
- Requirements: Tested on Linux but should work on Mac/Windows.
22-
- GPU needed to use 3d segmentation networks
23-
- Installation via conda and local pip install
24-
- GPU support
20+
SynapseNet was developed and tested on Linux. It should be possible to install and use it on Mac or Windows, but we have not tested this.
21+
Furthermore, SynapseNet requires a GPU for segmentation of 3D volumes.
22+
23+
You need a [conda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html) or [mamba](https://mamba.readthedocs.io/en/latest/installation/mamba-installation.html) installation. Follow the instruction at the respective links if you have installed neither. We assume you have `conda` for the rest of the instructions. After installing it, you can use the `conda` command.
24+
25+
To install it you should follow these steps:
26+
- First, download the SynapseNet repository via
27+
```bash
28+
git clone https://github.com/computational-cell-analytics/synapse-net
29+
```
30+
- Then, enter the `synapse-net` folder:
31+
```bash
32+
cd synapse-net
33+
```
34+
- Now you can install the environment for SynapseNet with `conda` from the environment file we proved:
35+
```bash
36+
conda env create -f environment.yaml
37+
```
38+
- You will need to confirm this step. It will take a while. Afterwards you can activate the environment:
39+
```bash
40+
conda activate synapse-net
41+
```
42+
- Finally, install SynapseNet itself into the environment:
43+
```bash
44+
pip install -e .
45+
```
46+
47+
Now you can use all SynapseNet features. From now on, just activate the environment via
48+
```
49+
conda activate synapse-net
50+
```
51+
to use them.
52+
53+
> Note: If you use `mamba` instead of conda just replace `conda` in the commands above with `mamba`.
54+
55+
> Note: We also provide an environment for a CPU version of SynapseNet. You can install it by replacing `environment.yaml` with `environment_cpu.yaml` in the respective command above. This version can be used for 2D vesicle segmentation, but it does not work for 3D segmentation.
56+
57+
> Note: If you have issues with the CUDA version then install a PyTorch that matches your nvidia drivers. See [pytorch.org](https://pytorch.org/) for details.
2558
26-
- Make sure conda or mamba is installed.
27-
- If you don't have a conda installation yet we recommend [micromamba](https://mamba.readthedocs.io/en/latest/installation/micromamba-installation.html)
28-
- Create the environment with all required dependencies: `mamba env create -f environment.yaml`
29-
- Activate the environment: `mamba activate synaptic-reconstruction`
30-
- Install the package: `pip install -e .`
3159

3260
## Napari Plugin
3361

34-
lorem ipsum
62+
**The rest of the documentation will be updated in the next days!**
63+
3564

3665
## Command Line Functionality
3766

@@ -40,6 +69,7 @@ lorem ipsum
4069
- vesicles / spheres
4170
- objects
4271

72+
4373
## Python Library
4474

4575
- segmentation functions

environment.yaml

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
11
channels:
2+
- pytorch
3+
- nvidia
24
- conda-forge
35
name:
4-
synaptic-reconstruction
6+
synapse-net
57
dependencies:
6-
- python-elf
8+
- bioimageio.core
9+
- kornia
10+
- magicgui
711
- napari
812
- pip
913
- pyqt
10-
- magicgui
14+
- python-elf
1115
- pytorch
12-
- bioimageio.core
13-
- kornia
16+
- pytorch-cuda=12.4
1417
- tensorboard
18+
- torch_em
19+
- torchvision
1520
- trimesh
1621
- pip:
1722
- napari-skimage-regionprops

environment_cpu.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
channels:
2+
- conda-forge
3+
name:
4+
synapse-net
5+
dependencies:
6+
- bioimageio.core
7+
- kornia
8+
- magicgui
9+
- napari
10+
- pip
11+
- pyqt
12+
- python-elf
13+
- pytorch
14+
- tensorboard
15+
- torch_em
16+
- trimesh
17+
- pip:
18+
- napari-skimage-regionprops

plot_distances.sh

Lines changed: 0 additions & 1 deletion
This file was deleted.

synaptic_reconstruction/inference/util.py

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import imageio.v3 as imageio
1313
import elf.parallel as parallel
14+
import mrcfile
1415
import numpy as np
1516
import torch
1617
import torch_em
@@ -131,7 +132,7 @@ def get_prediction(
131132
# torch_em expects the root folder of a checkpoint path instead of the checkpoint itself.
132133
if model_path.endswith("best.pt"):
133134
model_path = os.path.split(model_path)[0]
134-
print(f"tiling {tiling}")
135+
# print(f"tiling {tiling}")
135136
# Create updated_tiling with the same structure
136137
updated_tiling = {
137138
"tile": {},
@@ -140,7 +141,7 @@ def get_prediction(
140141
# Update tile dimensions
141142
for dim in tiling["tile"]:
142143
updated_tiling["tile"][dim] = tiling["tile"][dim] - 2 * tiling["halo"][dim]
143-
print(f"updated_tiling {updated_tiling}")
144+
# print(f"updated_tiling {updated_tiling}")
144145
pred = get_prediction_torch_em(
145146
input_volume, updated_tiling, model_path, model, verbose, with_channels, mask=mask
146147
)
@@ -252,6 +253,33 @@ def _load_input(img_path, extra_files, i):
252253
return input_volume
253254

254255

256+
def _derive_scale(img_path, model_resolution):
257+
try:
258+
with mrcfile.open(img_path, "r") as f:
259+
voxel_size = f.voxel_size
260+
if len(model_resolution) == 2:
261+
voxel_size = [voxel_size.y, voxel_size.x]
262+
else:
263+
voxel_size = [voxel_size.z, voxel_size.y, voxel_size.x]
264+
265+
assert len(voxel_size) == len(model_resolution)
266+
# The voxel size is given in Angstrom and we need to translate it to nanometer.
267+
voxel_size = [vsize / 10 for vsize in voxel_size]
268+
269+
# Compute the correct scale factor.
270+
scale = tuple(vsize / res for vsize, res in zip(voxel_size, model_resolution))
271+
print("Rescaling the data at", img_path, "by", scale, "to match the training voxel size", model_resolution)
272+
273+
except Exception:
274+
warnings.warn(
275+
f"The voxel size could not be read from the data for {img_path}. "
276+
"This data will not be scaled for prediction."
277+
)
278+
scale = None
279+
280+
return scale
281+
282+
255283
def inference_helper(
256284
input_path: str,
257285
output_root: str,
@@ -263,6 +291,8 @@ def inference_helper(
263291
mask_input_ext: str = ".tif",
264292
force: bool = False,
265293
output_key: Optional[str] = None,
294+
model_resolution: Optional[Tuple[float, float, float]] = None,
295+
scale: Optional[Tuple[float, float, float]] = None,
266296
) -> None:
267297
"""Helper function to run segmentation for mrc files.
268298
@@ -282,7 +312,13 @@ def inference_helper(
282312
mask_input_ext: File extension for the mask inputs (by default .tif).
283313
force: Whether to rerun segmentation for output files that are already present.
284314
output_key: Output key for the prediction. If none will write an hdf5 file.
315+
model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction.
316+
If given, the scaling factor will automatically be determined based on the voxel_size of the input data.
317+
scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'.
285318
"""
319+
if (scale is not None) and (model_resolution is not None):
320+
raise ValueError("You must not provide both 'scale' and 'model_resolution' arguments.")
321+
286322
# Get the input files. If input_path is a folder then this will load all
287323
# the mrc files beneath it. Otherwise we assume this is an mrc file already
288324
# and just return the path to this mrc file.
@@ -333,8 +369,18 @@ def inference_helper(
333369
# Load the mask (if given).
334370
mask = None if mask_files is None else imageio.imread(mask_files[i])
335371

372+
# Determine the scale factor:
373+
# If the neither the 'scale' nor 'model_resolution' arguments were passed then set it to None.
374+
if scale is None and model_resolution is None:
375+
this_scale = None
376+
elif scale is not None: # If 'scale' was passed then use it.
377+
this_scale = scale
378+
else: # Otherwise 'model_resolution' was passed, use it to derive the scaling from the data
379+
assert model_resolution is not None
380+
this_scale = _derive_scale(img_path, model_resolution)
381+
336382
# Run the segmentation.
337-
segmentation = segmentation_function(input_volume, mask=mask)
383+
segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale)
338384

339385
# Write the result to tif or h5.
340386
os.makedirs(os.path.split(output_path)[0], exist_ok=True)
@@ -348,15 +394,21 @@ def inference_helper(
348394
print(f"Saved segmentation to {output_path}.")
349395

350396

351-
def get_default_tiling() -> Dict[str, Dict[str, int]]:
397+
def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]:
352398
"""Determine the tile shape and halo depending on the available VRAM.
353399
400+
Args:
401+
is_2d: Whether to return tiling settings for 2d inference.
402+
354403
Returns:
355404
The default tiling settings for the available computational resources.
356405
"""
357-
if torch.cuda.is_available():
358-
print("Determining suitable tiling")
406+
if is_2d:
407+
tile = {"x": 768, "y": 768, "z": 1}
408+
halo = {"x": 128, "y": 128, "z": 0}
409+
return {"tile": tile, "halo": halo}
359410

411+
if torch.cuda.is_available():
360412
# We always use the same default halo.
361413
halo = {"x": 64, "y": 64, "z": 16} # before 64,64,8
362414

@@ -390,19 +442,21 @@ def get_default_tiling() -> Dict[str, Dict[str, int]]:
390442

391443
def parse_tiling(
392444
tile_shape: Tuple[int, int, int],
393-
halo: Tuple[int, int, int]
445+
halo: Tuple[int, int, int],
446+
is_2d: bool = False,
394447
) -> Dict[str, Dict[str, int]]:
395448
"""Helper function to parse tiling parameter input from the command line.
396449
397450
Args:
398451
tile_shape: The tile shape. If None the default tile shape is used.
399452
halo: The halo. If None the default halo is used.
453+
is_2d: Whether to return tiling for a 2d model.
400454
401455
Returns:
402456
The tiling specification.
403457
"""
404458

405-
default_tiling = get_default_tiling()
459+
default_tiling = get_default_tiling(is_2d=is_2d)
406460

407461
if tile_shape is None:
408462
tile_shape = default_tiling["tile"]
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import os
2+
import pooch
3+
4+
5+
def get_sample_data(name: str) -> str:
6+
"""Get the filepath to SynapseNet sample data, stored as mrc file.
7+
8+
Args:
9+
name: The name of the sample data. Currently, we only provide the 'tem_2d' sample data.
10+
11+
Returns:
12+
The filepath to the downloaded sample data.
13+
"""
14+
registry = {
15+
"tem_2d.mrc": "3c6f9ff6d7673d9bf2fd46c09750c3c7dbb8fa1aa59dcdb3363b65cc774dcf28",
16+
}
17+
urls = {
18+
"tem_2d.mrc": "https://owncloud.gwdg.de/index.php/s/5sAQ0U4puAspcHg/download",
19+
}
20+
key = f"{name}.mrc"
21+
22+
if key not in registry:
23+
valid_names = [k[:-4] for k in registry.keys()]
24+
raise ValueError(f"Invalid sample name {name}, please choose one of {valid_names}.")
25+
26+
cache_dir = os.path.expanduser(pooch.os_cache("synapse-net"))
27+
data_registry = pooch.create(
28+
path=os.path.join(cache_dir, "sample_data"),
29+
base_url="",
30+
registry=registry,
31+
urls=urls,
32+
)
33+
file_path = data_registry.fetch(key)
34+
return file_path

0 commit comments

Comments
 (0)