diff --git a/scripts/inner_ear/check_results.py b/scripts/inner_ear/check_results.py index 7bf3e5d..c9f2c82 100644 --- a/scripts/inner_ear/check_results.py +++ b/scripts/inner_ear/check_results.py @@ -1,3 +1,4 @@ +import json import os import sys from glob import glob @@ -204,10 +205,11 @@ def visualize_folder(folder, segmentation_version, visualize_distances, binning) for name, seg in segmentations.items(): # The function signature of the label layer has recently changed, # and we still need to support both versions. + visible = name != "vesicles" try: - v.add_labels(seg, name=name, color=colors.get(name, None), scale=scale) + v.add_labels(seg, name=name, color=colors.get(name, None), scale=scale, visible=visible) except TypeError: - v.add_labels(seg, name=name, colormap=colors.get(name, None), scale=scale) + v.add_labels(seg, name=name, colormap=colors.get(name, None), scale=scale, visible=visible) for name, lines in distance_lines.items(): v.add_shapes(lines, shape_type="line", name=name, visible=False, scale=scale) @@ -250,7 +252,7 @@ def visualize_all_data( data_root, table, segmentation_version=None, check_micro=None, visualize_distances=False, skip_iteration=None, - binning="auto", val_table=None, + binning="auto", val_table=None, tomo_list=None, ): from parse_table import check_val_table @@ -264,7 +266,13 @@ def visualize_all_data( if skip_iteration is not None and i < skip_iteration: continue - if val_table is not None: + if tomo_list is not None: + tomo_name = os.path.relpath( + folder, os.path.join(data_root, "Electron-Microscopy-Susi/Analyse") + ) + if tomo_name not in tomo_list: + continue + elif val_table is not None: is_complete = check_val_table(val_table, row) if is_complete: continue @@ -303,6 +311,7 @@ def main(): parser.add_argument("-d", "--visualize_distances", action="store_false") parser.add_argument("-b", "--binning", default="auto") parser.add_argument("-s", "--show_finished", action="store_true") + parser.add_argument("--tomos") # Optional list of tomograms. args = parser.parse_args() assert args.microscope in (None, "both", "old", "new") @@ -324,11 +333,16 @@ def main(): val_table_path = os.path.join(data_root, "Electron-Microscopy-Susi", "Validierungs-Tabelle-v3.xlsx") val_table = pandas.read_excel(val_table_path) + tomo_list = args.tomos + if tomo_list is not None: + with open(tomo_list) as f: + tomo_list = json.load(f) + visualize_all_data( data_root, table, segmentation_version=segmentation_version, check_micro=args.microscope, visualize_distances=args.visualize_distances, skip_iteration=args.iteration, - binning=binning, val_table=val_table + binning=binning, val_table=val_table, tomo_list=tomo_list ) diff --git a/scripts/inner_ear/clear_tomos_with_issues.py b/scripts/inner_ear/clear_tomos_with_issues.py new file mode 100644 index 0000000..fa076ef --- /dev/null +++ b/scripts/inner_ear/clear_tomos_with_issues.py @@ -0,0 +1,16 @@ +import json +import os + +root = "/home/pape/Work/data/moser/em-synapses/Electron-Microscopy-Susi/Analyse" +with open("tomo_issues.json", "r") as f: + tomos = json.load(f) + +for name in tomos: + path = os.path.join(root, name, "Korrektur", "measurements.xlsx") + if not os.path.exists(path): + path = os.path.join(root, name, "korrektur", "measurements.xlsx") + if os.path.exists(path): + print("Removing", path) + os.remove(path) + else: + print("Skipping", path) diff --git a/scripts/inner_ear/export_reformatted_results.py b/scripts/inner_ear/export_reformatted_results.py new file mode 100644 index 0000000..c23d6f0 --- /dev/null +++ b/scripts/inner_ear/export_reformatted_results.py @@ -0,0 +1,120 @@ +import os +import argparse +import pandas as pd + + +def aggregate_statistics(vesicle_table, morpho_table): + boundary_dists = {"All": [], "RA-V": [], "MP-V": [], "Docked-V": []} + pd_dists = {"All": [], "RA-V": [], "MP-V": [], "Docked-V": []} + ribbon_dists = {"All": [], "RA-V": [], "MP-V": [], "Docked-V": []} + radii = {"All": [], "RA-V": [], "MP-V": [], "Docked-V": []} + + tomo_names = [] + n_ravs, n_mpvs, n_dockeds = [], [], [] + n_vesicles, ves_per_surfs, ribbon_ids = [], [], [] + + tomograms = pd.unique(vesicle_table.tomogram) + for tomo in tomograms: + tomo_table = vesicle_table[vesicle_table.tomogram == tomo] + this_ribbon_ids = pd.unique(tomo_table.ribbon_id) + + for ribbon_id in this_ribbon_ids: + + ribbon_table = tomo_table[tomo_table.ribbon_id == ribbon_id] + # FIXME we need the ribbon_id for the morpho table + this_morpho_table = morpho_table[morpho_table.tomogram == tomo] + + rav_mask = ribbon_table.pool == "RA-V" + mpv_mask = ribbon_table.pool == "MP-V" + docked_mask = ribbon_table.pool == "Docked-V" + + masks = {"All": ribbon_table.pool != "", "RA-V": rav_mask, "MP-V": mpv_mask, "Docked-V": docked_mask} + + for pool, mask in masks.items(): + pool_table = ribbon_table[mask] + radii[pool].append(pool_table["radius [nm]"].mean()) + ribbon_dists[pool].append(pool_table["ribbon_distance [nm]"].mean()) + pd_dists[pool].append(pool_table["pd_distance [nm]"].mean()) + boundary_dists[pool].append(pool_table["boundary_distance [nm]"].mean()) + + tomo_names.append(tomo) + ribbon_ids.append(ribbon_id) + n_rav = rav_mask.sum() + n_mpv = mpv_mask.sum() + n_docked = docked_mask.sum() + + n_ves = n_rav + n_mpv + n_docked + ribbon_surface = this_morpho_table[this_morpho_table.structure == "ribbon"]["surface [nm^2]"].values[0] + ves_per_surface = n_ves / ribbon_surface + + n_ravs.append(n_rav) + n_mpvs.append(n_mpv) + n_dockeds.append(n_docked) + n_vesicles.append(n_ves) + ves_per_surfs.append(ves_per_surface) + + summary = { + "tomogram": tomo_names, + "ribbon_id": ribbon_ids, + "N_RA-V": n_ravs, + "N_MP-V": n_mpvs, + "N_Docked-V": n_dockeds, + "N_Vesicles": n_vesicles, + "Vesicles / Surface [1 / nm^2]": ves_per_surfs, + } + summary.update({f"{pool}: radius [nm]": dists for pool, dists in radii.items()}) + summary.update({f"{pool}: ribbon_distance [nm]": dists for pool, dists in ribbon_dists.items()}) + summary.update({f"{pool}: pd_distance [nm]": dists for pool, dists in pd_dists.items()}) + summary.update({f"{pool}: boundary_distance [nm]": dists for pool, dists in boundary_dists.items()}) + summary = pd.DataFrame(summary) + return summary + + +# TODO +# - add ribbon id to the morphology table! +def export_reformatted_results(input_path, output_path): + vesicle_table = pd.read_excel(input_path, sheet_name="vesicles") + morpho_table = pd.read_excel(input_path, sheet_name="morphology") + + vesicle_table["stimulation"] = vesicle_table["tomogram"].apply(lambda x: x.split("/")[0]) + # Separating by mouse is currently not required, but we leave in the column for now. + vesicle_table["mouse"] = vesicle_table["tomogram"].apply(lambda x: x.split("/")[-3]) + vesicle_table["pil_v_mod"] = vesicle_table["tomogram"].apply(lambda x: x.split("/")[-2]) + + # For now: export only the vesicle pools per tomogram. + for stim in ("WT strong stim", "WT control"): + for condition in ("pillar", "modiolar"): + condition_table = vesicle_table[ + (vesicle_table.stimulation == stim) & (vesicle_table.pil_v_mod == condition) + ] + + this_tomograms = pd.unique(condition_table.tomogram) + this_morpho_table = morpho_table[morpho_table.tomogram.isin(this_tomograms)] + condition_table = aggregate_statistics(condition_table, this_morpho_table) + + # Simpler aggregation for just the number of vesicles. + # condition_table = condition_table.pivot_table( + # index=["tomogram", "ribbon_id"], columns="pool", aggfunc="size", fill_value=0 + # ).reset_index() + + sheet_name = f"{stim}-{condition}" + + if os.path.exists(output_path): + with pd.ExcelWriter(output_path, engine="openpyxl", mode="a") as writer: + condition_table.to_excel(writer, sheet_name=sheet_name, index=False) + else: + condition_table.to_excel(output_path, sheet_name=sheet_name, index=False) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_path", "-i", + default="results/20250221_1/automatic_analysis_results.xlsx" + ) + parser.add_argument( + "--output_path", "-o", + default="results/vesicle_pools_automatic.xlsx" + ) + args = parser.parse_args() + export_reformatted_results(args.input_path, args.output_path) diff --git a/scripts/inner_ear/processing/filter_objects.py b/scripts/inner_ear/processing/filter_objects.py new file mode 100644 index 0000000..258ec58 --- /dev/null +++ b/scripts/inner_ear/processing/filter_objects.py @@ -0,0 +1,134 @@ +import os +from pathlib import Path +from tqdm import tqdm + +import h5py +import imageio.v3 as imageio +import numpy as np +from skimage.measure import label +from skimage.segmentation import relabel_sequential + +from synapse_net.file_utils import get_data_path +from parse_table import parse_table, get_data_root, _match_correction_folder, _match_correction_file + + +def _load_segmentation(seg_path): + ext = Path(seg_path).suffix + assert ext in (".h5", ".tif"), ext + if ext == ".tif": + seg = imageio.imread(seg_path) + else: + with h5py.File(seg_path, "r") as f: + seg = f["segmentation"][:] + return seg + + +def _save_segmentation(seg_path, seg): + ext = Path(seg_path).suffix + assert ext in (".h5", ".tif"), ext + if ext == ".tif": + imageio.imwrite(seg_path, seg, compression="zlib") + else: + with h5py.File(seg_path, "a") as f: + f.create_dataset("segmentation", data=seg, compression="gzip") + return seg + + +def _filter_n_objects(segmentation, num_objects): + # Create individual objects for all disconnected pieces. + segmentation = label(segmentation) + # Find object ids and sizes, excluding background. + ids, sizes = np.unique(segmentation, return_counts=True) + ids, sizes = ids[1:], sizes[1:] + # Only keep the biggest 'num_objects' objects. + keep_ids = ids[np.argsort(sizes)[::-1]][:num_objects] + segmentation[~np.isin(segmentation, keep_ids)] = 0 + # Relabel the segmentation sequentially. + segmentation, _, _ = relabel_sequential(segmentation) + # Ensure that we have the correct number of objects. + n_ids = int(segmentation.max()) + assert n_ids == num_objects + return segmentation + + +def process_tomogram(folder, num_ribbon, num_pd): + data_path = get_data_path(folder) + output_folder = os.path.join(folder, "automatisch", "v2") + fname = Path(data_path).stem + + correction_folder = _match_correction_folder(folder) + + ribbon_path = _match_correction_file(correction_folder, "ribbon") + if not os.path.exists(ribbon_path): + ribbon_path = os.path.join(output_folder, f"{fname}_ribbon.h5") + assert os.path.exists(ribbon_path), ribbon_path + ribbon = _load_segmentation(ribbon_path) + + pd_path = _match_correction_file(correction_folder, "PD") + if not os.path.exists(pd_path): + pd_path = os.path.join(output_folder, f"{fname}_pd.h5") + assert os.path.exists(pd_path), pd_path + PD = _load_segmentation(pd_path) + + # Filter the ribbon and the PD. + print("Filtering number of ribbons:", num_ribbon) + ribbon = _filter_n_objects(ribbon, num_ribbon) + bkp_path_ribbon = ribbon_path + ".bkp" + os.rename(ribbon_path, bkp_path_ribbon) + _save_segmentation(ribbon_path, ribbon) + + print("Filtering number of PDs:", num_pd) + PD = _filter_n_objects(PD, num_pd) + bkp_path_pd = pd_path + ".bkp" + os.rename(pd_path, bkp_path_pd) + _save_segmentation(pd_path, PD) + + +def filter_objects(table, version): + for i, row in tqdm(table.iterrows(), total=len(table)): + folder = row["Local Path"] + if folder == "": + continue + + # We have to handle the segmentation without ribbon separately. + if row["PD vorhanden? "] == "nein": + continue + + n_pds = row["Anzahl PDs"] + if n_pds == "unklar": + n_pds = 1 + + n_pds = int(n_pds) + n_ribbons = int(row["Anzahl Ribbons"]) + if (n_ribbons == 2 and n_pds == 1): + print(f"The tomogram {folder} has {n_ribbons} ribbons and {n_pds} PDs.") + print("The structure post-processing for this case is not yet implemented and will be skipped.") + continue + + micro = row["EM alt vs. Neu"] + if micro == "beides": + process_tomogram(folder, n_ribbons, n_pds) + + folder_new = os.path.join(folder, "Tomo neues EM") + if not os.path.exists(folder_new): + folder_new = os.path.join(folder, "neues EM") + assert os.path.exists(folder_new), folder_new + process_tomogram(folder_new, n_ribbons, n_pds) + + elif micro == "alt": + process_tomogram(folder, n_ribbons, n_pds) + + elif micro == "neu": + process_tomogram(folder, n_ribbons, n_pds) + + +def main(): + data_root = get_data_root() + table_path = os.path.join(data_root, "Electron-Microscopy-Susi", "Übersicht.xlsx") + table = parse_table(table_path, data_root) + version = 2 + filter_objects(table, version) + + +if __name__ == "__main__": + main() diff --git a/scripts/inner_ear/processing/run_analyis.py b/scripts/inner_ear/processing/run_analyis.py index ed6ccdd..afafcb9 100644 --- a/scripts/inner_ear/processing/run_analyis.py +++ b/scripts/inner_ear/processing/run_analyis.py @@ -54,6 +54,7 @@ def _load_segmentation(seg_path, tomo_shape): def compute_distances(segmentation_paths, save_folder, resolution, force, tomo_shape, use_corrected_vesicles=True): + print(save_folder) os.makedirs(save_folder, exist_ok=True) vesicles = None @@ -78,6 +79,15 @@ def _require_vesicles(): ribbon_path = segmentation_paths["ribbon"] ribbon = _load_segmentation(ribbon_path, tomo_shape) + import napari + print("Tomo :", tomo_shape) + print("Ribbon :", ribbon.shape) + print("Vesicles:", vesicles.shape) + v = napari.Viewer() + v.add_labels(ribbon) + v.add_labels(vesicles) + napari.run() + if ribbon is None or ribbon.sum() == 0: print("The ribbon segmentation at", segmentation_paths["ribbon"], "is empty. Skipping analysis.") return None, True @@ -102,6 +112,8 @@ def _require_vesicles(): mem_path = segmentation_paths["membrane"] membrane = _load_segmentation(mem_path, tomo_shape) + if membrane is None: + return None, True try: measure_segmentation_to_object_distances( vesicles, membrane, save_path=membrane_save, resolution=resolution @@ -469,7 +481,7 @@ def main(): version = 2 force = False - use_corrected_vesicles = False + use_corrected_vesicles = True # val_table_path = os.path.join(data_root, "Electron-Microscopy-Susi", "Validierungs-Tabelle-v3.xlsx") # val_table = pandas.read_excel(val_table_path) diff --git a/scripts/inner_ear/run_final_analysis_pipeline.py b/scripts/inner_ear/run_final_analysis_pipeline.py index fb549e8..51fc24f 100644 --- a/scripts/inner_ear/run_final_analysis_pipeline.py +++ b/scripts/inner_ear/run_final_analysis_pipeline.py @@ -8,6 +8,7 @@ from add_summary_stats_to_table import add_summary_stats from combine_measurements import combine_automatic_results, combine_manual_results from compare_pool_assignments import create_manual_assignment, compare_assignments, update_measurements +from export_reformatted_results import export_reformatted_results sys.path.append("processing") @@ -66,9 +67,16 @@ def main(): update_measurements(data_root, tomograms, manual_res_path, output_path=full_manual_res_path) # Step 4: Add summary stats to all the tables. - for tab in [automatic_res_path, manual_res_path, full_manual_res_path]: + result_tables = [automatic_res_path, manual_res_path, full_manual_res_path] + for tab in result_tables: add_summary_stats(tab) + # Step 5: Export the reformatted results. + for tab in result_tables: + name = os.path.basename(tab) + out_tab = os.path.join(output_folder, f"summary_{name}") + export_reformatted_results(tab, out_tab) + if __name__ == "__main__": main() diff --git a/scripts/inner_ear/tomo_issues.json b/scripts/inner_ear/tomo_issues.json new file mode 100644 index 0000000..5d03944 --- /dev/null +++ b/scripts/inner_ear/tomo_issues.json @@ -0,0 +1,10 @@ +[ + "WT strong stim/Mouse 1/modiolar/5", + "WT strong stim/Mouse 1/modiolar/6", + "WT strong stim/Mouse 1/modiolar/17", + "WT strong stim/Mouse 2/modiolar/5", + "WT strong stim/Mouse 2/modiolar/7", + "WT strong stim/Mouse 2/modiolar/15", + "WT control/Mouse 2/modiolar/1", + "WT control/Mouse 2/modiolar/15" +] diff --git a/scripts/otoferlin/.gitignore b/scripts/otoferlin/.gitignore new file mode 100644 index 0000000..59d5ea8 --- /dev/null +++ b/scripts/otoferlin/.gitignore @@ -0,0 +1,4 @@ +data/ +sync_segmentation.sh +segmentation/ +results/ diff --git a/scripts/otoferlin/README.md b/scripts/otoferlin/README.md new file mode 100644 index 0000000..a96eda3 --- /dev/null +++ b/scripts/otoferlin/README.md @@ -0,0 +1,9 @@ +# Otoferlin Analysis + + +## Notes on improvements: + +- Try less than 20 exclude slices +- Update boundary post-proc (not robust when PD not found and selects wrong objects) +- Can we find fiducials with ilastik and mask them out? They are interfering with Ribbon finding. + - Alternative: just restrict the processing to a center crop by default. diff --git a/scripts/otoferlin/automatic_processing.py b/scripts/otoferlin/automatic_processing.py new file mode 100644 index 0000000..531a38c --- /dev/null +++ b/scripts/otoferlin/automatic_processing.py @@ -0,0 +1,201 @@ +import os + +import h5py +import numpy as np + +from skimage.measure import label +from skimage.segmentation import relabel_sequential + +from synapse_net.distance_measurements import measure_segmentation_to_object_distances +from synapse_net.file_utils import read_mrc +from synapse_net.inference.vesicles import segment_vesicles +from synapse_net.tools.util import get_model, compute_scale_from_voxel_size, _segment_ribbon_AZ +from tqdm import tqdm + +from common import get_all_tomograms, get_seg_path, get_adapted_model, load_segmentations + +# These are tomograms for which the sophisticated membrane processing fails. +# In this case, we just select the largest boundary piece. +SIMPLE_MEM_POSTPROCESSING = [ + "Otof_TDAKO1blockA_GridN5_2_rec.mrc", "Otof_TDAKO2blockC_GridF5_1_rec.mrc", "Otof_TDAKO2blockC_GridF5_2_rec.mrc", + "Bl6_NtoTDAWT1_blockH_GridF3_1_rec.mrc", "Bl6_NtoTDAWT1_blockH_GridG2_3_rec.mrc", "Otof_TDAKO1blockA_GridN5_5_rec.mrc", + "Otof_TDAKO2blockC_GridE2_1_rec.mrc", "Otof_TDAKO2blockC_GridE2_2_rec.mrc", + +] + + +def _get_center_crop(input_): + halo_xy = (600, 600) + bb_xy = tuple( + slice(max(sh // 2 - ha, 0), min(sh // 2 + ha, sh)) for sh, ha in zip(input_.shape[1:], halo_xy) + ) + bb = (np.s_[:],) + bb_xy + return bb, input_.shape + + +def _get_tiling(): + # tile = {"x": 768, "y": 768, "z": 48} + tile = {"x": 512, "y": 512, "z": 48} + halo = {"x": 128, "y": 128, "z": 8} + return {"tile": tile, "halo": halo} + + +def process_vesicles(mrc_path, output_path, process_center_crop): + key = "segmentation/vesicles" + if os.path.exists(output_path): + with h5py.File(output_path, "r") as f: + if key in f: + return + + input_, voxel_size = read_mrc(mrc_path) + if process_center_crop: + bb, full_shape = _get_center_crop(input_) + input_ = input_[bb] + + model = get_adapted_model() + scale = compute_scale_from_voxel_size(voxel_size, "ribbon") + print("Rescaling volume for vesicle segmentation with factor:", scale) + tiling = _get_tiling() + segmentation = segment_vesicles(input_, model=model, scale=scale, tiling=tiling) + + if process_center_crop: + full_seg = np.zeros(full_shape, dtype=segmentation.dtype) + full_seg[bb] = segmentation + segmentation = full_seg + + with h5py.File(output_path, "a") as f: + f.create_dataset(key, data=segmentation, compression="gzip") + + +def _simple_membrane_postprocessing(membrane_prediction): + seg = label(membrane_prediction) + ids, sizes = np.unique(seg, return_counts=True) + ids, sizes = ids[1:], sizes[1:] + return (seg == ids[np.argmax(sizes)]).astype("uint8") + + +def process_ribbon_structures(mrc_path, output_path, process_center_crop): + key = "segmentation/ribbon" + with h5py.File(output_path, "r") as f: + if key in f: + return + vesicles = f["segmentation/vesicles"][:] + + input_, voxel_size = read_mrc(mrc_path) + if process_center_crop: + bb, full_shape = _get_center_crop(input_) + input_, vesicles = input_[bb], vesicles[bb] + assert input_.shape == vesicles.shape + + model_name = "ribbon" + model = get_model(model_name) + scale = compute_scale_from_voxel_size(voxel_size, model_name) + tiling = _get_tiling() + + segmentations, predictions = _segment_ribbon_AZ( + input_, model, tiling=tiling, scale=scale, verbose=True, extra_segmentation=vesicles, + return_predictions=True, n_slices_exclude=5, + ) + + # The distance based post-processing for membranes fails for some tomograms. + # In these cases, just choose the largest membrane piece. + fname = os.path.basename(mrc_path) + if fname in SIMPLE_MEM_POSTPROCESSING: + segmentations["membrane"] = _simple_membrane_postprocessing(predictions["membrane"]) + + if process_center_crop: + for name, seg in segmentations.items(): + full_seg = np.zeros(full_shape, dtype=seg.dtype) + full_seg[bb] = seg + segmentations[name] = full_seg + for name, pred in predictions.items(): + full_pred = np.zeros(full_shape, dtype=seg.dtype) + full_pred[bb] = pred + predictions[name] = full_pred + + with h5py.File(output_path, "a") as f: + for name, seg in segmentations.items(): + f.create_dataset(f"segmentation/{name}", data=seg, compression="gzip") + f.create_dataset(f"prediction/{name}", data=predictions[name], compression="gzip") + + +def postprocess_vesicles( + mrc_path, output_path, process_center_crop, force=False +): + key = "segmentation/veiscles_postprocessed" + with h5py.File(output_path, "r") as f: + if key in f and not force: + return + vesicles = f["segmentation/vesicles"][:] + if process_center_crop: + bb, full_shape = _get_center_crop(vesicles) + vesicles = vesicles[bb] + else: + bb = np.s_[:] + + segs = load_segmentations(output_path) + ribbon = segs["ribbon"][bb] + membrane = segs["membrane"][bb] + + # Filter out small vesicle fragments. + min_size = 5000 + ids, sizes = np.unique(vesicles, return_counts=True) + ids, sizes = ids[1:], sizes[1:] + filter_ids = ids[sizes < min_size] + vesicles[np.isin(vesicles, filter_ids)] = 0 + + input_, voxel_size = read_mrc(mrc_path) + voxel_size = tuple(voxel_size[ax] for ax in "zyx") + input_ = input_[bb] + + # Filter out all vesicles farther than 120 nm from the membrane or ribbon. + max_dist = 120 + seg = (ribbon + membrane) > 0 + distances, _, _, seg_ids = measure_segmentation_to_object_distances(vesicles, seg, resolution=voxel_size) + filter_ids = seg_ids[distances > max_dist] + vesicles[np.isin(vesicles, filter_ids)] = 0 + + vesicles, _, _ = relabel_sequential(vesicles) + + if process_center_crop: + full_seg = np.zeros(full_shape, dtype=vesicles.dtype) + full_seg[bb] = vesicles + vesicles = full_seg + with h5py.File(output_path, "a") as f: + if key in f: + f[key][:] = vesicles + else: + f.create_dataset(key, data=vesicles, compression="gzip") + + +def process_tomogram(mrc_path): + output_path = get_seg_path(mrc_path) + output_folder = os.path.split(output_path)[0] + os.makedirs(output_folder, exist_ok=True) + + process_center_crop = True + + process_vesicles(mrc_path, output_path, process_center_crop) + process_ribbon_structures(mrc_path, output_path, process_center_crop) + postprocess_vesicles(mrc_path, output_path, process_center_crop) + + +def main(): + tomograms = get_all_tomograms() + # for tomogram in tqdm(tomograms, desc="Process tomograms"): + # process_tomogram(tomogram) + + # Update the membrane postprocessing for the tomograms where this went wrong. + for tomo in tqdm(tomograms, desc="Fix membrame postprocesing"): + if os.path.basename(tomo) not in SIMPLE_MEM_POSTPROCESSING: + continue + seg_path = get_seg_path(tomo) + with h5py.File(seg_path, "r") as f: + pred = f["prediction/membrane"][:] + seg = _simple_membrane_postprocessing(pred) + with h5py.File(seg_path, "a") as f: + f["segmentation/membrane"][:] = seg + + +if __name__ == "__main__": + main() diff --git a/scripts/otoferlin/check_automatic_result.py b/scripts/otoferlin/check_automatic_result.py new file mode 100644 index 0000000..4c4c46c --- /dev/null +++ b/scripts/otoferlin/check_automatic_result.py @@ -0,0 +1,60 @@ +import os + +import h5py +import napari +import numpy as np + +from synapse_net.file_utils import read_mrc +from skimage.exposure import equalize_adapthist +from tqdm import tqdm + +from common import get_all_tomograms, get_seg_path, get_colormaps + + +def check_automatic_result(mrc_path, version, use_clahe=False, center_crop=True, segmentation_group="segmentation"): + tomogram, _ = read_mrc(mrc_path) + if center_crop: + halo = (50, 512, 512) + bb = tuple( + slice(max(sh // 2 - ha, 0), min(sh // 2 + ha, sh)) for sh, ha in zip(tomogram.shape, halo) + ) + tomogram = tomogram[bb] + else: + bb = np.s_[:] + + if use_clahe: + print("Run CLAHE ...") + tomogram = equalize_adapthist(tomogram, clip_limit=0.03) + print("... done") + + seg_path = get_seg_path(mrc_path, version) + segmentations, colormaps = {}, {} + if os.path.exists(seg_path): + with h5py.File(seg_path, "r") as f: + g = f[segmentation_group] + for name, ds in g.items(): + segmentations[name] = ds[bb] + colormaps[name] = get_colormaps().get(name, None) + + v = napari.Viewer() + v.add_image(tomogram) + for name, seg in segmentations.items(): + v.add_labels(seg, name=name, colormap=colormaps.get(name)) + v.title = os.path.basename(mrc_path) + napari.run() + + +def main(): + version = 2 + tomograms = get_all_tomograms() + for i, tomogram in tqdm( + enumerate(tomograms), total=len(tomograms), desc="Visualize automatic segmentation results" + ): + print("Checking tomogram", tomogram) + check_automatic_result(tomogram, version) + # check_automatic_result(tomogram, version, segmentation_group="vesicles") + # check_automatic_result(tomogram, version, segmentation_group="prediction") + + +if __name__: + main() diff --git a/scripts/otoferlin/check_structure_postprocessing.py b/scripts/otoferlin/check_structure_postprocessing.py new file mode 100644 index 0000000..6b8d4de --- /dev/null +++ b/scripts/otoferlin/check_structure_postprocessing.py @@ -0,0 +1,56 @@ +import os + +import h5py +import napari +import numpy as np + +from synapse_net.file_utils import read_mrc +from tqdm import tqdm + +from common import get_seg_path, get_all_tomograms, get_colormaps, STRUCTURE_NAMES + + +def check_structure_postprocessing(mrc_path, center_crop=True): + tomogram, _ = read_mrc(mrc_path) + if center_crop: + halo = (50, 512, 512) + bb = tuple( + slice(max(sh // 2 - ha, 0), min(sh // 2 + ha, sh)) for sh, ha in zip(tomogram.shape, halo) + ) + tomogram = tomogram[bb] + else: + bb = np.s_[:] + + seg_path = get_seg_path(mrc_path) + assert os.path.exists(seg_path) + + segmentations, predictions, colormaps = {}, {}, {} + with h5py.File(seg_path, "r") as f: + g = f["segmentation"] + for name in STRUCTURE_NAMES: + segmentations[f"seg/{name}"] = g[name][bb] + colormaps[name] = get_colormaps().get(name, None) + + g = f["prediction"] + for name in STRUCTURE_NAMES: + predictions[f"pred/{name}"] = g[name][bb] + + v = napari.Viewer() + v.add_image(tomogram) + for name, seg in segmentations.items(): + v.add_labels(seg, name=name, colormap=colormaps.get(name.split("/")[1])) + for name, pred in predictions.items(): + v.add_labels(pred, name=name, colormap=colormaps.get(name.split("/")[1]), visible=False) + v.title = os.path.basename(mrc_path) + napari.run() + + +def main(): + tomograms = get_all_tomograms() + for i, tomogram in tqdm(enumerate(tomograms), total=len(tomograms), desc="Check structure postproc"): + print(tomogram) + check_structure_postprocessing(tomogram) + + +if __name__: + main() diff --git a/scripts/otoferlin/common.py b/scripts/otoferlin/common.py new file mode 100644 index 0000000..9dd0ca7 --- /dev/null +++ b/scripts/otoferlin/common.py @@ -0,0 +1,116 @@ +import os +from glob import glob + +import imageio.v3 as imageio +import h5py +import pandas as pd +from synapse_net.tools.util import load_custom_model + + +# These are the files just for the test data. +# INPUT_ROOT = "/home/ag-wichmann/data/test-data/tomograms" +# OUTPUT_ROOT = "/home/ag-wichmann/data/test-data/segmentation" + +# These are the otoferlin tomograms. +INPUT_ROOT = "/home/ag-wichmann/data/otoferlin/tomograms" +OUTPUT_ROOT = "./segmentation" + +STRUCTURE_NAMES = ("ribbon", "PD", "membrane") + +# The version of the automatic segmentation. We have: +# - version 1: using the default models for all structures and the initial version of post-processing. +# - version 2: using the adapted model for vesicles in the otoferlin and updating the post-processing. +VERSION = 2 + + +def get_adapted_model(): + # Path on nhr. + # model_path = "/mnt/vast-nhr/home/pape41/u12086/Work/my_projects/synaptic-reconstruction/scripts/otoferlin/domain_adaptation/checkpoints/otoferlin_da.pt" # noqa + # Path on the Workstation. + model_path = "/home/ag-wichmann/Downloads/otoferlin_da.pt" + model = load_custom_model(model_path) + return model + + +def get_folders(): + if os.path.exists(INPUT_ROOT): + return INPUT_ROOT, OUTPUT_ROOT + root_in = "./data/tomograms" + assert os.path.exists(root_in) + return root_in, OUTPUT_ROOT + + +def load_table(): + table_path = "overview Otoferlin samples.xlsx" + table_mut = pd.read_excel(table_path, sheet_name="Mut") + table_wt = pd.read_excel(table_path, sheet_name="Wt") + table = pd.concat([table_mut, table_wt]) + table = table[table["Einschluss? "] == "ja"] + return table + + +def get_all_tomograms(restrict_to_good_tomos=False, restrict_to_nachgeb=False): + root, _ = get_folders() + tomograms = glob(os.path.join(root, "**", "*.mrc"), recursive=True) + tomograms += glob(os.path.join(root, "**", "*.rec"), recursive=True) + tomograms = sorted(tomograms) + if restrict_to_good_tomos: + table = load_table() + if restrict_to_nachgeb: + table = table[table["nachgebessert"] == "ja"] + fnames = [os.path.basename(row["File name"]) for _, row in table.iterrows()] + tomograms = [tomo for tomo in tomograms if os.path.basename(tomo) in fnames] + # assert len(tomograms) == len(table), f"{len(tomograms), len(table)}" + return tomograms + + +def get_seg_path(mrc_path, version=VERSION): + input_root, output_root = get_folders() + rel_path = os.path.relpath(mrc_path, input_root) + rel_folder, fname = os.path.split(rel_path) + fname = os.path.splitext(fname)[0] + seg_path = os.path.join(output_root, f"v{VERSION}", rel_folder, f"{fname}.h5") + return seg_path + + +def get_colormaps(): + pool_map = { + "RA-V": (0, 0.33, 0), + "MP-V": (1.0, 0.549, 0.0), + "Docked-V": (1, 1, 0), + None: "gray", + } + ribbon_map = {1: "red", 2: "red", None: (0, 0, 0, 0), 0: (0, 0, 0, 0)} + membrane_map = {1: "purple", None: (0, 0, 0, 0)} + pd_map = {1: "magenta", 2: "magenta", None: (0, 0, 0, 0)} + return {"pools": pool_map, "membrane": membrane_map, "PD": pd_map, "ribbon": ribbon_map} + + +def load_segmentations(seg_path, verbose=True): + # Keep the typo in the name, as these are the hdf5 keys! + seg_names = {"vesicles": "veiscles_postprocessed"} + seg_names.update({name: name for name in STRUCTURE_NAMES}) + + segmentations = {} + correction_folder = os.path.join(os.path.split(seg_path)[0], "correction") + with h5py.File(seg_path, "r") as f: + g = f["segmentation"] + for out_name, name in seg_names.items(): + correction_path = os.path.join(correction_folder, f"{name}.tif") + if os.path.exists(correction_path): + if verbose: + print("Loading corrected", name, "segmentation from", correction_path) + segmentations[out_name] = imageio.imread(correction_path) + else: + segmentations[out_name] = g[f"{name}"][:] + return segmentations + + +def to_condition(mrc_path): + fname = os.path.basename(mrc_path) + return "TDA KO" if fname.startswith("Otof") else "TDA WT" + + +if __name__ == "__main__": + tomos = get_all_tomograms(restrict_to_good_tomos=True, restrict_to_nachgeb=True) + print("We have", len(tomos), "tomograms") diff --git a/scripts/otoferlin/compare_vesicle_segmentation.py b/scripts/otoferlin/compare_vesicle_segmentation.py new file mode 100644 index 0000000..555947d --- /dev/null +++ b/scripts/otoferlin/compare_vesicle_segmentation.py @@ -0,0 +1,58 @@ +import os + +import h5py + +from skimage.exposure import equalize_adapthist +from synapse_net.inference.vesicles import segment_vesicles +from synapse_net.file_utils import read_mrc +from synapse_net.tools.util import get_model, compute_scale_from_voxel_size, load_custom_model +from tqdm import tqdm + +from common import get_all_tomograms, get_seg_path + + +def compare_vesicles(tomo_path): + seg_path = get_seg_path(tomo_path) + seg_folder = os.path.split(seg_path)[0] + os.makedirs(seg_folder, exist_ok=True) + + model_paths = { + "adapted_v1": "/mnt/vast-nhr/home/pape41/u12086/inner-ear-da.pt", + "adapted_v2": "./domain_adaptation/checkpoints/otoferlin_da.pt" + } + for model_type in ("vesicles_3d", "adapted_v1", "adapted_v2"): + for use_clahe in (False, True): + seg_key = f"vesicles/{model_type}" + if use_clahe: + seg_key += "_clahe" + + if os.path.exists(seg_path): + with h5py.File(seg_path, "r") as f: + if seg_key in f: + continue + + tomogram, voxel_size = read_mrc(tomo_path) + if use_clahe: + tomogram = equalize_adapthist(tomogram, clip_limit=0.03) + + if model_type == "vesicles_3d": + model = get_model(model_type) + scale = compute_scale_from_voxel_size(voxel_size, model_type) + else: + model_path = model_paths[model_type] + model = load_custom_model(model_path) + scale = compute_scale_from_voxel_size(voxel_size, "ribbon") + + seg = segment_vesicles(tomogram, model=model, scale=scale) + with h5py.File(seg_path, "a") as f: + f.create_dataset(seg_key, data=seg, compression="gzip") + + +def main(): + tomograms = get_all_tomograms() + for tomo in tqdm(tomograms): + compare_vesicles(tomo) + + +if __name__ == "__main__": + main() diff --git a/scripts/otoferlin/correct_structure_segmentation.py b/scripts/otoferlin/correct_structure_segmentation.py new file mode 100644 index 0000000..5c2ff15 --- /dev/null +++ b/scripts/otoferlin/correct_structure_segmentation.py @@ -0,0 +1,34 @@ +from pathlib import Path + +import napari + +from synapse_net.file_utils import read_mrc +from common import get_all_tomograms, get_seg_path, load_segmentations, get_colormaps + + +def correct_structure_segmentation(mrc_path): + seg_path = get_seg_path(mrc_path) + + data, _ = read_mrc(mrc_path) + segmentations = load_segmentations(seg_path) + color_maps = get_colormaps() + + v = napari.Viewer() + v.add_image(data) + for name, seg in segmentations.items(): + if name == "vesicles": + name = "veiscles_postprocessed" + v.add_labels(seg, name=name, colormap=color_maps.get(name, None)) + fname = Path(mrc_path).stem + v.title = fname + napari.run() + + +def main(): + tomograms = get_all_tomograms() + for tomo in tomograms: + correct_structure_segmentation(tomo) + + +if __name__ == "__main__": + main() diff --git a/scripts/otoferlin/correct_vesicle_pools.py b/scripts/otoferlin/correct_vesicle_pools.py new file mode 100644 index 0000000..e99d7a2 --- /dev/null +++ b/scripts/otoferlin/correct_vesicle_pools.py @@ -0,0 +1,162 @@ +import os + +import imageio.v3 as imageio +import napari +import numpy as np +import pandas as pd +from magicgui import magicgui + +from synapse_net.file_utils import read_mrc +from synapse_net.distance_measurements import load_distances +from skimage.measure import regionprops +from common import load_segmentations, get_seg_path, get_all_tomograms, get_colormaps, STRUCTURE_NAMES +from tqdm import tqdm + +import warnings +warnings.filterwarnings("ignore") + + +# FIXME: adding vesicles to pool doesn't work / messes with color map +def _create_pool_layer(seg, assignment_path): + assignments = pd.read_csv(assignment_path) + pools = np.zeros_like(seg) + + pool_colors = get_colormaps()["pools"] + colormap = {None: "gray", 0: (0, 0, 0, 0)} + + # Sorting of floats and ints by np.unique is weird. We better don't trust unique here + # It should not matter if one of the pools is empty. + pool_names = ["RA-V", "MP-V", "Docked-V"] + + for pool_id, pool_name in enumerate(pool_names, 1): + if not isinstance(pool_name, str) and np.isnan(pool_name): + continue + pool_vesicle_ids = assignments[assignments.pool == pool_name].vesicle_id.values + pool_mask = np.isin(seg, pool_vesicle_ids) + pools[pool_mask] = pool_id + colormap[pool_id] = pool_colors[pool_name] + + return pools, colormap, assignments + + +def _update_assignments(vesicles, pool_correction, assignment_path): + old_assignments = pd.read_csv(assignment_path) + props = regionprops(vesicles, pool_correction) + + val_to_pool = {0: 0, 1: "RA-V", 2: "MP-V", 3: "Docked-V", 4: None} + corrected_pools = {prop.label: val_to_pool[int(prop.max_intensity)] for prop in props} + + new_assignments = [] + for _, row in old_assignments.iterrows(): + vesicle_id = row.vesicle_id + corrected_pool = corrected_pools[vesicle_id] + if corrected_pool != 0: + row.pool = corrected_pool + new_assignments.append(row) + new_assignments = pd.DataFrame(new_assignments) + new_assignments.to_csv(assignment_path, index=False) + + +def _create_outlier_mask(assignments, vesicles, output_folder): + distances = {} + for name in STRUCTURE_NAMES: + dist, _, _, ids = load_distances(os.path.join(output_folder, "distances", f"{name}.npz")) + distances[name] = {vid: dist for vid, dist in zip(ids, dist)} + + pool_criteria = { + "RA-V": {"ribbon": 80}, + "MP-V": {"PD": 100, "membrane": 50}, + "Docked-V": {"PD": 100, "membrane": 2}, + } + + vesicle_ids = assignments.vesicle_id.values + outlier_ids = [] + for pool in ("RA-V", "MP-V", "Docked-V"): + pool_ids = assignments[assignments.pool == pool].vesicle_id.values + for name in STRUCTURE_NAMES: + min_dist = pool_criteria[pool].get(name) + if min_dist is None: + continue + dist = distances[name] + assert len(dist) == len(vesicle_ids) + pool_outliers = [vid for vid in pool_ids if dist[vid] > min_dist] + if pool_outliers: + print("Pool:", pool, ":", name, ":", len(pool_outliers)) + outlier_ids.extend(pool_outliers) + + outlier_ids = np.unique(outlier_ids) + outlier_mask = np.isin(vesicles, outlier_ids).astype("uint8") + return outlier_mask + + +def correct_vesicle_pools(mrc_path, show_outliers, skip_if_no_outlier=False): + seg_path = get_seg_path(mrc_path) + + output_folder = os.path.split(seg_path)[0] + assignment_path = os.path.join(output_folder, "vesicle_pools.csv") + if not os.path.exists(assignment_path): + print("Skip", seg_path, "due to missing assignments") + return + + data, _ = read_mrc(mrc_path) + segmentations = load_segmentations(seg_path, verbose=False) + vesicles = segmentations["vesicles"] + + colormaps = get_colormaps() + pool_colors = colormaps["pools"] + correction_colors = { + 1: pool_colors["RA-V"], 2: pool_colors["MP-V"], 3: pool_colors["Docked-V"], 4: "Gray", None: "Gray" + } + + vesicle_pools, pool_colors, assignments = _create_pool_layer(vesicles, assignment_path) + if show_outliers: + outlier_mask = _create_outlier_mask(assignments, vesicles, output_folder) + else: + outlier_mask = None + + if skip_if_no_outlier and outlier_mask.sum() == 0: + return + + pool_correction_path = os.path.join(output_folder, "correction", "pool_correction.tif") + os.makedirs(os.path.join(output_folder, "correction"), exist_ok=True) + if os.path.exists(pool_correction_path): + pool_correction = imageio.imread(pool_correction_path) + else: + pool_correction = np.zeros_like(vesicles) + + v = napari.Viewer() + v.add_image(data) + v.add_labels(vesicle_pools, colormap=pool_colors) + v.add_labels(pool_correction, colormap=correction_colors) + v.add_labels(vesicles, visible=False) + for name in STRUCTURE_NAMES: + # v.add_labels(segmentations[name], name=name, visible=False, colormap=colormaps[name]) + v.add_labels(segmentations[name], name=name, visible=False) + + if outlier_mask is not None: + v.add_labels(outlier_mask) + + @magicgui(call_button="Update Pools") + def update_pools(viewer: napari.Viewer): + pool_data = viewer.layers["vesicle_pools"].data + vesicles = viewer.layers["vesicles"].data + pool_correction = viewer.layers["pool_correction"].data + _update_assignments(vesicles, pool_correction, assignment_path) + pool_data, pool_colors, _ = _create_pool_layer(vesicles, assignment_path) + viewer.layers["vesicle_pools"].data = pool_data + viewer.layers["vesicle_pools"].colormap = pool_colors + + v.window.add_dock_widget(update_pools) + v.title = os.path.basename(mrc_path) + + napari.run() + + +def main(): + tomograms = get_all_tomograms(restrict_to_good_tomos=True) + for tomo in tqdm(tomograms): + correct_vesicle_pools(tomo, show_outliers=True, skip_if_no_outlier=True) + + +if __name__ == "__main__": + main() diff --git a/scripts/otoferlin/domain_adaptation/train_domain_adaptation.py b/scripts/otoferlin/domain_adaptation/train_domain_adaptation.py new file mode 100644 index 0000000..99b32f7 --- /dev/null +++ b/scripts/otoferlin/domain_adaptation/train_domain_adaptation.py @@ -0,0 +1,66 @@ +import os +from glob import glob + +import h5py + +from synapse_net.file_utils import read_mrc +from synapse_net.training.domain_adaptation import mean_teacher_adaptation +from synapse_net.tools.util import compute_scale_from_voxel_size +from synapse_net.inference.util import _Scaler + + +# Apply rescaling, depending on what the segmentation comparison shows. +def preprocess_training_data(): + root = "../data/tomograms" + tomograms = glob(os.path.join(root, "**", "*.mrc"), recursive=True) + tomograms += glob(os.path.join(root, "**", "*.rec"), recursive=True) + tomograms = sorted(tomograms) + + train_folder = "./train_data" + os.makedirs(train_folder, exist_ok=True) + + all_paths = [] + for i, tomo_path in enumerate(tomograms): + out_path = os.path.join(train_folder, f"tomo{i}.h5") + if os.path.exists(out_path): + all_paths.append(out_path) + continue + + data, voxel_size = read_mrc(tomo_path) + scale = compute_scale_from_voxel_size(voxel_size, "ribbon") + print("Scale factor:", scale) + scaler = _Scaler(scale, verbose=True) + data = scaler.scale_input(data) + + with h5py.File(out_path, "a") as f: + f.create_dataset("raw", data=data, compression="gzip") + all_paths.append(out_path) + + train_paths, val_paths = all_paths[:-1], all_paths[-1:] + return train_paths, val_paths + + +def train_domain_adaptation(train_paths, val_paths): + model_path = "/mnt/vast-nhr/home/pape41/u12086/inner-ear-da.pt" + model_name = "adapted_otoferlin" + + patch_shape = [48, 384, 384] + mean_teacher_adaptation( + name=model_name, + unsupervised_train_paths=train_paths, + unsupervised_val_paths=val_paths, + raw_key="raw", + patch_shape=patch_shape, + source_checkpoint=model_path, + confidence_threshold=0.75, + n_iterations=int(2.5*1e4), + ) + + +def main(): + train_paths, val_paths = preprocess_training_data() + train_domain_adaptation(train_paths, val_paths) + + +if __name__ == "__main__": + main() diff --git a/scripts/otoferlin/ensure_labeled_all_vesicles.py b/scripts/otoferlin/ensure_labeled_all_vesicles.py new file mode 100644 index 0000000..c32f8b9 --- /dev/null +++ b/scripts/otoferlin/ensure_labeled_all_vesicles.py @@ -0,0 +1,20 @@ +from common import get_all_tomograms, get_seg_path, load_segmentations +from tqdm import tqdm +from skimage.measure import label +import numpy as np + + +def ensure_labeled(vesicles): + n_ids = len(np.unique(vesicles)) + n_ids_labeled = len(np.unique(label(vesicles))) + assert n_ids == n_ids_labeled, f"{n_ids}, {n_ids_labeled}" + + +def main(): + tomograms = get_all_tomograms(restrict_to_good_tomos=True) + for tomogram in tqdm(tomograms, desc="Process tomograms"): + segmentations = load_segmentations(get_seg_path(tomogram)) + ensure_labeled(segmentations["vesicles"]) + + +main() diff --git a/scripts/otoferlin/export_results.py b/scripts/otoferlin/export_results.py new file mode 100644 index 0000000..f83ff3c --- /dev/null +++ b/scripts/otoferlin/export_results.py @@ -0,0 +1,133 @@ +import os +from datetime import datetime + +import numpy as np +import pandas as pd +from common import get_all_tomograms, get_seg_path, to_condition + +from synapse_net.distance_measurements import load_distances + + +def get_output_folder(): + output_root = "./results" + date = datetime.now().strftime("%Y%m%d") + + version = 1 + output_folder = os.path.join(output_root, f"{date}_{version}") + while os.path.exists(output_folder): + version += 1 + output_folder = os.path.join(output_root, f"{date}_{version}") + + os.makedirs(output_folder) + return output_folder + + +def _export_results(tomograms, result_path, result_extraction): + results = {} + for tomo in tomograms: + condition = to_condition(tomo) + res = result_extraction(tomo) + if condition in results: + results[condition].append(res) + else: + results[condition] = [res] + + for condition, res in results.items(): + res = pd.concat(res) + if os.path.exists(result_path): + with pd.ExcelWriter(result_path, engine="openpyxl", mode="a") as writer: + res.to_excel(writer, sheet_name=condition, index=False) + else: + res.to_excel(result_path, sheet_name=condition, index=False) + + +def load_measures(measure_path, min_radius=5): + measures = pd.read_csv(measure_path).dropna() + measures = measures[measures.radius > min_radius] + return measures + + +def count_vesicle_pools(measures, ribbon_id, tomo): + ribbon_measures = measures[measures.ribbon_id == ribbon_id] + pool_names, counts = np.unique(ribbon_measures.pool.values, return_counts=True) + pool_names, counts = pool_names.tolist(), counts.tolist() + pool_names.append("MP-V_all") + counts.append(counts[pool_names.index("MP-V")] + counts[pool_names.index("Docked-V")]) + res = {"tomogram": [os.path.basename(tomo)], "ribbon": ribbon_id} + res.update({k: v for k, v in zip(pool_names, counts)}) + return pd.DataFrame(res) + + +def export_vesicle_pools(tomograms, result_path): + + def result_extraction(tomo): + folder = os.path.split(get_seg_path(tomo))[0] + measure_path = os.path.join(folder, "vesicle_pools.csv") + measures = load_measures(measure_path) + ribbon_ids = pd.unique(measures.ribbon_id) + + results = [] + for ribbon_id in ribbon_ids: + res = count_vesicle_pools(measures, ribbon_id, tomo) + results.append(res) + return pd.concat(results) + + _export_results(tomograms, result_path, result_extraction) + + +def export_distances(tomograms, result_path): + def result_extraction(tomo): + folder = os.path.split(get_seg_path(tomo))[0] + measure_path = os.path.join(folder, "vesicle_pools.csv") + measures = load_measures(measure_path) + + measures = measures[measures.pool.isin(["MP-V", "Docked-V"])][["vesicle_id", "pool"]] + + # Load the distances to PD. + pd_distances, _, _, seg_ids = load_distances(os.path.join(folder, "distances", "PD.npz")) + pd_distances = {sid: dist for sid, dist in zip(seg_ids, pd_distances)} + measures["distance-to-pd"] = [pd_distances[vid] for vid in measures.vesicle_id.values] + + # Load the distances to membrane. + mem_distances, _, _, seg_ids = load_distances(os.path.join(folder, "distances", "membrane.npz")) + mem_distances = {sid: dist for sid, dist in zip(seg_ids, mem_distances)} + measures["distance-to-membrane"] = [mem_distances[vid] for vid in measures.vesicle_id.values] + + measures = measures.drop(columns=["vesicle_id"]) + measures.insert(0, "tomogram", len(measures) * [os.path.basename(tomo)]) + + return measures + + _export_results(tomograms, result_path, result_extraction) + + +def export_diameter(tomograms, result_path): + def result_extraction(tomo): + folder = os.path.split(get_seg_path(tomo))[0] + measure_path = os.path.join(folder, "vesicle_pools.csv") + measures = load_measures(measure_path) + + measures = measures[measures.pool.isin(["MP-V", "Docked-V"])][["pool", "diameter"]] + measures.insert(0, "tomogram", len(measures) * [os.path.basename(tomo)]) + + return measures + + _export_results(tomograms, result_path, result_extraction) + + +def main(): + tomograms = get_all_tomograms(restrict_to_good_tomos=True) + result_folder = get_output_folder() + + result_path = os.path.join(result_folder, "vesicle_pools.xlsx") + export_vesicle_pools(tomograms, result_path) + + result_path = os.path.join(result_folder, "distances.xlsx") + export_distances(tomograms, result_path) + + result_path = os.path.join(result_folder, "diameter.xlsx") + export_diameter(tomograms, result_path) + + +if __name__ == "__main__": + main() diff --git a/scripts/otoferlin/export_to_imod.py b/scripts/otoferlin/export_to_imod.py new file mode 100644 index 0000000..35e5a72 --- /dev/null +++ b/scripts/otoferlin/export_to_imod.py @@ -0,0 +1,92 @@ +import os +from glob import glob + +from pathlib import Path +from subprocess import run + +import numpy as np +import pandas as pd + +from tqdm import tqdm +from synapse_net.imod.to_imod import write_segmentation_to_imod, write_segmentation_to_imod_as_points +from common import STRUCTURE_NAMES, get_all_tomograms, get_seg_path, load_segmentations + + +def check_imod(mrc_path, mod_path): + run(["imod", mrc_path, mod_path]) + + +def export_tomogram(mrc_path, force): + seg_path = get_seg_path(mrc_path) + output_folder = os.path.split(seg_path)[0] + assert os.path.exists(output_folder) + + # export_folder = os.path.join(output_folder, "imod") + tomo_name = Path(mrc_path).stem + export_folder = os.path.join(f"./imod/{tomo_name}") + if os.path.exists(export_folder) and not force: + return + + segmentations = load_segmentations(seg_path) + vesicles = segmentations["vesicles"] + + os.makedirs(export_folder, exist_ok=True) + + # Load the pool assignments and export the pools to IMOD. + assignment_path = os.path.join(output_folder, "vesicle_pools.csv") + assignments = pd.read_csv(assignment_path) + + colors = { + "Docked-V": (255, 170, 127), # (1, 0.666667, 0.498039) + "RA-V": (0, 85, 0), # (0, 0.333333, 0) + "MP-V": (255, 170, 0), # (1, 0.666667, 0) + "ribbon": (255, 0, 0), + "PD": (255, 0, 255), # (1, 0, 1) + "membrane": (255, 170, 255), # 1, 0.666667, 1 + } + + pools = ['Docked-V', 'RA-V', 'MP-V'] + radius_factor = 0.85 + for pool in pools: + export_path = os.path.join(export_folder, f"{pool}.mod") + pool_ids = assignments[assignments.pool == pool].vesicle_id + pool_seg = vesicles.copy() + pool_seg[~np.isin(pool_seg, pool_ids)] = 0 + write_segmentation_to_imod_as_points( + mrc_path, pool_seg, export_path, min_radius=5, radius_factor=radius_factor, + color=colors.get(pool), name=pool, + ) + # check_imod(mrc_path, export_path) + + # Export the structures to IMOD. + for name in STRUCTURE_NAMES: + export_path = os.path.join(export_folder, f"{name}.mod") + color = colors.get(name) + write_segmentation_to_imod(mrc_path, segmentations[name], export_path, color=color) + # check_imod(mrc_path, export_path) + + # Join the model + all_mod_files = sorted(glob(os.path.join(export_folder, "*.mod"))) + export_path = os.path.join(export_folder, f"{tomo_name}.mod") + join_cmd = ["imodjoin"] + all_mod_files + [export_path] + run(join_cmd) + check_imod(mrc_path, export_path) + + +def main(): + force = True + tomograms = get_all_tomograms(restrict_to_good_tomos=True) + tomograms_for_vis = [ + "Bl6_NtoTDAWT1_blockH_GridE4_1_rec.mrc", + "Otof_TDAKO1blockA_GridN5_6_rec.mrc", + ] + for tomogram in tqdm(tomograms, desc="Process tomograms"): + fname = os.path.basename(tomogram) + if fname not in tomograms_for_vis: + continue + print("Exporting:", fname) + export_tomogram(tomogram, force) + + +if __name__ == "__main__": + main() diff --git a/scripts/otoferlin/filter_objects_and_measure.py b/scripts/otoferlin/filter_objects_and_measure.py new file mode 100644 index 0000000..1479bbe --- /dev/null +++ b/scripts/otoferlin/filter_objects_and_measure.py @@ -0,0 +1,81 @@ +import os +from tqdm import tqdm + +import numpy as np +from skimage.measure import label +from skimage.segmentation import relabel_sequential +from common import get_all_tomograms, get_seg_path, load_table, load_segmentations, STRUCTURE_NAMES +from synapse_net.distance_measurements import measure_segmentation_to_object_distances, load_distances +from synapse_net.file_utils import read_mrc + + +def _filter_n_objects(segmentation, num_objects): + # Create individual objects for all disconnected pieces. + segmentation = label(segmentation) + # Find object ids and sizes, excluding background. + ids, sizes = np.unique(segmentation, return_counts=True) + ids, sizes = ids[1:], sizes[1:] + # Only keep the biggest 'num_objects' objects. + keep_ids = ids[np.argsort(sizes)[::-1]][:num_objects] + segmentation[~np.isin(segmentation, keep_ids)] = 0 + # Relabel the segmentation sequentially. + segmentation, _, _ = relabel_sequential(segmentation) + # Ensure that we have the correct number of objects. + n_ids = int(segmentation.max()) + assert n_ids == num_objects + return segmentation + + +def filter_and_measure(mrc_path, seg_path, output_folder, force): + result_folder = os.path.join(output_folder, "distances") + if os.path.exists(result_folder) and not force: + return + + # Load the table to find out how many ribbons / PDs we have here. + table = load_table() + table = table[table["File name"] == os.path.basename(mrc_path)] + assert len(table) == 1 + + num_ribbon = int(table["#ribbons"].values[0]) + num_pd = int(table["PD?"].values[0]) + + segmentations = load_segmentations(seg_path) + vesicles = segmentations["vesicles"] + structures = {name: segmentations[name] for name in STRUCTURE_NAMES} + + # Filter the ribbon and the PD. + print("Filtering number of ribbons:", num_ribbon) + structures["ribbon"] = _filter_n_objects(structures["ribbon"], num_ribbon) + print("Filtering number of PDs:", num_pd) + structures["PD"] = _filter_n_objects(structures["PD"], num_pd) + + _, resolution = read_mrc(mrc_path) + resolution = [resolution[ax] for ax in "zyx"] + + # Measure all the object distances. + for name in ("ribbon", "PD"): + seg = structures[name] + assert seg.sum() != 0, name + print("Compute vesicle distances to", name) + save_path = os.path.join(result_folder, f"{name}.npz") + measure_segmentation_to_object_distances(vesicles, seg, save_path=save_path, resolution=resolution) + + +def process_tomogram(mrc_path, force): + seg_path = get_seg_path(mrc_path) + output_folder = os.path.split(seg_path)[0] + assert os.path.exists(output_folder) + + # Measure the distances. + filter_and_measure(mrc_path, seg_path, output_folder, force) + + +def main(): + force = True + tomograms = get_all_tomograms(restrict_to_good_tomos=True) + for tomogram in tqdm(tomograms, desc="Process tomograms"): + process_tomogram(tomogram, force) + + +if __name__ == "__main__": + main() diff --git a/scripts/otoferlin/handle_ribbon_assignments.py b/scripts/otoferlin/handle_ribbon_assignments.py new file mode 100644 index 0000000..8ac3586 --- /dev/null +++ b/scripts/otoferlin/handle_ribbon_assignments.py @@ -0,0 +1,57 @@ +import os +import pandas as pd +from synapse_net.distance_measurements import load_distances + +from common import get_all_tomograms, get_seg_path, load_table + + +def _add_one_to_assignment(mrc_path): + seg_path = get_seg_path(mrc_path) + output_folder = os.path.split(seg_path)[0] + assert os.path.exists(output_folder) + assignment_path = os.path.join(output_folder, "vesicle_pools.csv") + + assignments = pd.read_csv(assignment_path) + assignments["ribbon_id"] = len(assignments) * [1] + assignments.to_csv(assignment_path, index=False) + + +def _update_assignments(mrc_path, num_ribbon): + print(mrc_path) + seg_path = get_seg_path(mrc_path) + output_folder = os.path.split(seg_path)[0] + assert os.path.exists(output_folder) + assignment_path = os.path.join(output_folder, "vesicle_pools.csv") + distance_path = os.path.join(output_folder, "distances", "ribbon.npz") + + _, _, _, seg_ids, object_ids = load_distances(distance_path, return_object_ids=True) + assert all(obj in range(1, num_ribbon + 1) for obj in object_ids) + + assignments = pd.read_csv(assignment_path) + assert len(assignments) == len(object_ids) + assert (seg_ids == assignments.vesicle_id.values).all() + assignments["ribbon_id"] = object_ids + assignments.to_csv(assignment_path, index=False) + + +def process_tomogram(mrc_path): + table = load_table() + table = table[table["File name"] == os.path.basename(mrc_path)] + assert len(table) == 1 + num_ribbon = int(table["#ribbons"].values[0]) + assert num_ribbon in (1, 2) + + if num_ribbon == 1: + _add_one_to_assignment(mrc_path) + else: + _update_assignments(mrc_path, num_ribbon) + + +def main(): + tomograms = get_all_tomograms(restrict_to_good_tomos=True) + for tomogram in tomograms: + process_tomogram(tomogram) + + +if __name__ == "__main__": + main() diff --git a/scripts/otoferlin/make_figure_napari.py b/scripts/otoferlin/make_figure_napari.py new file mode 100644 index 0000000..ef515d1 --- /dev/null +++ b/scripts/otoferlin/make_figure_napari.py @@ -0,0 +1,79 @@ +import os + +import napari +import numpy as np +import pandas as pd + +from synapse_net.file_utils import read_mrc +from common import get_all_tomograms, get_seg_path, load_segmentations, STRUCTURE_NAMES + + +colors = { + "Docked-V": (255, 170, 127), # (1, 0.666667, 0.498039) + "RA-V": (0, 85, 0), # (0, 0.333333, 0) + "MP-V": (255, 170, 0), # (1, 0.666667, 0) + "ribbon": (255, 0, 0), + "PD": (255, 0, 255), # (1, 0, 1) + "membrane": (255, 170, 255), # 1, 0.666667, 1 +} + + +def plot_napari(mrc_path, rotate=False): + data, voxel_size = read_mrc(mrc_path) + voxel_size = tuple(voxel_size[ax] for ax in "zyx") + + seg_path = get_seg_path(mrc_path) + output_folder = os.path.split(seg_path)[0] + segmentations = load_segmentations(seg_path) + vesicles = segmentations["vesicles"] + + assignment_path = os.path.join(output_folder, "vesicle_pools.csv") + assignments = pd.read_csv(assignment_path) + + pools = np.zeros_like(vesicles) + pool_names = ["RA-V", "MP-V", "Docked-V"] + + pool_colors = {None: (0, 0, 0)} + for pool_id, pool_name in enumerate(pool_names, 1): + pool_vesicle_ids = assignments[assignments.pool == pool_name].vesicle_id.values + pool_mask = np.isin(vesicles, pool_vesicle_ids) + pools[pool_mask] = pool_id + color = colors.get(pool_name) + color = tuple(c / float(255) for c in color) + pool_colors[pool_id] = color + + if rotate: + data = np.rot90(data, k=3, axes=(1, 2)) + pools = np.rot90(pools, k=3, axes=(1, 2)) + segmentations = {name: np.rot90(segmentations[name], k=3, axes=(1, 2)) for name in STRUCTURE_NAMES} + + v = napari.Viewer() + v.add_image(data, scale=voxel_size) + v.add_labels(pools, colormap=pool_colors, scale=voxel_size) + for name in STRUCTURE_NAMES: + color = colors[name] + color = tuple(c / float(255) for c in color) + cmap = {1: color, None: (0, 0, 0)} + seg = (segmentations[name] > 0).astype("uint8") + v.add_labels(seg, colormap=cmap, scale=voxel_size, name=name) + v.scale_bar.visible = True + v.scale_bar.unit = "nm" + v.scale_bar.font_size = 18 + v.title = os.path.basename(mrc_path) + napari.run() + + +def main(): + tomograms = get_all_tomograms(restrict_to_good_tomos=True) + tomograms_for_vis = [ + "Bl6_NtoTDAWT1_blockH_GridE4_1_rec.mrc", + "Otof_TDAKO1blockA_GridN5_6_rec.mrc", + ] + for tomogram in tomograms: + fname = os.path.basename(tomogram) + if fname not in tomograms_for_vis: + continue + plot_napari(tomogram, rotate=fname.startswith("Otof")) + + +main() diff --git a/scripts/otoferlin/overview Otoferlin samples.xlsx b/scripts/otoferlin/overview Otoferlin samples.xlsx new file mode 100644 index 0000000..6380dfb Binary files /dev/null and b/scripts/otoferlin/overview Otoferlin samples.xlsx differ diff --git a/scripts/otoferlin/pool_assignments_and_measurements.py b/scripts/otoferlin/pool_assignments_and_measurements.py new file mode 100644 index 0000000..90a2d2e --- /dev/null +++ b/scripts/otoferlin/pool_assignments_and_measurements.py @@ -0,0 +1,125 @@ +import os + +import numpy as np +import pandas as pd + +from synapse_net.distance_measurements import measure_segmentation_to_object_distances, load_distances +from synapse_net.file_utils import read_mrc +from synapse_net.imod.to_imod import convert_segmentation_to_spheres +from skimage.measure import label +from tqdm import tqdm + +from common import STRUCTURE_NAMES, get_all_tomograms, get_seg_path, load_segmentations + + +def ensure_labeled(vesicles): + n_ids = len(np.unique(vesicles)) + n_ids_labeled = len(np.unique(label(vesicles))) + assert n_ids == n_ids_labeled, f"{n_ids}, {n_ids_labeled}" + + +def measure_distances(mrc_path, seg_path, output_folder, force): + result_folder = os.path.join(output_folder, "distances") + if os.path.exists(result_folder) and not force: + return + + # Get the voxel size. + _, voxel_size = read_mrc(mrc_path) + resolution = tuple(voxel_size[ax] for ax in "zyx") + + # Load the segmentations. + segmentations = load_segmentations(seg_path) + vesicles = segmentations["vesicles"] + ensure_labeled(vesicles) + structures = {name: segmentations[name] for name in STRUCTURE_NAMES} + + # Measure all the object distances. + os.makedirs(result_folder, exist_ok=True) + for name, seg in structures.items(): + if seg.sum() == 0: + print(name, "was not found, skipping the distance computation.") + continue + print("Compute vesicle distances to", name) + save_path = os.path.join(result_folder, f"{name}.npz") + measure_segmentation_to_object_distances(vesicles, seg, save_path=save_path, resolution=resolution) + + +def _measure_radii(seg_path): + segmentations = load_segmentations(seg_path) + vesicles = segmentations["vesicles"] + # The radius factor of 0.85 yields the best fit to vesicles in IMOD. + _, radii = convert_segmentation_to_spheres(vesicles, radius_factor=0.85) + return np.array(radii) + + +def assign_vesicle_pools_and_measure_radii(seg_path, output_folder, force): + assignment_path = os.path.join(output_folder, "vesicle_pools.csv") + if os.path.exists(assignment_path) and not force: + return + + distance_folder = os.path.join(output_folder, "distances") + distance_paths = {name: os.path.join(distance_folder, f"{name}.npz") for name in STRUCTURE_NAMES} + if not all(os.path.exists(path) for path in distance_paths.values()): + print("Skip vesicle pool assignment, because some distances are missing.") + print("This is probably due to the fact that the corresponding structures were not found.") + return + distances = {name: load_distances(path) for name, path in distance_paths.items()} + + # The distance criteria. + rav_ribbon_distance = 80 # nm + mpv_pd_distance = 100 # nm + mpv_mem_distance = 50 # nm + docked_pd_distance = 100 # nm + docked_mem_distance = 2 # nm + + rav_distances, seg_ids = distances["ribbon"][0], np.array(distances["ribbon"][-1]) + rav_ids = seg_ids[rav_distances < rav_ribbon_distance] + + pd_distances, mem_distances = distances["PD"][0], distances["membrane"][0] + assert len(pd_distances) == len(mem_distances) == len(rav_distances) + + mpv_ids = seg_ids[np.logical_and(pd_distances < mpv_pd_distance, mem_distances < mpv_mem_distance)] + docked_ids = seg_ids[np.logical_and(pd_distances < docked_pd_distance, mem_distances < docked_mem_distance)] + + # Create a dictionary to map vesicle ids to their corresponding pool. + # (RA-V get's over-written by MP-V, which is correct). + pool_assignments = {vid: "RA-V" for vid in rav_ids} + pool_assignments.update({vid: "MP-V" for vid in mpv_ids}) + pool_assignments.update({vid: "Docked-V" for vid in docked_ids}) + + pool_values = [pool_assignments.get(vid, None) for vid in seg_ids] + radii = _measure_radii(seg_path) + assert len(radii) == len(pool_values) + + pool_assignments = pd.DataFrame({ + "vesicle_id": seg_ids, + "pool": pool_values, + "radius": radii, + "diameter": 2 * radii, + }) + pool_assignments.to_csv(assignment_path, index=False) + + +def process_tomogram(mrc_path, force): + seg_path = get_seg_path(mrc_path) + output_folder = os.path.split(seg_path)[0] + assert os.path.exists(output_folder) + + # Measure the distances. + measure_distances(mrc_path, seg_path, output_folder, force) + + # Assign the vesicle pools. + assign_vesicle_pools_and_measure_radii(seg_path, output_folder, force) + + # The surface area / volume for ribbon and PD will be done in a separate script. + + +def main(): + force = True + tomograms = get_all_tomograms(restrict_to_good_tomos=True, restrict_to_nachgeb=True) + for tomogram in tqdm(tomograms, desc="Process tomograms"): + process_tomogram(tomogram, force) + + +if __name__ == "__main__": + main() diff --git a/scripts/otoferlin/postprocess_vesicles.py b/scripts/otoferlin/postprocess_vesicles.py new file mode 100644 index 0000000..370d38c --- /dev/null +++ b/scripts/otoferlin/postprocess_vesicles.py @@ -0,0 +1,73 @@ +import os +from pathlib import Path +from shutil import copyfile + +import imageio.v3 as imageio +import napari +import h5py + +from skimage.measure import label +from tqdm import tqdm + +from common import get_all_tomograms, get_seg_path +from synapse_net.file_utils import read_mrc +from automatic_processing import postprocess_vesicles + +TOMOS = [ + "Otof_TDAKO2blockC_GridE2_1_rec", + "Otof_TDAKO1blockA_GridN5_3_rec", + "Otof_TDAKO1blockA_GridN5_5_rec", + "Bl6_NtoTDAWT1_blockH_GridG2_3_rec", +] + + +def postprocess(mrc_path, process_center_crop): + output_path = get_seg_path(mrc_path) + copyfile(output_path, output_path + ".bkp") + postprocess_vesicles( + mrc_path, output_path, process_center_crop=process_center_crop, force=True + ) + + tomo, _ = read_mrc(mrc_path) + with h5py.File(output_path, "r") as f: + ves = f["segmentation/veiscles_postprocessed"][:] + + v = napari.Viewer() + v.add_image(tomo) + v.add_labels(ves) + napari.run() + + +# Postprocess vesicles in specific tomograms, where this initially +# failed due to wrong structure segmentations. +def redo_initial_postprocessing(): + tomograms = get_all_tomograms() + for tomogram in tqdm(tomograms, desc="Process tomograms"): + fname = Path(tomogram).stem + if fname not in TOMOS: + continue + print("Postprocessing", fname) + postprocess(tomogram, process_center_crop=True) + + +def label_all_vesicles(): + tomograms = get_all_tomograms(restrict_to_good_tomos=True) + for mrc_path in tqdm(tomograms, desc="Process tomograms"): + output_path = get_seg_path(mrc_path) + output_folder = os.path.split(output_path)[0] + vesicle_path = os.path.join(output_folder, "correction", "veiscles_postprocessed.tif") + assert os.path.exists(vesicle_path), vesicle_path + copyfile(vesicle_path, vesicle_path + ".bkp") + vesicles = imageio.imread(vesicle_path) + vesicles = label(vesicles) + imageio.imwrite(vesicle_path, vesicles, compression="zlib") + + +def main(): + # redo_initial_postprocessing() + # Label all vesicle corrections to make sure everyone has its own id + label_all_vesicles() + + +if __name__: + main() diff --git a/scripts/otoferlin/update_radius_measurements.py b/scripts/otoferlin/update_radius_measurements.py new file mode 100644 index 0000000..5c78b6e --- /dev/null +++ b/scripts/otoferlin/update_radius_measurements.py @@ -0,0 +1,31 @@ +import os +import pandas as pd +from pool_assignments_and_measurements import _measure_radii + +from common import STRUCTURE_NAMES, get_all_tomograms, get_seg_path, load_segmentations +from tqdm import tqdm + + +def update_radii(mrc_path): + seg_path = get_seg_path(mrc_path) + output_folder = os.path.split(seg_path)[0] + assert os.path.exists(output_folder) + assignment_path = os.path.join(output_folder, "vesicle_pools.csv") + radii = _measure_radii(seg_path) + + pool_assignments = pd.read_csv(assignment_path) + assert len(radii) == len(pool_assignments) + pool_assignments["radius"] = radii + pool_assignments["diameter"] = 2 * radii + + pool_assignments.to_csv(assignment_path, index=False) + + +def main(): + tomograms = get_all_tomograms(restrict_to_good_tomos=True) + for tomogram in tqdm(tomograms, desc="Process tomograms"): + update_radii(tomogram) + + +if __name__: + main() diff --git a/synapse_net/distance_measurements.py b/synapse_net/distance_measurements.py index 4cf3181..ac074f7 100644 --- a/synapse_net/distance_measurements.py +++ b/synapse_net/distance_measurements.py @@ -1,5 +1,6 @@ import os import multiprocessing as mp +from itertools import product from typing import Dict, List, Optional, Tuple, Union import numpy as np @@ -226,6 +227,7 @@ def measure_segmentation_to_object_distances( resolution: Optional[Tuple[int, int, int]] = None, save_path: Optional[os.PathLike] = None, verbose: bool = False, + return_object_ids: bool = False, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Compute the distance betwen all objects in a segmentation and another object. @@ -238,6 +240,7 @@ def measure_segmentation_to_object_distances( resolution: The resolution / pixel size of the data. save_path: Path for saving the measurement results in numpy zipped format. verbose: Whether to print the progress of the distance computation. + return_object_ids: Whether to also return the object ids. Returns: The segmentation to object distances. @@ -262,7 +265,10 @@ def measure_segmentation_to_object_distances( seg_ids=seg_ids, object_ids=object_ids, ) - return distances, endpoints1, endpoints2, seg_ids + if return_object_ids: + return distances, endpoints1, endpoints2, seg_ids, object_ids + else: + return distances, endpoints1, endpoints2, seg_ids def _extract_nearest_neighbors(pairwise_distances, seg_ids, n_neighbors, remove_duplicates=True): @@ -292,12 +298,13 @@ def _extract_nearest_neighbors(pairwise_distances, seg_ids, n_neighbors, remove_ def load_distances( - measurement_path: os.PathLike + measurement_path: os.PathLike, return_object_ids: bool = False, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Load the saved distacnes from a zipped numpy file. Args: measurement_path: The path where the distances where saved. + return_object_ids: Whether to also return the object ids. Returns: The segmentation to object distances. @@ -308,7 +315,11 @@ def load_distances( auto_dists = np.load(measurement_path) distances, seg_ids = auto_dists["distances"], list(auto_dists["seg_ids"]) endpoints1, endpoints2 = auto_dists["endpoints1"], auto_dists["endpoints2"] - return distances, endpoints1, endpoints2, seg_ids + if return_object_ids: + object_ids = auto_dists["object_ids"] + return distances, endpoints1, endpoints2, seg_ids, object_ids + else: + return distances, endpoints1, endpoints2, seg_ids def create_pairwise_distance_lines( @@ -497,6 +508,46 @@ def keep_direct_distances( return filtered_pairs +# Current implementation only works for 3D data. +def dilate_coordinates( + coords: List[Tuple[int, int, int]], dilation_radius: int, shape: Tuple[int, int, int] +) -> List[Tuple[int, int, int]]: + """Expand coordinates similar to binary dilation without explicitly creating a mask. + + Parameters: + coords: List of (z, row, col) coordinates to expand. + dilation_radius: Radius of dilation (in pixels). + shape: Shape (depth, rows, cols) to clip the expanded coordinates. + + Returns: + Expanded list of coordinates. + """ + coords = np.vstack(coords).T + + # Generate a cubic structuring element (neighborhood) + offset_range = np.arange(-dilation_radius, dilation_radius + 1) + offsets = np.array(list(product(offset_range, offset_range, offset_range))) + + # Apply offsets to all coordinates (broadcasted) + expanded = (coords[:, None, :] + offsets[None, :, :]).reshape(-1, 3) + + # Clip coordinates to the image shape. + valid_mask = ( + (expanded[:, 0] >= 0) & (expanded[:, 0] < shape[0]) & + (expanded[:, 1] >= 0) & (expanded[:, 1] < shape[1]) & + (expanded[:, 2] >= 0) & (expanded[:, 2] < shape[2]) + ) + expanded = expanded[valid_mask] + + # Remove duplicates + expanded = np.unique(expanded, axis=0) + + # Convert back to the tuple of arrays format + expanded_coords = tuple(expanded.T) + + return expanded_coords + + def filter_blocked_segmentation_to_object_distances( segmentation: np.ndarray, distances: np.ndarray, @@ -525,32 +576,27 @@ def filter_blocked_segmentation_to_object_distances( The list of id pairs that are kept. """ distance_lines, properties = create_object_distance_lines( - distances, endpoints1, endpoints2, seg_ids, scale=scale + distances, endpoints1, endpoints2, seg_ids, filter_seg_ids=filter_seg_ids, scale=scale ) all_seg_ids = properties["id"] filtered_ids = [] - for seg_id, line in tqdm(zip(all_seg_ids, distance_lines), total=len(distance_lines), disable=not verbose): - if (seg_ids is not None) and (seg_id not in seg_ids): + for seg_id, line in tqdm( + zip(all_seg_ids, distance_lines), total=len(distance_lines), disable=not verbose, desc="Filter blocked objects" + ): + if (filter_seg_ids is not None) and (seg_id not in filter_seg_ids): continue start, stop = line line = line_nd(start, stop, endpoint=True) if line_dilation > 0: - # TODO make this more efficient, ideally by dilating the mask coordinates - # instead of dilating the actual mask. - # We turn the line into a binary mask and dilate it to have some tolerance. - line_vol = np.zeros_like(segmentation) - line_vol[line] = 1 - line_vol = binary_dilation(line_vol, iterations=line_dilation) - else: - line_vol = line + line = dilate_coordinates(line, line_dilation, segmentation.shape) # Check if we cross any other segments: # Extract the unique ids in the segmentation that overlap with the segmentation. # We count this as a direct distance if no other object overlaps with the line. - line_seg_ids = np.unique(segmentation[line_vol]) + line_seg_ids = np.unique(segmentation[line]) line_seg_ids = np.setdiff1d(line_seg_ids, [0, seg_id]) if len(line_seg_ids) == 0: # No other objet is overlapping, we keep the line. diff --git a/synapse_net/ground_truth/shape_refinement.py b/synapse_net/ground_truth/shape_refinement.py index 8c357ae..26e8e56 100644 --- a/synapse_net/ground_truth/shape_refinement.py +++ b/synapse_net/ground_truth/shape_refinement.py @@ -203,6 +203,7 @@ def refine_individual_vesicle_shapes( edge_map: np.ndarray, foreground_erosion: int = 4, background_erosion: int = 8, + compactness: float = 0.5, ) -> np.ndarray: """Refine vesicle shapes by fitting vesicles to a boundary map. @@ -215,6 +216,8 @@ def refine_individual_vesicle_shapes( You can use `edge_filter` to compute this based on the tomogram. foreground_erosion: By how many pixels the foreground should be eroded in the seeds. background_erosion: By how many pixels the background should be eroded in the seeds. + compactness: The compactness parameter passed to the watershed function. + Higher compactness leads to more regular sized vesicles. Returns: The refined vesicles. """ @@ -250,7 +253,7 @@ def fit_vesicle(prop): # Run seeded watershed to fit the shapes. seeds = fg_seed + 2 * bg_seed - seg[z] = watershed(hmap[z], seeds) == 1 + seg[z] = watershed(hmap[z], seeds, compactness=compactness) == 1 # import napari # v = napari.Viewer() diff --git a/synapse_net/imod/to_imod.py b/synapse_net/imod/to_imod.py index 5832213..99f2407 100644 --- a/synapse_net/imod/to_imod.py +++ b/synapse_net/imod/to_imod.py @@ -37,6 +37,7 @@ def write_segmentation_to_imod( segmentation: Union[str, np.ndarray], output_path: str, segmentation_key: Optional[str] = None, + color: Optional[Tuple[int, int, int]] = None, ) -> None: """Write a segmentation to a mod file as closed contour object(s). @@ -45,6 +46,7 @@ def write_segmentation_to_imod( segmentation: The segmentation (either as numpy array or filepath to a .tif file). output_path: The output path where the mod file will be saved. segmentation_key: The key to the segmentation data in case the segmentation is stored in hdf5 files. + color: Optional color for the exported model. """ cmd = "imodauto" cmd_path = shutil.which(cmd) @@ -83,6 +85,10 @@ def write_segmentation_to_imod( # Run the command. cmd_list = [cmd, "-E", "1", "-u", tmp_path, output_path] + if color is not None: + assert len(color) == 3 + r, g, b = [str(co) for co in color] + cmd_list += ["-co", f"{r} {g} {b}"] run(cmd_list) @@ -172,6 +178,7 @@ def write_points_to_imod( min_radius: Union[float, int], output_path: str, color: Optional[Tuple[int, int, int]] = None, + name: Optional[str] = None, ) -> None: """Write point annotations to a .mod file for IMOD. @@ -182,6 +189,7 @@ def write_points_to_imod( min_radius: Minimum radius for export. output_path: Where to save the .mod file. color: Optional color for writing out the points. + name: Optional name for the exported model. """ cmd = "point2model" cmd_path = shutil.which(cmd) @@ -210,6 +218,8 @@ def _pad(inp, n=3): assert len(color) == 3 r, g, b = [str(co) for co in color] cmd += ["-co", f"{r} {g} {b}"] + if name is not None: + cmd += ["-name", name] run(cmd) @@ -222,6 +232,8 @@ def write_segmentation_to_imod_as_points( radius_factor: float = 1.0, estimate_radius_2d: bool = True, segmentation_key: Optional[str] = None, + color: Optional[Tuple[int, int, int]] = None, + name: Optional[str] = None, ) -> None: """Write segmentation results to .mod file with imod point annotations. @@ -237,6 +249,8 @@ def write_segmentation_to_imod_as_points( the radius will be computed only in 2d rather than in 3d. This can lead to better results in case of deformation across the depth axis. segmentation_key: The key to the segmentation data in case the segmentation is stored in hdf5 files. + color: Optional color for writing out the points. + name: Optional name for the exported model. """ # Read the resolution information from the mrcfile. @@ -254,7 +268,7 @@ def write_segmentation_to_imod_as_points( ) # Write the point annotations to imod. - write_points_to_imod(coordinates, radii, segmentation.shape, min_radius, output_path) + write_points_to_imod(coordinates, radii, segmentation.shape, min_radius, output_path, color=color, name=name) def _get_file_paths(input_path, ext=(".mrc", ".rec")): diff --git a/synapse_net/tools/util.py b/synapse_net/tools/util.py index 1495112..a2113c5 100644 --- a/synapse_net/tools/util.py +++ b/synapse_net/tools/util.py @@ -59,7 +59,28 @@ def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None return model -def _segment_ribbon_AZ(image, model, tiling, scale, verbose, **kwargs): +def _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons): + from synapse_net.inference.postprocessing import ( + segment_ribbon, segment_presynaptic_density, segment_membrane_distance_based, + ) + + ribbon = segment_ribbon( + predictions["ribbon"], vesicles, n_slices_exclude=n_slices_exclude, n_ribbons=n_ribbons, + max_vesicle_distance=40, + ) + PD = segment_presynaptic_density( + predictions["PD"], ribbon, n_slices_exclude=n_slices_exclude, max_distance_to_ribbon=40, + ) + ref_segmentation = PD if PD.sum() > 0 else ribbon + membrane = segment_membrane_distance_based( + predictions["membrane"], ref_segmentation, max_distance=500, n_slices_exclude=n_slices_exclude, + ) + + segmentations = {"ribbon": ribbon, "PD": PD, "membrane": membrane} + return segmentations + + +def _segment_ribbon_AZ(image, model, tiling, scale, verbose, return_predictions=False, **kwargs): # Parse additional keyword arguments from the kwargs. vesicles = kwargs.pop("extra_segmentation") threshold = kwargs.pop("threshold", 0.5) @@ -70,31 +91,21 @@ def _segment_ribbon_AZ(image, model, tiling, scale, verbose, **kwargs): image, model=model, tiling=tiling, scale=scale, verbose=verbose, threshold=threshold, **kwargs ) - # If the vesicles were passed then run additional post-processing. + # Otherwise, just return the predictions. if vesicles is None: - segmentation = predictions + if verbose: + print("Vesicle segmentation was not passed, WILL NOT run post-processing.") + segmentations = predictions - # Otherwise, just return the predictions. + # If the vesicles were passed then run additional post-processing. else: - from synapse_net.inference.postprocessing import ( - segment_ribbon, segment_presynaptic_density, segment_membrane_distance_based, - ) + if verbose: + print("Vesicle segmentation was passed, WILL run post-processing.") + segmentations = _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons) - ribbon = segment_ribbon( - predictions["ribbon"], vesicles, n_slices_exclude=n_slices_exclude, n_ribbons=n_ribbons, - max_vesicle_distance=40, - ) - PD = segment_presynaptic_density( - predictions["PD"], ribbon, n_slices_exclude=n_slices_exclude, max_distance_to_ribbon=40, - ) - ref_segmentation = PD if PD.sum() > 0 else ribbon - membrane = segment_membrane_distance_based( - predictions["membrane"], ref_segmentation, max_distance=500, n_slices_exclude=n_slices_exclude, - ) - - segmentation = {"ribbon": ribbon, "PD": PD, "membrane": membrane} - - return segmentation + if return_predictions: + return segmentations, predictions + return segmentations def run_segmentation(