diff --git a/.gitignore b/.gitignore index e5a747b..203ef2a 100644 --- a/.gitignore +++ b/.gitignore @@ -7,9 +7,14 @@ models/*/ *.egg-info/ run_sbatch.sbatch slurm/ +slurm_revision/ scripts/cooper/evaluation_results/ scripts/cooper/training/copy_testset.py scripts/rizzoli/upsample_data.py scripts/cooper/training/find_rec_testset.py synapse-net-models/ scripts/portal/upscale_tomo.py +analysis_results/ +scripts/cooper/revision/evaluation_results/ +scripts/cooper/revision/export_tif_to_h5.py +scripts/cooper/revision/copy_path.py \ No newline at end of file diff --git a/environment.yaml b/environment.yaml index e85fc3c..c4fd63b 100644 --- a/environment.yaml +++ b/environment.yaml @@ -1,7 +1,7 @@ channels: - conda-forge name: - synapse-net + synapse-net-cpu dependencies: - bioimageio.core - kornia diff --git a/models_az_thin/split-chemical_fixation.json b/models_az_thin/split-chemical_fixation.json new file mode 100644 index 0000000..c700fc7 --- /dev/null +++ b/models_az_thin/split-chemical_fixation.json @@ -0,0 +1 @@ +{"train": ["20180305_06_MS.h5", "20171113_01_MS.h5", "20171113_05_MS.h5", "20171113_04_MS.h5", "20180305_05_MS.h5", "20171113_3.2_MS.h5", "20180305_09_MS.h5", "20180305_07_MS.h5", "20180305_04_MS.h5", "20171006_2_2_MS.h5", "20171006_05_MS.h5", "20180305_10_MS.h5", "20171006_03_MS.h5", "20171006_2_3_MS.h5"], "val": ["20171013_01_MS.h5", "20171013_1.2_MS.h5", "20171006_3_2_MS.h5", "20180305_08_MS.h5", "20180305_02_MS.h5"], "test": ["20180305_03_MS.h5", "20171113_06_MS.h5", "20171113_07_MS.h5", "20171113_02_MS.h5", "20180305_01_MS.h5"]} \ No newline at end of file diff --git a/models_az_thin/split-endbulb_of_held.json b/models_az_thin/split-endbulb_of_held.json new file mode 100644 index 0000000..06e93b5 --- /dev/null +++ b/models_az_thin/split-endbulb_of_held.json @@ -0,0 +1 @@ +{"train": ["M2_eb7_model.h5", "Wt22_eb1_10K_model2.h5", "WT40_eb5_model.h5", "KO8_eb4_model.h5", "M3_eb4_model.h5", "WT20_eb11_model2.h5", "WT13_syn6_model2.h5", "WT20_eb7_AZ1_model2.h5", "M2_eb10_model.h5", "WT11_syn1_model2.h5", "M8_eb8_model.h5", "WT22_eb10_model2.h5", "WT13_eb3_model2.h5", "WT20_syn7_model2.h5", "M7_eb6_model.h5", "WT20_eb7_AZ2_model2.h5", "WT13_syn5_model2.h5", "WT13_syn9_model2.h5", "M2_eb3_model.h5", "WT20_syn3_model2.h5", "WT20_syn1_model2.h5", "WT21_syn4_model2.h5", "WT20_eb4_model2.h5", "M2_eb9_model.h5", "M7_eb3_model.h5", "WT21_eb5_model2.h5", "WT22_eb6_AZ1_model2.h5", "WT13_syn10_model2.h5", "WT19_syn6_model2.h5", "M2_eb2_AZ2_model.h5", "WT41_eb4_model.h5", "WT13_syn4_model2.h5", "WT40_eb10_model.h5", "M1_eb8_model.h5", "WT19_syn9_model2.h5", "WT22_eb5_model2.h5", "WT39_eb7_model.h5", "KO9_eb13_model.h5", "WT39_eb5_model.h5", "WT11_eb5_model2.h5", "M7_eb15_model.h5", "M7_eb2_model.h5", "M7_eb9_model.h5", "WT22_syn7_model2.h5", "M1_eb7_model.h5", "WT11_syn6_model2.h5", "M7_eb5_model.h5", "WT22_syn5_model2.h5", "WT21_eb3_model2.h5", "WT19_syn3_model2.h5", "WT22_syn9_model2.h5", "M5_eb3_model.h5", "WT22_syn6_model2.h5", "WT39_eb9_model.h5", "WT13_eb5_model2.h5", "WT20_eb9_model2.h5", "WT20_eb2_AZ2_12K_model2.h5", "M7_eb12_model.h5", "M1_eb1_model.h5", "WT40_eb1_model.h5", "M2_eb6_model.h5", "M8_eb14_model.h5", "KO9_eb1_model.h5", "WT20_eb1_AZ2_12K_model2.h5", "M7_eb7_model.h5", "M8_eb9_model.h5", "WT40_eb11_model.h5", "M1_eb3_model.h5", "M8_eb12_model.h5", "M7_eb11_model.h5", "KO9_eb10_model.h5", "KO9_eb4_model.h5", "KO8_eb2_model.h5", "WT39_eb4_model.h5", "M1_eb5_model.h5", "M10_eb9_model.h5"], "val": ["WT41_eb6_model.h5", "Wt22_syn2_10K_model2.h5", "M1_eb6_model.h5", "KO9_eb11_model.h5", "WT20_eb5_model2.h5", "WT16_syn2_model2.h5", "KO9_eb9_model.h5", "WT21_eb7_model2.h5", "M7_eb4_model.h5", "M1_eb9_model.h5", "WT40_eb8_model.h5", "M2_eb5_model.h5", "WT13_eb4_model2.h5", "WT39_eb10_model.h5", "WT39_eb8_model.h5", "WT20_syn6_model2.h5", "WT13_syn11_model2.h5", "WT41_eb2_model.h5", "WT39_eb2_model.h5", "M2_eb14_model.h5", "M2_eb8_model.h5", "WT20_syn5_model2.h5", "M10_eb12_model.h5", "M5_eb1_model.h5", "KO9_eb12_model.h5"], "test": ["WT11_eb2_model2.h5", "WT21_syn5_model2.h5", "WT20_eb8_AZ2_model2.h5", "M10_eb8_model.h5", "WT21_eb9_model2.h5", "M2_eb1_model.h5", "WT22_syn10_model2.h5", "WT11_syn3_model2.h5", "WT11_eb1_model2.h5", "WT13_syn7_model2.h5", "WT21_eb4_model2.h5", "WT40_eb9_model.h5", "M6_eb2_model.h5", "WT22_syn1_10K_model2.h5", "WT19_syn1_model2.h5", "M7_eb10_model.h5", "KO9_eb6_model.h5", "WT11_eb7_model2.h5", "WT40_eb3_model.h5", "KO9_eb14_model.h5", "WT20_syn2_model2.h5", "WT22_eb9_model2.h5", "WT13_syn1_model2.h5", "WT39_eb3_model.h5", "WT21_syn3_model2.h5", "M8_eb6_model.h5"]} \ No newline at end of file diff --git a/models_az_thin/split-endbulb_of_held_cropped.json b/models_az_thin/split-endbulb_of_held_cropped.json new file mode 100644 index 0000000..3aef9c5 --- /dev/null +++ b/models_az_thin/split-endbulb_of_held_cropped.json @@ -0,0 +1 @@ +{"train": ["Wt22_eb1_10K_model2_cropped.h5", "WT22_eb6_AZ1_model2_cropped.h5", "WT21_eb3_model2_cropped.h5", "M7_eb11_model_cropped.h5", "WT20_syn3_model2_cropped.h5", "WT13_syn7_model2_cropped.h5", "WT13_eb4_model2_cropped.h5", "WT39_eb2_model_cropped.h5", "WT20_syn1_model2_cropped.h5", "WT40_eb3_model_cropped.h5", "WT20_eb7_AZ1_model2_cropped.h5", "WT21_syn5_model2_cropped.h5", "WT13_syn4_model2_cropped.h5", "KO9_eb13_model_cropped.h5", "M10_eb9_model_cropped.h5", "WT20_syn5_model2_cropped.h5", "M10_eb8_model_cropped.h5", "M7_eb12_model_cropped.h5", "WT39_eb7_model_cropped.h5", "WT20_eb4_model2_cropped.h5", "M1_eb8_model_cropped.h5", "WT40_eb11_model_cropped.h5", "KO9_eb14_model_cropped.h5", "WT39_eb10_model_cropped.h5", "KO9_eb6_model_cropped.h5", "WT13_syn1_model2_cropped.h5", "WT13_syn9_model2_cropped.h5", "WT13_eb3_model2_cropped.h5", "WT41_eb4_model_cropped.h5", "WT40_eb5_model_cropped.h5", "WT11_eb5_model2_cropped.h5", "WT22_eb9_model2_cropped.h5", "M2_eb2_AZ2_model_cropped.h5", "WT41_eb6_model_cropped.h5", "WT13_eb5_model2_cropped.h5", "WT13_syn11_model2_cropped.h5", "WT22_syn5_model2_cropped.h5", "WT20_syn6_model2_cropped.h5", "WT22_syn9_model2_cropped.h5", "WT11_syn6_model2_cropped.h5", "M8_eb12_model_cropped.h5", "WT39_eb4_model_cropped.h5", "M8_eb8_model_cropped.h5", "WT21_eb9_model2_cropped.h5", "WT39_eb3_model_cropped.h5", "M2_eb1_model_cropped.h5", "M2_eb9_model_cropped.h5", "WT39_eb5_model_cropped.h5", "WT22_eb10_model2_cropped.h5", "M7_eb4_model_cropped.h5", "WT20_eb7_AZ2_model2_cropped.h5", "WT40_eb10_model_cropped.h5", "WT19_syn9_model2_cropped.h5", "WT22_syn6_model2_cropped.h5", "WT11_eb1_model2_cropped.h5", "M10_eb12_model_cropped.h5", "KO9_eb11_model_cropped.h5", "WT19_syn6_model2_cropped.h5", "M7_eb5_model_cropped.h5", "WT39_eb9_model_cropped.h5", "M2_eb14_model_cropped.h5", "Wt22_syn2_10K_model2_cropped.h5", "WT20_syn2_model2_cropped.h5", "M7_eb9_model_cropped.h5", "M5_eb3_model_cropped.h5", "WT22_syn1_10K_model2_cropped.h5", "M1_eb7_model_cropped.h5", "M1_eb6_model_cropped.h5", "M7_eb7_model_cropped.h5", "WT21_eb7_model2_cropped.h5", "M2_eb8_model_cropped.h5", "WT20_eb1_AZ2_12K_model2_cropped.h5", "WT20_eb9_model2_cropped.h5", "WT41_eb2_model_cropped.h5", "WT20_eb5_model2_cropped.h5", "KO9_eb12_model_cropped.h5", "M3_eb4_model_cropped.h5", "WT19_syn1_model2_cropped.h5", "M2_eb3_model_cropped.h5", "KO9_eb9_model_cropped.h5", "WT13_syn5_model2_cropped.h5", "M1_eb1_model_cropped.h5", "M2_eb5_model_cropped.h5", "WT20_eb11_model2_cropped.h5", "WT13_syn6_model2_cropped.h5", "KO9_eb10_model_cropped.h5", "M2_eb7_model_cropped.h5", "M1_eb5_model_cropped.h5", "WT13_syn10_model2_cropped.h5", "WT22_eb5_model2_cropped.h5", "KO8_eb2_model_cropped.h5", "M2_eb10_model_cropped.h5", "KO9_eb1_model_cropped.h5", "M7_eb10_model_cropped.h5", "WT21_syn3_model2_cropped.h5", "WT40_eb8_model_cropped.h5"], "val": ["WT20_eb2_AZ2_12K_model2_cropped.h5", "WT19_syn3_model2_cropped.h5", "M6_eb2_model_cropped.h5", "M8_eb9_model_cropped.h5", "KO8_eb4_model_cropped.h5", "WT21_eb4_model2_cropped.h5", "WT22_syn10_model2_cropped.h5", "WT20_syn7_model2_cropped.h5", "M7_eb3_model_cropped.h5", "M7_eb6_model_cropped.h5", "WT16_syn2_model2_cropped.h5", "WT39_eb8_model_cropped.h5", "KO9_eb4_model_cropped.h5", "M1_eb9_model_cropped.h5", "M2_eb6_model_cropped.h5", "WT11_syn1_model2_cropped.h5", "WT11_eb7_model2_cropped.h5", "M1_eb3_model_cropped.h5", "WT21_eb5_model2_cropped.h5", "WT11_syn3_model2_cropped.h5", "M8_eb6_model_cropped.h5", "WT21_syn4_model2_cropped.h5", "M7_eb15_model_cropped.h5", "WT22_syn7_model2_cropped.h5"]} \ No newline at end of file diff --git a/models_az_thin/split-stem.json b/models_az_thin/split-stem.json new file mode 100644 index 0000000..ecaedc3 --- /dev/null +++ b/models_az_thin/split-stem.json @@ -0,0 +1 @@ +{"train": ["36859_H3_SP_10_rec_2kb1dawbp_crop.h5", "36859_H3_SP_01_rec_2kb1dawbp_crop.h5", "36859_H2_SP_02_rec_2Kb1dawbp_crop.h5", "36859_H2_SP_03_rec_2Kb1dawbp_crop.h5"], "val": ["36859_H3_SP_05_rec_2kb1dawbp_crop.h5", "36859_H2_SP_01_rec_2Kb1dawbp_crop.h5"], "test": ["36859_H3_SP_07_rec_2kb1dawbp_crop.h5", "36859_J1_STEM750_66K_SP_03_rec_2kb1dawbp_crop.h5"]} \ No newline at end of file diff --git a/models_az_thin/split-stem_cropped.json b/models_az_thin/split-stem_cropped.json new file mode 100644 index 0000000..a9031ec --- /dev/null +++ b/models_az_thin/split-stem_cropped.json @@ -0,0 +1 @@ +{"train": ["36859_J1_66K_TS_CA3_PS_23_rec_2Kb1dawbp_crop_crop1.h5", "36859_J1_STEM750_66K_SP_15_rec_2kb1dawbp_crop_crop1.h5", "36859_J1_66K_TS_CA3_PS_43_rec_2Kb1dawbp_crop_crop1.h5", "36859_H2_SP_01_rec_2Kb1dawbp_crop_crop1.h5", "36859_J1_STEM750_66K_SP_10_rec_2kb1dawbp_crop_crop1.h5", "36859_H3_SP_05_rec_2kb1dawbp_crop_crop1.h5", "36859_J1_STEM750_66K_SP_07_rec_2kb1dawbp_crop_crop2.h5", "36859_H2_SP_01_rec_2Kb1dawbp_crop_crop2.h5", "36859_J1_STEM750_66K_SP_08_rec_2kb1dawbp_crop_crop3.h5", "36859_H3_SP_07_rec_2kb1dawbp_crop_crop2.h5", "36859_H3_SP_07_rec_2kb1dawbp_crop_cropped_noAZ.h5", "36859_J1_STEM750_66K_SP_03_rec_2kb1dawbp_crop_crop1.h5", "36859_J1_STEM750_66K_SP_06_rec_2kb1dawbp_crop_crop1.h5", "36859_H2_SP_03_rec_2Kb1dawbp_crop_crop1.h5", "36859_H3_SP_07_rec_2kb1dawbp_crop_crop1.h5", "36859_J1_STEM750_66K_SP_02_rec_2kb1dawbp_crop_crop1.h5", "36859_J1_STEM750_66K_SP_17_rec_2kb1dawbp_crop_crop1.h5", "36859_H2_SP_11_rec_2Kb1dawbp_crop_crop1.h5", "36859_H2_SP_04_rec_2Kb1dawbp_crop_crop1.h5", "36859_H2_SP_04_rec_2Kb1dawbp_crop_crop2.h5", "36859_H3_SP_10_rec_2kb1dawbp_crop_cropped_noAZ.h5", "36859_H2_SP_06_rec_2Kb1dawbp_crop_crop1.h5", "36859_J1_66K_TS_CA3_PS_26_rec_2Kb1dawbp_crop_crop1.h5", "36859_J1_STEM750_66K_SP_14_rec_2kb1dawbp_crop_crop2.h5", "36859_H3_SP_01_rec_2kb1dawbp_crop_crop1.h5", "36859_H2_SP_01_rec_2Kb1dawbp_crop_crop3.h5", "36859_J1_STEM750_66K_SP_13_rec_2kb1dawbp_crop_crop2.h5", "36859_H2_SP_10_rec_crop_crop1.h5", "36859_H2_SP_02_rec_2Kb1dawbp_crop_cropped_noAZ.h5", "36859_J1_66K_TS_CA3_PS_32_rec_2Kb1dawbp_crop_crop1.h5", "36859_J1_STEM750_66K_SP_01_rec_2kb1dawbp_crop_crop1.h5", "36859_H2_SP_03_rec_2Kb1dawbp_crop_cropped_noAZ.h5", "36859_H3_SP_10_rec_2kb1dawbp_crop_crop1.h5", "36859_H2_SP_02_rec_2Kb1dawbp_crop_crop1.h5", "36859_J1_STEM750_66K_SP_08_rec_2kb1dawbp_crop_crop2.h5", "36859_J1_STEM750_66K_SP_03_rec_2kb1dawbp_crop_crop2.h5", "36859_H3_SP_05_rec_2kb1dawbp_crop_crop2.h5", "36859_H2_SP_07_rec_2Kb1dawbp_crop_crop1.h5", "36859_J1_STEM750_66K_SP_13_rec_2kb1dawbp_crop_crop1.h5", "36859_J1_STEM750_66K_SP_14_rec_2kb1dawbp_crop_crop1.h5"], "val": ["36859_J1_STEM750_66K_SP_03_rec_2kb1dawbp_crop_crop3.h5", "36859_H2_SP_01_rec_2Kb1dawbp_crop_cropped_noAZ.h5", "36859_J1_STEM750_66K_SP_08_rec_2kb1dawbp_crop_crop1.h5", "36859_J1_STEM750_66K_SP_12_rec_2kb1dawbp_crop_crop1.h5", "36859_H3_SP_05_rec_2kb1dawbp_crop_cropped_noAZ.h5", "36859_H3_SP_09_rec_2kb1dawbp_crop_crop1.h5", "36859_J1_STEM750_66K_SP_11_rec_2kb1dawbp_crop_crop1.h5", "36859_J1_STEM750_66K_SP_07_rec_2kb1dawbp_crop_crop1.h5", "36859_J1_STEM750_66K_SP_12_rec_2kb1dawbp_crop_crop2.h5", "36859_J1_STEM750_66K_SP_10_rec_2kb1dawbp_crop_crop2.h5"]} \ No newline at end of file diff --git a/models_az_thin/split-tem.json b/models_az_thin/split-tem.json new file mode 100644 index 0000000..f5982cf --- /dev/null +++ b/models_az_thin/split-tem.json @@ -0,0 +1 @@ +{"train": ["20190805_09002_B4_SC_08_SP.h5", "WT_MF_DIV28_3.2_MS_09204_K1.h5", "20190524_09204_F4_SC_09_SP.h5", "20190524_09204_F4_SC_06_SP.h5", "WT_MF_DIV14_6.2_MS_E2_09175_CA3_2.h5", "WT_MF_DIV14_07_MS_C2_09175_CA3.h5", "WT_Unt_SC_09175_C4_04_DIV15_mtk_04.h5", "M13_CTRL_22723_O3_06_DIV29_06_MS.h5", "WT_Unt_SC_09175_E2_04_DIV14_mtk_03.h5", "WT_MF_DIV14_01_MS_E2_09175_CA3.h5", "M13_CTRL_22723_J1_03_DIV29_03_MS.h5", "20190807_23032_D4_SC_03_SP.h5", "M13_DKO_09201_U1_05_DIV31_05_MS.h5", "WT_Unt_SC_09175_E4_02_DIV14_mtk_02.h5", "20190805_09002_B4_SC_01_SP.h5", "WT_Unt_SC_09175_C4_08_DIV15_mtk_08.h5", "WT_MF_DIV14_01_MS_B2_09175_CA3.h5", "20190805_09002_B4_SC_10_SP.h5", "WT_Unt_SC_09175_D4_01_DIV14_mtk_01.h5", "WT_Unt_SC_09175_E2_03_DIV14_mtk_03.h5", "WT_Unt_SC_09175_C4_02_DIV15_mtk_02.h5", "20190805_09002_B4_SC_12_SP.h5", "20190805_09002_B4_SC_02_SP.h5", "M13_DKO_09201_U1_04_DIV31_04_MS.h5", "WT_MF_DIV14_02_MS_D2_09175_CA3.h5", "WT_Unt_SC_09175_E2_05_DIV14_mtk_05.h5", "20190807_23032_D4_SC_07_SP.h5", "M13_CTRL_09201_S2_03_DIV31_03_MS.h5", "WT_MF_DIV14_01.2_MS_D1_09175_CA3.h5", "20190807_23032_D4_SC_01_SP.h5", "WT_Unt_SC_09175_E4_03_DIV14_mtk_03.h5", "WT_MF_DIV28_3.3_MS_09204_K1.h5", "WT_Unt_SC_09175_D5_04_DIV14_mtk_04.h5", "20190524_09204_F4_SC_05_SP.h5", "M13_DKO_22723_A1_03_DIV29_03_MS.h5", "20190805_09002_B4_SC_7.2_SP.h5", "M13_DKO_22723_A4_08_DIV29_08_MS.h5", "WT_Unt_SC_09175_E4_05_DIV14_mtk_05.h5", "M13_CTRL_09201_S1_01_DIV31_01.h5", "WT_Unt_SC_09175_C4_03_DIV15_mtk_03.h5", "WT_MF_DIV28_03_MS_09204_M1.h5", "M13_DKO_09201_U1_5.2_DIV31_5.2_MS.h5", "WT_Unt_SC_09175_E2_01_DIV14_mtk_01.h5", "WT_MF_DIV28_2.3_MS_09002_B1.h5", "WT_MF_DIV28_01_MS_09204_F1.h5", "M13_CTRL_09201_S2_05_DIV31_05_MS.h5", "20190807_23032_D4_SC_04_SP.h5", "M13_DKO_09201_Q1_04_DIV31_04_MS.h5", "WT_MF_DIV28_04_MS_09204_M1.h5", "WT_Unt_SC_09175_D4_04_DIV14_mtk_04.h5", "WT_MF_DIV14_02_MS_B2_09175_CA3.h5", "20190807_23032_D4_SC_10_SP.h5", "M13_DKO_09201_U1_03_DIV31_03_MS.h5", "WT_MF_DIV14_01_MS_D1_09175_CA3.h5", "WT_MF_DIV28_05_MS_09204_F1.h5", "WT_Unt_SC_09175_D4_05_DIV14_mtk_05.h5", "WT_MF_DIV14_01_MS_D2_09175_CA3.h5", "WT_MF_DIV28_03_MS_09204_K1.h5", "M13_DKO_09201_O1_01_DIV31_01_MS.h5", "WT_MF_DIV14_04_MS_C2_09175_CA3.h5", "20190524_09204_F4_SC_04_SP.h5", "WT_MF_DIV14_02_MS_E1_09175_CA3.h5", "20190807_23032_D4_SC_09_SP.h5", "M13_CTRL_22723_O2_04_DIV29_04_MS.h5", "WT_MF_DIV28_03_MS_09204_F1.h5", "WT_MF_DIV28_01_MS_09002_B1.h5", "20190524_09204_F4_SC_03_SP.h5", "M13_DKO_22723_A1_4.2_DIV29_4.2_MS.h5", "M13_DKO_23037_K1_01_DIV29_01_MS.h5", "M13_DKO_23037_K1_1.2_DIV29_1.2_MS.h5", "20190524_09204_F4_SC_11_SP.h5", "WT_MF_DIV28_1.2_MS_09204_F1.h5", "M13_DKO_09201_O3_06_DIV31_06_MS.h5", "WT_Unt_SC_09175_C4_01_DIV15_mtk_01.h5", "WT_MF_DIV14_3.2_MS_D2_09175_CA3.h5", "WT_MF_DIV28_07_MS_09002_B2.h5", "WT_MF_DIV28_06_MS_09204_F1.h5", "M13_CTRL_09201_S2_04_DIV31_04_MS.h5", "WT_MF_DIV28_1.2_MS_09002_B1.h5", "WT_MF_DIV28_3.4_MS_09204_K1.h5", "20190524_09204_F4_SC_10_SP.h5", "20190805_09002_B4_SC_11_SP.h5", "20190524_09204_F4_SC_01_SP.h5", "WT_Unt_SC_09175_C4_05_DIV15_mtk_05.h5", "WT_MF_DIV28_04_MS_09002_B2.h5"], "val": ["WT_MF_DIV28_02_MS_09002_B1.h5", "M13_CTRL_09201_S2_06_DIV31_06_MS.h5", "20190805_09002_B4_SC_05_SP.h5", "WT_MF_DIV14_04_MS_B2_09175_CA3.h5", "M13_CTRL_22723_O3_07_DIV29_07_MS.h5", "WT_MF_DIV14_03.2_MS_D1_09175_CA3.h5", "WT_MF_DIV14_04_MS_E1_09175_CA3.h5", "20190524_09204_F4_SC_07_SP.h5", "WT_MF_DIV28_08_MS_09204_F2.h5", "WT_MF_DIV14_06_MS_C2_09175_CA3.h5", "WT_Unt_SC_09175_B5_03_DIV16_mtk_05.h5", "WT_MF_DIV28_2.2_MS_09002_B1.h5", "WT_Unt_SC_09175_D4_02_DIV14_mtk_02.h5", "WT_MF_DIV28_04_MS_09204_F1.h5", "20190805_09002_B4_SC_04_SP.h5", "M13_DKO_09201_Q1_01_DIV31_01_MS.h5", "WT_MF_DIV28_02_MS_09204_M1.h5", "M13_CTRL_22723_O2_05_DIV29_05_MS_.h5", "20190807_23032_D4_SC_08_SP.h5", "WT_Unt_SC_09175_D5_05_DIV14_mtk_05.h5", "WT_Unt_SC_09175_D5_02_DIV14_mtk_02.h5", "WT_MF_DIV14_05_MS_B2_09175_CA3.h5", "WT_Unt_SC_09175_D5_01_DIV14_mtk_01.h5", "20190524_09204_F4_SC_02_SP.h5", "WT_Unt_SC_09175_E4_04_DIV14_mtk_04.h5", "M13_DKO_09201_O3_6.2_DIV31_6.2_MS.h5", "M13_DKO_22723_A1_05_DIV29_05_MS.h5", "WT_MF_DIV14_04_MS_E2_09175_CA3_2.h5", "WT_MF_DIV14_03_MS_C2_09175_CA3.h5"], "test": ["WT_MF_DIV28_4.2_MS_09204_M1.h5", "M13_DKO_09201_Q1_03_DIV31_03_MS.h5", "WT_MF_DIV14_3.1_MS_D2_09175_CA3.h5", "WT_Unt_SC_09175_B5_03_DIV16_mtk_04.h5", "WT_MF_DIV28_10_MS_09002_B3.h5", "WT_MF_DIV14_06_MS_E2_09175_CA3_2.h5", "WT_MF_DIV14_03.3_MS_D1_09175_CA3.h5", "20190805_09002_B4_SC_09_SP.h5", "20190805_09002_B4_SC_7.1_SP.h5", "WT_MF_DIV14_03.1_MS_D1_09175_CA3.h5", "WT_Unt_SC_09175_E4_01_DIV14_mtk_01.h5", "WT_MF_DIV28_09_MS_09002_B3.h5", "WT_Unt_SC_09175_B5_01_DIV16_mtk_01.h5", "M13_DKO_22723_A1_06_DIV29_06_MS.h5", "WT_MF_DIV14_01_MS_orig_C2_09175_CA3.h5", "20190807_23032_D4_SC_05_SP.h5", "M13_DKO_22723_A4_10_DIV29_10_MS.h5", "WT_MF_DIV14_05_MS_C2_09175_CA3.h5", "WT_MF_DIV28_07_MS_09204_F2.h5", "WT_Unt_SC_09175_D5_03_DIV14_mtk_03.h5", "M13_CTRL_22723_O2_05_DIV29_5.2.h5", "WT_MF_DIV14_02_MS_C2_09175_CA3.h5", "WT_Unt_SC_09175_B5_03_DIV16_mtk_03.h5", "WT_MF_DIV14_03_MS_B2_09175_CA3.h5", "WT_Unt_SC_09175_B5_02_DIV16_mtk_02.h5", "WT_Unt_SC_09175_C4_06_DIV15_mtk_06.h5", "WT_MF_DIV14_05_MS_E2_09175_CA3_2.h5", "M13_CTRL_09201_S2_02_DIV31_02_MS.h5", "WT_MF_DIV28_08_MS_09002_B3.h5"]} \ No newline at end of file diff --git a/run_sbatch_revision.sbatch b/run_sbatch_revision.sbatch new file mode 100644 index 0000000..65fc41b --- /dev/null +++ b/run_sbatch_revision.sbatch @@ -0,0 +1,12 @@ +#! /bin/bash +#SBATCH -c 4 #4 #8 +#SBATCH --mem 256G #120G #32G #64G #256G +#SBATCH -p grete:shared #grete:shared #grete-h100:shared +#SBATCH -t 4:00:00 #6:00:00 #48:00:00 +#SBATCH -G A100:1 #V100:1 #2 #A100:1 #gtx1080:2 #v100:1 #H100:1 +#SBATCH --output=/user/muth9/u12095/synapse-net/slurm_revision/slurm-%j.out +#SBATCH -A nim00007 #SBATCH --constraint 80gb + +source ~/.bashrc +conda activate synapse-net +python scripts/cooper/revision/surface_dice.py -i /mnt/ceph-hdd/cold/nim00007/AZ_prediction_new/stem_for_eval/ -gt /mnt/ceph-hdd/cold/nim00007/new_AZ_train_data/stem_for_eval/ -v 7 \ No newline at end of file diff --git a/scripts/baselines/cryo_ves_net/evaluate_cooper.py b/scripts/baselines/cryo_ves_net/evaluate_cooper.py index ed123f4..71e1ff6 100644 --- a/scripts/baselines/cryo_ves_net/evaluate_cooper.py +++ b/scripts/baselines/cryo_ves_net/evaluate_cooper.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd from elf.evaluation.matching import matching +from elf.evaluation.dice import symmetric_best_dice_score from tqdm import tqdm INPUT_ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/vesicles_processed_v2/testsets" # noqa @@ -25,11 +26,11 @@ ] -def evaluate_dataset(ds_name): +def evaluate_dataset(ds_name, force): result_folder = "./results/cooper" os.makedirs(result_folder, exist_ok=True) result_path = os.path.join(result_folder, f"{ds_name}.csv") - if os.path.exists(result_path): + if os.path.exists(result_path) and not force: results = pd.read_csv(result_path) return results @@ -44,6 +45,9 @@ def evaluate_dataset(ds_name): mask_key = None pred_files = sorted(glob(os.path.join(OUTPUT_ROOT, ds_name, "**/*.h5"), recursive=True)) + if ds_name == "04": + pred_names = [os.path.basename(path) for path in pred_files] + input_files = [path for path in input_files if os.path.basename(path) in pred_names] assert len(input_files) == len(pred_files), f"{len(input_files)}, {len(pred_files)}" results = { @@ -52,12 +56,13 @@ def evaluate_dataset(ds_name): "precision": [], "recall": [], "f1-score": [], + "sbd-score": [], } for inf, predf in tqdm(zip(input_files, pred_files), total=len(input_files), desc=f"Evaluate {ds_name}"): fname = os.path.basename(inf) sub_res_path = os.path.join(result_folder, f"{ds_name}_{fname}.json") - if os.path.exists(sub_res_path): + if os.path.exists(sub_res_path) and not force: print("Loading scores from", sub_res_path) with open(sub_res_path, "r") as f: scores = json.load(f) @@ -89,6 +94,8 @@ def evaluate_dataset(ds_name): gt[mask == 0] = 0 scores = matching(seg, gt) + sbd_score = symmetric_best_dice_score(seg, gt) + scores["sbd"] = sbd_score with open(sub_res_path, "w") as f: json.dump(scores, f) @@ -98,6 +105,7 @@ def evaluate_dataset(ds_name): results["precision"].append(scores["precision"]) results["recall"].append(scores["recall"]) results["f1-score"].append(scores["f1"]) + results["sbd-score"].append(scores["sbd"]) results = pd.DataFrame(results) results.to_csv(result_path, index=False) @@ -105,9 +113,11 @@ def evaluate_dataset(ds_name): def main(): + force = False + all_results = {} for ds in DATASETS: - result = evaluate_dataset(ds) + result = evaluate_dataset(ds, force=force) all_results[ds] = result groups = { @@ -123,16 +133,24 @@ def main(): } for name, datasets in groups.items(): - f1_scores = [] + f1_scores, sbd_scores = [], [] for ds in datasets: this_f1_scores = all_results[ds]["f1-score"].values.tolist() + this_sbd_scores = all_results[ds]["sbd-score"].values.tolist() f1_scores.extend(this_f1_scores) + sbd_scores.extend(this_sbd_scores) mean_f1 = np.mean(f1_scores) std_f1 = np.std(f1_scores) + print("F1-Score") print(name, ":", mean_f1, "+-", std_f1) + mean_sbd = np.mean(sbd_scores) + std_sbd = np.std(sbd_scores) + print("SBD-Score") + print(name, ":", mean_sbd, "+-", std_sbd) + if __name__ == "__main__": main() diff --git a/scripts/baselines/cryo_ves_net/evaluate_cryo.py b/scripts/baselines/cryo_ves_net/evaluate_cryo.py index 45da0d0..968c72b 100644 --- a/scripts/baselines/cryo_ves_net/evaluate_cryo.py +++ b/scripts/baselines/cryo_ves_net/evaluate_cryo.py @@ -4,17 +4,18 @@ import h5py import pandas as pd from elf.evaluation.matching import matching +from elf.evaluation.dice import symmetric_best_dice_score INPUT_FOLDER = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/fernandez-busnadiego/vesicle_gt/v3" # noqa OUTPUT_FOLDER = "./predictions/cryo" -def evaluate_dataset(ds_name="cryo"): +def evaluate_dataset(ds_name="cryo", force=False): result_folder = "./results/cryo" os.makedirs(result_folder, exist_ok=True) result_path = os.path.join(result_folder, f"{ds_name}.csv") - if os.path.exists(result_path): + if os.path.exists(result_path) and not force: results = pd.read_csv(result_path) return results @@ -28,6 +29,7 @@ def evaluate_dataset(ds_name="cryo"): "precision": [], "recall": [], "f1-score": [], + "sbd-score": [], } for inf, predf in zip(input_files, pred_files): fname = os.path.basename(inf) @@ -39,12 +41,15 @@ def evaluate_dataset(ds_name="cryo"): assert gt.shape == seg.shape scores = matching(seg, gt) + sbd_score = symmetric_best_dice_score(seg, gt) + scores["sbd"] = sbd_score results["dataset"].append(ds_name) results["file"].append(fname) results["precision"].append(scores["precision"]) results["recall"].append(scores["recall"]) results["f1-score"].append(scores["f1"]) + results["sbd-score"].append(scores["sbd"]) results = pd.DataFrame(results) results.to_csv(result_path, index=False) @@ -52,9 +57,13 @@ def evaluate_dataset(ds_name="cryo"): def main(): - result = evaluate_dataset() + force = False + result = evaluate_dataset(force=force) print(result) + print("F1-Score") print(result["f1-score"].mean()) + print("SBD-Score") + print(result["sbd-score"].mean()) if __name__ == "__main__": diff --git a/scripts/baselines/cryo_ves_net/evaluate_endbulb.py b/scripts/baselines/cryo_ves_net/evaluate_endbulb.py index c30e4b1..ad44eb3 100644 --- a/scripts/baselines/cryo_ves_net/evaluate_endbulb.py +++ b/scripts/baselines/cryo_ves_net/evaluate_endbulb.py @@ -4,17 +4,19 @@ import h5py import pandas as pd from elf.evaluation.matching import matching +from elf.evaluation.dice import symmetric_best_dice_score +from tqdm import tqdm INPUT_FOLDER = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/wichmann/extracted/endbulb_of_held/Automatische_Segmentierung_Dataset_Validierung" # noqa OUTPUT_FOLDER = "./predictions/endbulb" -def evaluate_dataset(ds_name="endbulb"): +def evaluate_dataset(ds_name="endbulb", force=False): result_folder = "./results/endbulb" os.makedirs(result_folder, exist_ok=True) result_path = os.path.join(result_folder, f"{ds_name}.csv") - if os.path.exists(result_path): + if os.path.exists(result_path) and not force: results = pd.read_csv(result_path) return results @@ -28,8 +30,9 @@ def evaluate_dataset(ds_name="endbulb"): "precision": [], "recall": [], "f1-score": [], + "sbd-score": [], } - for inf, predf in zip(input_files, pred_files): + for inf, predf in tqdm(zip(input_files, pred_files), total=len(input_files), desc="Run evaluation"): fname = os.path.basename(inf) with h5py.File(inf, "r") as f: @@ -39,12 +42,14 @@ def evaluate_dataset(ds_name="endbulb"): assert gt.shape == seg.shape scores = matching(seg, gt) + sbd_score = symmetric_best_dice_score(seg, gt) results["dataset"].append(ds_name) results["file"].append(fname) results["precision"].append(scores["precision"]) results["recall"].append(scores["recall"]) results["f1-score"].append(scores["f1"]) + results["sbd-score"].append(sbd_score) results = pd.DataFrame(results) results.to_csv(result_path, index=False) @@ -52,11 +57,14 @@ def evaluate_dataset(ds_name="endbulb"): def main(): - result = evaluate_dataset() + force = False + result = evaluate_dataset(force=force) print(result) print() - print(result["f1-score"].mean()) - print(result["f1-score"].std()) + print("F1-Score") + print(result["f1-score"].mean(), "+-", result["f1-score"].std()) + print("SBD-Score") + print(result["sbd-score"].mean(), "+-", result["sbd-score"].std()) if __name__ == "__main__": diff --git a/scripts/baselines/cryo_ves_net/evaluate_inner_ear.py b/scripts/baselines/cryo_ves_net/evaluate_inner_ear.py index 930cfc9..626ffad 100644 --- a/scripts/baselines/cryo_ves_net/evaluate_inner_ear.py +++ b/scripts/baselines/cryo_ves_net/evaluate_inner_ear.py @@ -4,17 +4,19 @@ import h5py import pandas as pd from elf.evaluation.matching import matching +from elf.evaluation.dice import symmetric_best_dice_score +from tqdm import tqdm INPUT_FOLDER = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/moser/vesicle_gt" # noqa OUTPUT_FOLDER = "./predictions/inner_ear" -def evaluate_dataset(ds_name="inner_ear"): +def evaluate_dataset(ds_name="inner_ear", force=False): result_folder = "./results/inner_ear" os.makedirs(result_folder, exist_ok=True) result_path = os.path.join(result_folder, f"{ds_name}.csv") - if os.path.exists(result_path): + if os.path.exists(result_path) and not force: results = pd.read_csv(result_path) return results @@ -28,8 +30,9 @@ def evaluate_dataset(ds_name="inner_ear"): "precision": [], "recall": [], "f1-score": [], + "sbd-score": [], } - for inf, predf in zip(input_files, pred_files): + for inf, predf in tqdm(zip(input_files, pred_files), total=len(input_files), desc="Run evaluation"): fname = os.path.basename(inf) with h5py.File(inf, "r") as f: @@ -39,12 +42,14 @@ def evaluate_dataset(ds_name="inner_ear"): assert gt.shape == seg.shape scores = matching(seg, gt) + sbd_score = symmetric_best_dice_score(seg, gt) results["dataset"].append(ds_name) results["file"].append(fname) results["precision"].append(scores["precision"]) results["recall"].append(scores["recall"]) results["f1-score"].append(scores["f1"]) + results["sbd-score"].append(sbd_score) results = pd.DataFrame(results) results.to_csv(result_path, index=False) @@ -52,11 +57,14 @@ def evaluate_dataset(ds_name="inner_ear"): def main(): - result = evaluate_dataset() + force = False + result = evaluate_dataset(force=force) print(result) print() - print(result["f1-score"].mean()) - print(result["f1-score"].std()) + print("F1-Score") + print(result["f1-score"].mean(), "+-", result["f1-score"].std()) + print("SBD-Score") + print(result["sbd-score"].mean(), "+-", result["sbd-score"].std()) if __name__ == "__main__": diff --git a/scripts/cooper/revision/README.md b/scripts/cooper/revision/README.md new file mode 100644 index 0000000..aec9a0b --- /dev/null +++ b/scripts/cooper/revision/README.md @@ -0,0 +1,19 @@ +# Improving the AZ model + +Scripts for improving the AZ annotations, training the AZ model, and evaluating it. + +The most important scripts are: +- For improving and updating the AZ annotations: + - `prediction.py`: Run prediction of vesicle and boundary model. + - `thin_az_gt.py`: Thin the AZ annotations, so that it aligns only with the presynaptic membrane. This is done by intersecting the annotations with the presynaptic compartment, using predictions from the network used for compartment segmentation. + - `assort_new_az_data.py`: Create a new version of the annotation, renaming the dataset, and creating a cropped version of the endbulb of held data. + - `merge_az.py`: Merge AZ annotations with predictions from model v4, in order to remove some artifacts that resulted from AZ thinning. +- For evaluating the AZ predictions: + - `az_prediction.py`: Run prediction with the AZ model. + - `run_az_evaluation.py`: Evaluate the predictions of an AZ model. + - `evaluate_result.py`: Summarize the evaluation results. +- And for training: `train_az_gt.py`. So far, I have trained: + - v3: Trained on the initial annotations. + - v4: Trained on the thinned annotations. + - v5: Trained on the thinned annotations with an additional distance loss (did not help). + - v6: Trained on the merged annotations. diff --git a/scripts/cooper/revision/assort_new_az_data.py b/scripts/cooper/revision/assort_new_az_data.py new file mode 100644 index 0000000..5edf06e --- /dev/null +++ b/scripts/cooper/revision/assort_new_az_data.py @@ -0,0 +1,359 @@ +import os +from glob import glob + +import h5py +import tifffile +import numpy as np +from tqdm import tqdm +from skimage.transform import resize +from skimage.measure import label +from scipy.ndimage import binary_closing + +ROOT = "/mnt/ceph-hdd/cold/nim00007/AZ_data/training_data" +INTER_ROOT = "/mnt/ceph-hdd/cold/nim00007/AZ_predictions" +OUTPUT_ROOT = "/mnt/ceph-hdd/cold/nim00007/new_AZ_train_data" +STEM_INPUT="/mnt/lustre-emmy-hdd/usr/u12095/synaptic_reconstruction/for_revison/postprocessed_AZ" +TIF_INPUT = "/mnt/ceph-hdd/cold/nim00007/new_AZ_train_data/stem/" + + +def _check_data(files, label_folder, check_thinned): + for ff in files: + with h5py.File(ff, "r") as f: + shape = f["raw"].shape + az = f["labels/az"][:] + n_az = az.max() + + if check_thinned: + label_file = os.path.join(label_folder, os.path.basename(ff)) + with h5py.File(label_file, "r") as f: + az_thin = f["labels/az_thin2"][:] + n_az_thin = az_thin.max() + else: + n_az_thin = None + + print(os.path.basename(ff), ":", shape, ":", n_az, ":", n_az_thin) + + +def assort_tem(): + old_name = "01data_withoutInvertedFiles_minusSVseg_corrected" + new_name = "tem" + + raw_folder = os.path.join(ROOT, old_name) + label_folder = os.path.join(INTER_ROOT, old_name) + output_folder = os.path.join(OUTPUT_ROOT, new_name) + os.makedirs(output_folder, exist_ok=True) + + files = glob(os.path.join(raw_folder, "*.h5")) + for ff in tqdm(files): + with h5py.File(ff, "r") as f: + raw = f["raw"][:] + az = f["labels/az"][:] + + label_path = os.path.join(label_folder, os.path.basename(ff)) + with h5py.File(label_path, "r") as f: + az_thin = f["labels/az_thin2"][:] + + z_range1 = np.where(az != 0)[0] + z_range2 = np.where(az != 0)[0] + z_range = slice( + np.min(np.concatenate([z_range1, z_range2])), + np.max(np.concatenate([z_range1, z_range2])) + 1, + ) + raw, az, az_thin = raw[z_range], az[z_range], az_thin[z_range] + + out_path = os.path.join(output_folder, os.path.basename(ff)) + with h5py.File(out_path, "a") as f: + f.create_dataset("raw", data=raw, compression="lzf") + f.create_dataset("labels/az_thin", data=az_thin, compression="lzf") + f.create_dataset("labels/az", data=az, compression="lzf") + + +def assort_chemical_fixation(): + old_name = "12_chemical_fix_cryopreparation_minusSVseg_corrected" + new_name = "chemical_fixation" + + raw_folder = os.path.join(ROOT, old_name) + label_folder = os.path.join(INTER_ROOT, old_name) + output_folder = os.path.join(OUTPUT_ROOT, new_name) + os.makedirs(output_folder, exist_ok=True) + + label_key = "labels/az_thin2" + + files = glob(os.path.join(raw_folder, "*.h5")) + for ff in tqdm(files): + with h5py.File(ff, "r") as f: + raw = f["raw"][:] + az = f["labels/az"][:] + + label_path = os.path.join(label_folder, os.path.basename(ff)) + with h5py.File(label_path, "r") as f: + az_thin = f[label_key][:] + + z_range1 = np.where(az != 0)[0] + z_range2 = np.where(az != 0)[0] + z_range = slice( + np.min(np.concatenate([z_range1, z_range2])), + np.max(np.concatenate([z_range1, z_range2])) + 1, + ) + raw, az, az_thin = raw[z_range], az[z_range], az_thin[z_range] + + out_path = os.path.join(output_folder, os.path.basename(ff)) + with h5py.File(out_path, "a") as f: + f.create_dataset("raw", data=raw, compression="lzf") + f.create_dataset("labels/az_thin", data=az_thin, compression="lzf") + f.create_dataset("labels/az", data=az, compression="lzf") + + +def assort_stem(): + old_names = [ + "04_hoi_stem_examples_fidi_and_sarah_corrected", + "04_hoi_stem_examples_minusSVseg_cropped_corrected", + "06_hoi_wt_stem750_fm_minusSVseg_cropped_corrected", + ] + new_names = ["stem", "stem_cropped", "stem_cropped"] + for old_name, new_name in zip(old_names, new_names): + print(old_name) + raw_folder = os.path.join(ROOT, f"{old_name}_rescaled_tomograms") + label_folder = os.path.join(INTER_ROOT, old_name) + files = glob(os.path.join(raw_folder, "*.h5")) + + # _check_data(files, label_folder, check_thinned=True) + # continue + + output_folder = os.path.join(OUTPUT_ROOT, new_name) + os.makedirs(output_folder, exist_ok=True) + for ff in tqdm(files): + with h5py.File(ff, "r") as f: + raw = f["raw"][:] + az = f["labels/az"][:] + + label_path = os.path.join(label_folder, os.path.basename(ff)) + with h5py.File(label_path, "r") as f: + az_thin = f["labels/az_thin2"][:] + az_thin = resize(az_thin, az.shape, order=0, anti_aliasing=False, preserve_range=True).astype(az_thin.dtype) + assert az_thin.shape == az.shape + + out_path = os.path.join(output_folder, os.path.basename(ff)) + with h5py.File(out_path, "a") as f: + f.create_dataset("raw", data=raw, compression="lzf") + f.create_dataset("labels/az_thin", data=az_thin, compression="lzf") + f.create_dataset("labels/az", data=az, compression="lzf") + + +def assort_wichmann(): + old_name = "wichmann_withAZ_rescaled_tomograms" + new_name = "endbulb_of_held" + + raw_folder = os.path.join(ROOT, old_name) + output_folder = os.path.join(OUTPUT_ROOT, new_name) + os.makedirs(output_folder, exist_ok=True) + + files = glob(os.path.join(raw_folder, "*.h5")) + + output_folder = os.path.join(OUTPUT_ROOT, new_name) + os.makedirs(output_folder, exist_ok=True) + for ff in tqdm(files): + with h5py.File(ff, "r") as f: + raw = f["raw"][:] + az = f["labels/az"][:] + + output_file = os.path.join(output_folder, os.path.basename(ff)) + with h5py.File(output_file, "a") as f: + f.create_dataset("raw", data=raw, compression="lzf") + f.create_dataset("labels/az", data=az, compression="lzf") + f.create_dataset("labels/az_thin", data=az, compression="lzf") + + +def crop_wichmann(): + input_name = "endbulb_of_held" + output_name = "endbulb_of_held_cropped" + + input_folder = os.path.join(OUTPUT_ROOT, input_name) + output_folder = os.path.join(OUTPUT_ROOT, output_name) + os.makedirs(output_folder, exist_ok=True) + files = glob(os.path.join(input_folder, "*.h5")) + + min_shape = (32, 512, 512) + + for ff in tqdm(files): + with h5py.File(ff, "r") as f: + az = f["labels/az"][:] + bb = np.where(az != 0) + bb = tuple(slice(int(b.min()), int(b.max()) + 1) for b in bb) + pad_width = [max(sh - (b.stop - b.start), 0) // 2 for b, sh in zip(bb, min_shape)] + bb = tuple( + slice(max(b.start - pw, 0), min(b.stop + pw, sh)) for b, pw, sh in zip(bb, pad_width, az.shape) + ) + az = az[bb] + raw = f["raw"][bb] + + # import napari + # v = napari.Viewer() + # v.add_image(raw) + # v.add_labels(az) + # v.add_labels(az_thin) + # napari.run() + + output_path = os.path.join(output_folder, os.path.basename(ff).replace(".h5", "_cropped.h5")) + with h5py.File(output_path, "a") as f: + f.create_dataset("raw", data=raw, compression="lzf") + f.create_dataset("labels/az", data=az, compression="lzf") + f.create_dataset("labels/az_thin", data=az, compression="lzf") + +def crop_stem(): + #forgot about 06, added later + input_name = "06_hoi_wt_stem750_fm_minusSVseg"#"04_hoi_stem_examples_minusSVseg" + output_name = "stem_cropped2" + + input_folder = os.path.join(STEM_INPUT, input_name) + output_folder = os.path.join(OUTPUT_ROOT, output_name) + os.makedirs(output_folder, exist_ok=True) + files = glob(os.path.join(input_folder, "*.h5")) + + min_shape = (32, 512, 512) + + for ff in tqdm(files): + with h5py.File(ff, "r") as f: + az = f["labels/az"][:] + raw_full = f["raw"][:] + + # Label connected components in the az volume + labeled = label(az) + num, sizes = np.unique(labeled, return_counts=True) + #print(f"num {num}, sizes {sizes}") + num, sizes = num[1:], sizes[1:] + + #exclude artifacts and background + keep_labels = num[(sizes > 2000) & (num != 0)] + #print(f"keep_labels {keep_labels}") + + #Clean up az annotations + az = np.isin(labeled, keep_labels).astype("uint8") + # Apply binary closing. + az = np.logical_or(az, binary_closing(az, iterations=4)).astype("uint8") + + crop_id = 1 + for l in keep_labels: + + output_path = os.path.join(output_folder, os.path.basename(ff).replace(".h5", f"_crop{crop_id}.h5")) + if os.path.exists(output_path): + print(f"Skipping existing file: {output_path}") + crop_id += 1 + continue + + + mask = labeled == l + bb = np.where(mask) + if not bb[0].size: + continue + bb = tuple(slice(int(b.min()), int(b.max()) + 1) for b in bb) + pad_width = [max(sh - (b.stop - b.start), 0) // 2 for b, sh in zip(bb, min_shape)] + bb = tuple( + slice(max(b.start - pw, 0), min(b.stop + pw, sh)) for b, pw, sh in zip(bb, pad_width, az.shape) + ) + az_crop = az[bb] + raw_crop = raw_full[bb] + + + import napari + v = napari.Viewer() + v.add_image(raw_crop) + v.add_labels(az_crop) + napari.run() + + with h5py.File(output_path, "a") as f: + f.create_dataset("raw", data=raw_crop, compression="lzf") + f.create_dataset("labels/az", data=az_crop, compression="lzf") + crop_id += 1 + +def get_bounding_box_3d(file_path, raw_volume): + volume = tifffile.imread(file_path) + filename = os.path.basename(file_path) + print(f"filename {filename}") + + # Find the z index where the 2D rectangle is located (non-zero slice) + z_indices = np.where(np.any(volume, axis=(1, 2)))[0] + + if len(z_indices) == 0: + raise ValueError("No non-zero 2D rectangle found in the volume.") + + z_rect = z_indices[0] + + # Get the 2D mask from that slice + mask_2d = volume[z_rect] + y_indices, x_indices = np.where(mask_2d) + + if len(x_indices) == 0 or len(y_indices) == 0: + raise ValueError("Found slice has no non-zero pixels.") + + x_min, x_max = x_indices.min(), x_indices.max() + 1 + y_min, y_max = y_indices.min(), y_indices.max() + 1 + + # Determine z_start and z_end based on filename + if filename.endswith("_toend.tif"): + z_start, z_end = z_rect, raw_volume.shape[0] + elif filename.endswith("_tostart.tif"): + z_start, z_end = 0, z_rect + 1 + else: + print("here?") + z_start, z_end = z_rect, z_rect + 1 + + # Return bounding box as slices, usable directly for numpy indexing + return ( + slice(z_start, z_end), + slice(y_min, y_max), + slice(x_min, x_max) + ) + +def neg_crop_stem(): + input_name = "mask_for_neg_example"#"04_hoi_stem_examples_minusSVseg" + output_name = "stem_cropped2" + + input_folder = TIF_INPUT + tif_input_folder = os.path.join(TIF_INPUT, input_name) + output_folder = os.path.join(OUTPUT_ROOT, output_name) + os.makedirs(output_folder, exist_ok=True) + tif_files = glob(os.path.join(tif_input_folder, "*.tif")) + print(f"tif_files {tif_files}") + + for ff in tqdm(tif_files): + input_path = os.path.join(input_folder, os.path.basename(ff).replace('_tostart.tif', '.h5').replace('_toend.tif', '.h5')) + with h5py.File(input_path, "r") as f: + raw_full = f["raw"][:] + + + output_path = os.path.join(output_folder, os.path.basename(ff).replace('_tostart.tif', '_cropped_noAZ.h5').replace('_toend.tif', '_cropped_noAZ.h5')) + if os.path.exists(output_path): + print(f"Skipping existing file: {output_path}") + continue + + + bb = get_bounding_box_3d(ff, raw_full) + print(f"bb {bb}") + + raw_crop = raw_full[bb] + + + import napari + v = napari.Viewer() + v.add_image(raw_crop) + napari.run() + + with h5py.File(output_path, "a") as f: + f.create_dataset("raw", data=raw_crop, compression="lzf") + +def main(): + # assort_tem() + # assort_chemical_fixation() + + # assort_stem() + + # assort_wichmann() + #crop_wichmann() + + #crop_stem() + neg_crop_stem() + + +if __name__ == "__main__": + main() diff --git a/scripts/cooper/revision/az_prediction.py b/scripts/cooper/revision/az_prediction.py new file mode 100644 index 0000000..4fca971 --- /dev/null +++ b/scripts/cooper/revision/az_prediction.py @@ -0,0 +1,95 @@ +import argparse +import os +from glob import glob + +import h5py +from synapse_net.inference.active_zone import segment_active_zone +from torch_em.util import load_model +from tqdm import tqdm + +from common import get_file_names, get_split_folder, ALL_NAMES, INPUT_ROOT, OUTPUT_ROOT + + +def run_prediction(model, name, split_folder, version, split_names, in_path): + if in_path: + file_paths = glob(os.path.join(in_path, name, "*.h5")) + file_names = [os.path.basename(path) for path in file_paths] + else: + file_names = get_file_names(name, split_folder, split_names=split_names) + + output_folder = os.path.join(OUTPUT_ROOT, name) + os.makedirs(output_folder, exist_ok=True) + output_key = f"predictions/az/v{version}" + output_key_seg = f"predictions/az/seg_v{version}" + + for fname in tqdm(file_names): + if in_path: + input_path=os.path.join(in_path, name, fname) + else: + input_path = os.path.join(INPUT_ROOT, name, fname) + print(f"segmenting {input_path}") + + output_path = os.path.join(output_folder, fname) + + if os.path.exists(output_path): + with h5py.File(output_path, "r") as f: + if output_key in f and output_key_seg in f: + print(f"skipping, because {output_key} and {output_key_seg} already exists in {output_path}") + continue + + with h5py.File(input_path, "r") as f: + raw = f["raw"][:] + + seg, pred = segment_active_zone(raw, model=model, verbose=False, return_predictions=True) + with h5py.File(output_path, "a") as f: + if output_key in f: + print(f"{output_key} already saved") + else: + f.create_dataset(output_key, data=pred, compression="lzf") + if output_key_seg in f: + print(f"{output_key_seg} already saved") + else: + f.create_dataset(output_key_seg, data=seg, compression="lzf") + + + +def get_model(version): + assert version in (3, 4, 5, 6, 7) + split_folder = get_split_folder(version) + if version == 3: + model_path = os.path.join(split_folder, "checkpoints", "3D-AZ-model-TEM_STEM_ChemFix_wichmann-v3") + elif version ==6: + model_path = "/mnt/ceph-hdd/cold/nim00007/models/AZ/v6/" + elif version == 7: + model_path = "/mnt/lustre-emmy-hdd/usr/u12095/synapse_net/models/ConstantinAZ/checkpoints/v7/" + else: + model_path = os.path.join(split_folder, "checkpoints", f"v{version}") + model = load_model(model_path) + return model + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--version", "-v", type=int) + parser.add_argument("--names", nargs="+", default=ALL_NAMES) + parser.add_argument("--splits", nargs="+", default=["test"]) + parser.add_argument("--model_path", default=None) + parser.add_argument("--input", "-i", default=None) + + args = parser.parse_args() + + if args.model_path: + model = load_model(model_path) + else: + model = get_model(args.version) + + split_folder = get_split_folder(args.version) + + for name in args.names: + run_prediction(model, name, split_folder, args.version, args.splits, args.input) + + print("Finished segmenting!") + + +if __name__ == "__main__": + main() diff --git a/scripts/cooper/revision/check_prediction.py b/scripts/cooper/revision/check_prediction.py new file mode 100644 index 0000000..04c1175 --- /dev/null +++ b/scripts/cooper/revision/check_prediction.py @@ -0,0 +1,47 @@ +import argparse +import os + +import h5py +import napari +from common import ALL_NAMES, get_file_names, get_split_folder, get_paths + + +def check_predictions(name, split, version): + split_folder = get_split_folder(version) + file_names = get_file_names(name, split_folder, split_names=[split]) + seg_paths, gt_paths = get_paths(name, file_names) + + for seg_path, gt_path in zip(seg_paths, gt_paths): + + with h5py.File(gt_path, "r") as f: + raw = f["raw"][:] + gt = f["labels/az"][:] if version == 3 else f["labels/az_thin"][:] + + with h5py.File(seg_path) as f: + pred_key = f"predictions/az/v{version}" + seg_key = f"predictions/az/seg_v{version}" + pred = f[pred_key][:] + seg = f[seg_key][:] + + v = napari.Viewer() + v.add_image(raw) + v.add_image(pred, blending="additive") + v.add_labels(gt) + v.add_labels(seg) + v.title = f"{name}/{os.path.basename(seg_path)}" + napari.run() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--version", "-v", type=int, required=True) + parser.add_argument("--split", default="test") + parser.add_argument("--names", nargs="+", default=ALL_NAMES) + args = parser.parse_args() + + for name in args.names: + check_predictions(name, args.split, args.version) + + +if __name__ == "__main__": + main() diff --git a/scripts/cooper/revision/check_training_data.py b/scripts/cooper/revision/check_training_data.py new file mode 100644 index 0000000..8534cca --- /dev/null +++ b/scripts/cooper/revision/check_training_data.py @@ -0,0 +1,37 @@ +import argparse +import os +from glob import glob + +import napari +import h5py + +ROOT = "/mnt/ceph-hdd/cold_store/projects/nim00007/new_AZ_train_data" +all_names = [ + "chemical_fixation", + "tem", + "stem", + "stem_cropped", + "endbulb_of_held", + "endbulb_of_held_cropped", +] + + +parser = argparse.ArgumentParser() +parser.add_argument("-n", "--names", nargs="+", default=all_names) +args = parser.parse_args() +names = args.names + + +for ds in names: + paths = glob(os.path.join(ROOT, ds, "*.h5")) + for p in paths: + with h5py.File(p, "r") as f: + raw = f["raw"][:] + az = f["labels/az"][:] + az_thin = f["labels/az_thin"][:] + v = napari.Viewer() + v.add_image(raw) + v.add_labels(az) + v.add_labels(az_thin) + v.title = os.path.basename(p) + napari.run() diff --git a/scripts/cooper/revision/common.py b/scripts/cooper/revision/common.py new file mode 100644 index 0000000..603a73d --- /dev/null +++ b/scripts/cooper/revision/common.py @@ -0,0 +1,71 @@ +import json +import os + + +# The root folder which contains the new AZ training data. +INPUT_ROOT = "/mnt/ceph-hdd/cold/nim00007/new_AZ_train_data" +# The output folder for AZ predictions. +OUTPUT_ROOT = "/mnt/ceph-hdd/cold/nim00007/AZ_prediction_new" + +# The names of all datasets for which to run prediction / evaluation. +# This excludes 'endbulb_of_held_cropped', which is a duplicate of 'endbulb_of_held', +# which we don't evaluate on because of this. +ALL_NAMES = [ + "chemical_fixation", "endbulb_of_held", "stem", "tem" +] + +# The translation of new dataset names to old dataset names. +NAME_TRANSLATION = { + "chemical_fixation": ["12_chemical_fix_cryopreparation_minusSVseg_corrected"], + "endbulb_of_held": ["wichmann_withAZ_rescaled_tomograms"], + "stem": ["04_hoi_stem_examples_fidi_and_sarah_corrected_rescaled_tomograms"], + "stem_cropped": ["04_hoi_stem_examples_minusSVseg_cropped_corrected_rescaled_tomograms", + "06_hoi_wt_stem750_fm_minusSVseg_cropped_corrected_rescaled_tomograms"], + "tem": ["01data_withoutInvertedFiles_minusSVseg_corrected"], +} + + +# Get the paths to the files with raw data / ground-truth and the segmentation. +def get_paths(name, file_names, skip_seg=False): + seg_paths, gt_paths = [], [] + for fname in file_names: + if not skip_seg: + seg_path = os.path.join(OUTPUT_ROOT, name, fname) + assert os.path.exists(seg_path), seg_path + seg_paths.append(seg_path) + + gt_path = os.path.join(INPUT_ROOT, name, fname) + assert os.path.exists(gt_path), gt_path + gt_paths.append(gt_path) + + return seg_paths, gt_paths + + +def get_file_names(name, split_folder, split_names): + split_path = os.path.join(split_folder, f"split-{name}.json") + if os.path.exists(split_path): + with open(split_path) as f: + splits = json.load(f) + file_names = [fname for split in split_names for fname in splits[split]] + + else: + old_names = NAME_TRANSLATION[name] + file_names = [] + for old_name in old_names: + split_path = os.path.join(split_folder, f"split-{old_name}.json") + with open(split_path) as f: + splits = json.load(f) + this_file_names = [fname for split in split_names for fname in splits[split]] + file_names.extend(this_file_names) + return file_names + + +def get_split_folder(version): + assert version in (3, 4, 5, 6, 7) + if version == 3: + split_folder = "splits" + elif version == 6: + split_folder= "/mnt/ceph-hdd/cold/nim00007/new_AZ_train_data/splits" + else: + split_folder = "models_az_thin" + return split_folder diff --git a/scripts/cooper/revision/evaluate_result.py b/scripts/cooper/revision/evaluate_result.py new file mode 100644 index 0000000..a7627ba --- /dev/null +++ b/scripts/cooper/revision/evaluate_result.py @@ -0,0 +1,49 @@ +import argparse +import pandas as pd + +parser = argparse.ArgumentParser() +parser.add_argument("result_path") +args = parser.parse_args() + +results = pd.read_excel(args.result_path) +print(results) + + +def summarize_results(res): + print("Dice-Score:", res["dice"].mean(), "+-", res["dice"].std()) + tp, fp, fn = float(res["tp"].sum()), float(res["fp"].sum()), float(res["fn"].sum()) + precision = tp / (tp + fp) + recall = tp / (tp + fn) + f1_score = 2 * tp / (2 * tp + fn + fp) + print("Precision:", precision) + print("Recall:", recall) + print("F1-Score:", f1_score) + + +# # Compute the results for Chemical Fixation. +results_chem_fix = results[results.dataset == "chemical_fixation"] +if results_chem_fix.size > 0: + print("Chemical Fixation Results:") + summarize_results(results_chem_fix) +# +# # Compute the results for STEM (=04). +results_stem = results[results.dataset.str.startswith("stem")] +if results_stem.size > 0: + print() + print("STEM Results:") + summarize_results(results_stem) +# +# # Compute the results for TEM (=01). +results_tem = results[results.dataset == "tem"] +if results_tem.size > 0: + print() + print("TEM Results:") + summarize_results(results_tem) + +# +# Compute the results for Wichmann / endbulb of held. +results_wichmann = results[results.dataset.str.startswith("endbulb")] +if results_wichmann.size > 0: + print() + print("Endbulb of Held Results:") + summarize_results(results_wichmann) diff --git a/scripts/cooper/revision/evaluation_results/v7.xlsx b/scripts/cooper/revision/evaluation_results/v7.xlsx new file mode 100644 index 0000000..db083d1 Binary files /dev/null and b/scripts/cooper/revision/evaluation_results/v7.xlsx differ diff --git a/scripts/cooper/revision/fix_az.py b/scripts/cooper/revision/fix_az.py new file mode 100644 index 0000000..72ca768 --- /dev/null +++ b/scripts/cooper/revision/fix_az.py @@ -0,0 +1,17 @@ +import os +from glob import glob +import h5py +from tqdm import tqdm + + +INPUT_ROOT = "/mnt/ceph-hdd/cold_store/projects/nim00007/new_AZ_train_data" + +files = glob(os.path.join(INPUT_ROOT, "**/*.h5"), recursive=True) + +key = "labels/az_merged" +for ff in tqdm(files): + with h5py.File(ff, "a") as f: + az = f[key][:] + az = az.squeeze() + del f[key] + f.create_dataset(key, data=az, compression="lzf") diff --git a/scripts/cooper/revision/merge_az.py b/scripts/cooper/revision/merge_az.py new file mode 100644 index 0000000..cb1a776 --- /dev/null +++ b/scripts/cooper/revision/merge_az.py @@ -0,0 +1,127 @@ +import argparse +import os +from glob import glob + +import h5py +import napari +import numpy as np +from scipy.ndimage import binary_closing +from common import ALL_NAMES, get_file_names, get_split_folder, get_paths + + +SKIP_MERGE = [ + "36859_J1_66K_TS_CA3_PS_26_rec_2Kb1dawbp_crop.h5", + "36859_J1_66K_TS_CA3_PS_23_rec_2Kb1dawbp_crop.h5", + "36859_J1_66K_TS_CA3_PS_23_rec_2Kb1dawbp_crop.h5", + "36859_J1_STEM750_66K_SP_17_rec_2kb1dawbp_crop.h5", +] + + +# STEM CROPPED IS OFTEN TOO SMALL! +def merge_az(name, version, check, in_path): + split_folder = get_split_folder(version) + + if name == "stem_cropped": + file_paths = glob(os.path.join("/mnt/ceph-hdd/cold/nim00007/new_AZ_train_data/stem_cropped", "*.h5")) + file_names = [os.path.basename(path) for path in file_paths] + else: + if in_path: + file_paths = glob(os.path.join(in_path, name, "*.h5")) + file_names = [os.path.basename(path) for path in file_paths] + else: + file_names = get_file_names(name, split_folder, split_names=["train", "val", "test"]) + seg_paths, gt_paths = get_paths(name, file_names) + + for seg_path, gt_path in zip(seg_paths, gt_paths): + + with h5py.File(gt_path, "r") as f: + #if not check and ("labels/az_merged" in f): + if f"labels/az_merged_v{version}" in f : + continue + raw = f["raw"][:] + gt = f["labels/az"][:] + gt_thin = f["labels/az_thin"][:] + + with h5py.File(seg_path) as f: + seg_key = f"predictions/az/v{version}" + pred = f[seg_key][:] + + fname = os.path.basename(seg_path) + if fname in SKIP_MERGE: + az_merged = gt + else: + threshold = 0.4 + gt_ = np.logical_or(binary_closing(gt, iterations=4), gt) + seg = pred > threshold + az_merged = np.logical_and(seg, gt_) + az_merged = np.logical_or(az_merged, gt_thin) + az_merged = np.logical_or(binary_closing(az_merged, iterations=2), az_merged) + + if check: + v = napari.Viewer() + v.add_image(raw) + v.add_image(pred, blending="additive", visible=False) + v.add_labels(seg, colormap={1: "blue"}) + v.add_labels(gt, colormap={1: "yellow"}) + v.add_labels(az_merged) + v.title = f"{name}/{fname}" + napari.run() + + print(f"gt_path {gt_path}") + with h5py.File(gt_path, "a") as f: + f.create_dataset(f"labels/az_merged_v{version}", data=az_merged, compression="lzf") + + else: + print(f"gt_path {gt_path}") + with h5py.File(gt_path, "a") as f: + f.create_dataset(f"labels/az_merged_v{version}", data=az_merged, compression="lzf") + '''with h5py.File(seg_path, "a") as f: + f.create_dataset(f"labels/az_merged_v{version}", data=az_merged, compression="lzf")''' + + +def visualize_merge(args): + for name in args.names: + if "endbulb" in name: + continue + merge_az(name, args.version, check=True, in_path=args.in_path) + + +def copy_az(name, version): + split_folder = get_split_folder(version) + file_names = get_file_names(name, split_folder, split_names=["train", "val", "test"]) + _, gt_paths = get_paths(name, file_names, skip_seg=True) + + for gt_path in gt_paths: + with h5py.File(gt_path, "a") as f: + if "labels/az_merged" in f: + continue + az = f["labels/az"][:] + f.create_dataset("labels/az_merged", data=az, compression="lzf") + + +def run_merge(args): + for name in args.names: + print("Merging", name) + if "endbulb" in name: + copy_az(name, args.version) + else: + merge_az(name, args.version, check=False, in_path= args.in_path) + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument("--visualize", action="store_true") + parser.add_argument("--names", nargs="+", default=ALL_NAMES + ["endbulb_of_held_cropped"]) + parser.add_argument("--version", "-v", type=int, default=4) + parser.add_argument("--in_path", "-i", default=None) + + args = parser.parse_args() + if args.visualize: + visualize_merge(args) + else: + run_merge(args) + + +if __name__ == "__main__": + main() diff --git a/scripts/cooper/revision/prediction.py b/scripts/cooper/revision/prediction.py new file mode 100644 index 0000000..85bc09b --- /dev/null +++ b/scripts/cooper/revision/prediction.py @@ -0,0 +1,121 @@ +import os +from glob import glob +import argparse + +import h5py +from synapse_net.inference.inference import get_model, compute_scale_from_voxel_size +from synapse_net.inference.compartments import segment_compartments +from synapse_net.inference.vesicles import segment_vesicles +from tqdm import tqdm + +ROOT = "/mnt/ceph-hdd/cold_store/projects/nim00007/AZ_data/training_data" +OUTPUT_ROOT = "/mnt/ceph-hdd/cold_store/projects/nim00007/AZ_predictions" +RESOLUTIONS = { + "01data_withoutInvertedFiles_minusSVseg_corrected": {"x": 1.554, "y": 1.554, "z": 1.554}, + "04_hoi_stem_examples_fidi_and_sarah_corrected": {"x": 0.8681, "y": 0.8681, "z": 0.8681}, + "04_hoi_stem_examples_fidi_and_sarah_corrected_rescaled_tomograms": {"x": 1.554, "y": 1.554, "z": 1.554}, + "04_hoi_stem_examples_minusSVseg_cropped_corrected": {"x": 0.8681, "y": 0.8681, "z": 0.8681}, + "04_hoi_stem_examples_minusSVseg_cropped_corrected_rescaled_tomograms": {"x": 1.554, "y": 1.554, "z": 1.554}, + "06_hoi_wt_stem750_fm_minusSVseg_cropped_corrected": {"x": 0.8681, "y": 0.8681, "z": 0.8681}, + "06_hoi_wt_stem750_fm_minusSVseg_cropped_corrected_rescaled_tomograms": {"x": 1.554, "y": 1.554, "z": 1.554}, + "12_chemical_fix_cryopreparation_minusSVseg_corrected": {"x": 1.554, "y": 1.554, "z": 1.554}, + "wichmann_withAZ": {"x": 1.748, "y": 1.748, "z": 1.748}, + "wichmann_withAZ_rescaled_tomograms": {"x": 1.554, "y": 1.554, "z": 1.554}, + "stem_cropped2_rescaled": {"x": 1.554, "y": 1.554, "z": 1.554}, +} + + +def predict_boundaries(model, path, output_path, visualize=False): + output_key = "predictions/boundaries" + if os.path.exists(output_path): + with h5py.File(output_path, "r") as f: + if output_key in f: + return + + dataset = os.path.basename(os.path.split(path)[0]) + + with h5py.File(path, "r") as f: + data = f["raw"][:] + scale = compute_scale_from_voxel_size(RESOLUTIONS[dataset], "compartments") + _, pred = segment_compartments(data, model=model, scale=scale, verbose=False, return_predictions=True) + + if visualize: + import napari + v = napari.Viewer() + v.add_image(data) + v.add_labels(pred) + napari.run() + + with h5py.File(output_path, "a") as f: + f.create_dataset(output_key, data=pred, compression="lzf") + + +def predict_all_boundaries(folder=ROOT, out_path=OUTPUT_ROOT, visualize=False): + model = get_model("compartments") + files = sorted(glob(os.path.join(folder, "**/*.h5"), recursive=True)) + for path in tqdm(files): + folder_name = os.path.basename(os.path.split(path)[0]) + output_folder = os.path.join(out_path, folder_name) + os.makedirs(output_folder, exist_ok=True) + output_path = os.path.join(output_folder, os.path.basename(path)) + predict_boundaries(model, path, output_path, visualize) + + +def predict_vesicles(model, path, output_path, visualize=False): + output_key = "predictions/vesicle_seg" + if os.path.exists(output_path): + with h5py.File(output_path, "r") as f: + if output_key in f: + return + + dataset = os.path.basename(os.path.split(path)[0]) + #if "rescaled" in dataset: + # return + + with h5py.File(path, "r") as f: + data = f["raw"][:] + scale = compute_scale_from_voxel_size(RESOLUTIONS[dataset], "vesicles_3d") + seg = segment_vesicles(data, model=model, scale=scale, verbose=False) + + if visualize: + import napari + v = napari.Viewer() + v.add_image(data) + v.add_labels(seg) + napari.run() + + with h5py.File(output_path, "a") as f: + f.create_dataset(output_key, data=seg, compression="lzf") + + +def predict_all_vesicles(folder=ROOT, out_path=OUTPUT_ROOT, visualize=False): + model = get_model("vesicles_3d") + files = sorted(glob(os.path.join(folder, "**/*.h5"), recursive=True)) + for path in tqdm(files): + folder_name = os.path.basename(os.path.split(path)[0]) + output_folder = os.path.join(out_path, folder_name) + os.makedirs(output_folder, exist_ok=True) + output_path = os.path.join(output_folder, os.path.basename(path)) + predict_vesicles(model, path, output_path, visualize) + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument("-i","--input_folder", type=str) + parser.add_argument("-o","--out_path", type=str) + parser.add_argument("--vesicles", action="store_true") + parser.add_argument("--boundaries", action="store_true") + parser.add_argument("--visualize", action="store_true") + + args = parser.parse_args() + if args.boundaries: + predict_all_boundaries(args.input_folder, args.out_path, args.visualize) + elif args.vesicles: + predict_all_vesicles(args.input_folder, args.out_path, args.visualize) + else: + print("Choose which structure to predict: --vesicles or --boundaries") + + +if __name__ == "__main__": + main() diff --git a/scripts/cooper/revision/remove_az_thin.py b/scripts/cooper/revision/remove_az_thin.py new file mode 100644 index 0000000..d383049 --- /dev/null +++ b/scripts/cooper/revision/remove_az_thin.py @@ -0,0 +1,72 @@ +'''import h5py + +files = [ + "/mnt/ceph-hdd/cold/nim00007/new_AZ_train_data/stem_cropped2_rescaled/36859_H2_SP_02_rec_2Kb1dawbp_crop_crop1.h5", + "/mnt/ceph-hdd/cold/nim00007/new_AZ_train_data/stem_cropped2_rescaled/36859_H2_SP_07_rec_2Kb1dawbp_crop_crop1.h5" +] + +for file in files: + with h5py.File(file, "r+") as f: + # Load the replacement data + gt = f["labels/az"][:] + + # Delete the existing dataset if it exists + if "labels/az_thin" in f: + del f["labels/az_thin"] + + # Recreate the dataset with the new data + f.create_dataset("labels/az_thin", data=gt) +''' +'''import os +import h5py +from glob import glob +import numpy as np + +# Collect all file paths +file_paths1 = glob(os.path.join("/mnt/ceph-hdd/cold/nim00007/new_AZ_train_data/chemical_fixation", "*.h5")) +file_paths2 = glob(os.path.join("/mnt/ceph-hdd/cold/nim00007/new_AZ_train_data/stem", "*.h5")) +file_paths3 = glob(os.path.join("/mnt/ceph-hdd/cold/nim00007/new_AZ_train_data/stem_cropped", "*.h5")) +file_paths4 = glob(os.path.join("/mnt/ceph-hdd/cold/nim00007/new_AZ_train_data/tem", "*.h5")) + +all_file_paths = file_paths1 + file_paths2 + file_paths3 + file_paths4 + +for fname in all_file_paths: + with h5py.File(fname, "a") as f: + if "/labels/az_merged_v6" in f: + az_merged = f["/labels/az_merged_v6"][:] # shape (1, 46, 446, 446) + az_merged = np.squeeze(az_merged) # shape (46, 446, 446) + + del f["/labels/az_merged_v6"] # delete old dataset + + f.create_dataset("/labels/az_merged_v6", data=az_merged, compression="lzf") + print(f"Updated file: {fname}") + else: + print(f"Dataset not found in: {fname}")''' + +import os +import h5py + +# List of target folders +base_path = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/ground_truth/04Dataset_for_vesicle_eval" +folders = [ + "20241019_Tomo-eval_MF_Synapse", + "20241019_Tomo-eval_PS_Synapse", + "20241019_Tomo-eval_SC_Synapse" +] + +# Keys to delete +keys_to_delete = ["/predictions/az/seg_v7", "/predictions/az/v7", "/predictions/az", "/predictions"] + +for folder in folders: + folder_path = os.path.join(base_path, folder) + for filename in os.listdir(folder_path): + if filename.endswith(".h5"): + file_path = os.path.join(folder_path, filename) + print(f"Processing: {file_path}") + with h5py.File(file_path, 'a') as h5file: + for key in keys_to_delete: + if key in h5file: + print(f" Deleting key: {key}") + del h5file[key] + else: + print(f" Key not found: {key}") diff --git a/scripts/cooper/revision/run_az_evaluation.py b/scripts/cooper/revision/run_az_evaluation.py new file mode 100644 index 0000000..8e91833 --- /dev/null +++ b/scripts/cooper/revision/run_az_evaluation.py @@ -0,0 +1,100 @@ +import argparse +import os +from glob import glob + +import pandas as pd +from common import get_paths, get_file_names, ALL_NAMES + + +def run_az_evaluation(args): + from synapse_net.ground_truth.az_evaluation import az_evaluation + + seg_key = f"predictions/az/v{args.version}" + + split_folder = "./models_az_thin" + results = [] + for dataset in args.datasets: + print(dataset, ":") + if args.in_path: + file_paths = glob(os.path.join(args.in_path, dataset, "*.h5")) + file_names = [os.path.basename(path) for path in file_paths] + else: + file_names = get_file_names(dataset, split_folder, split_names=["test"]) + seg_paths, gt_paths = get_paths(dataset, file_names) + result = az_evaluation( + seg_paths, gt_paths, seg_key=seg_key, gt_key="/labels/az_merged_v6", + criterion=args.criterion, dataset=[dataset] * len(seg_paths), threshold=args.threshold, + ) + results.append(result) + + results = pd.concat(results) + output_path = f"/user/muth9/u12095/synapse-net/scripts/cooper/revision/evaluation_results/v{args.version}.xlsx" + + if os.path.exists(output_path): + # Read existing data + existing = pd.read_excel(output_path) + + # Ensure consistent column naming and types + if "tomo_name" in results.columns and "tomo_name" in existing.columns: + # Drop existing entries with matching "tomo_name" + existing = existing[~existing["tomo_name"].isin(results["tomo_name"])] + + # Combine: old (filtered) + new + combined = pd.concat([existing, results], ignore_index=True) + else: + combined = results + + # Save back to Excel + combined.to_excel(output_path, index=False) + + +def visualize_az_evaluation(args): + from elf.visualisation.metric_visualization import run_metric_visualization + from synapse_net.ground_truth.az_evaluation import _postprocess, _crop + from elf.io import open_file + + seg_key = f"predictions/az/v{args.version}" + + split_folder = "./models_az_thin" + for dataset in args.datasets: + file_names = get_file_names(dataset, split_folder, split_names=["test"]) + seg_paths, gt_paths = get_paths(dataset, file_names) + + for seg_path, gt_path in zip(seg_paths, gt_paths): + + with open_file(seg_path, "r") as f: + seg = f[seg_key][:].squeeze() + with open_file(gt_path, "r") as f: + gt = f["/labels/az_merged_v6"][:] + + seg = seg > args.threshold + + seg, gt, bb = _crop(seg, gt, return_bb=True) + with open_file(gt_path, "r") as f: + image = f["raw"][bb] + + seg = _postprocess(seg, apply_cc=True, min_component_size=5000, iterations=3) + gt = _postprocess(gt, apply_cc=True, min_component_size=500) + + run_metric_visualization(image, seg, gt, title=os.path.basename(seg_path), criterion=args.criterion) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--version", "-v", type=int, required=True) + parser.add_argument("-c", "--criterion", default="iou") + parser.add_argument("--visualize", action="store_true") + parser.add_argument("--datasets", nargs="+", default=ALL_NAMES) + # Set the threshold to None if the AZ prediction already a segmentation. + parser.add_argument("--threshold", type=float, default=0.5) + parser.add_argument("--in_path", "-i", default=None) + args = parser.parse_args() + + if args.visualize: + visualize_az_evaluation(args) + else: + run_az_evaluation(args) + + +if __name__ == "__main__": + main() diff --git a/scripts/cooper/revision/surface_dice.py b/scripts/cooper/revision/surface_dice.py new file mode 100644 index 0000000..1b9553c --- /dev/null +++ b/scripts/cooper/revision/surface_dice.py @@ -0,0 +1,212 @@ +import sys +import os + +# Add membrain-seg to Python path +#Delete before last commit +MEMBRAIN_SEG_PATH = "/user/muth9/u12095/membrain-seg/src" +if MEMBRAIN_SEG_PATH not in sys.path: + sys.path.insert(0, MEMBRAIN_SEG_PATH) + +import argparse +import h5py +import pandas as pd +from tqdm import tqdm +import numpy as np +from scipy.ndimage import label +from skimage.measure import regionprops + +try: + from membrain_seg.segmentation.skeletonize import skeletonization + from membrain_seg.benchmark.metrics import masked_surface_dice +except ImportError: + skeletonization=None + masked_surface_dice=None + + +def load_segmentation(file_path, key): + with h5py.File(file_path, "r") as f: + data = f[key][:] + return data + + +def evaluate_surface_dice(pred, gt, raw, check): + if skeletonization is None: + print("Error! Install membrain_seg. For more information check out https://teamtomo.org/membrain-seg/installation/ ") + raise RuntimeError + + gt_skeleton = skeletonization(gt == 1, batch_size=100000) + pred_skeleton = skeletonization(pred, batch_size=100000) + mask = gt != 2 + + if check: + import napari + v = napari.Viewer() + v.add_image(raw) + v.add_labels(gt, name="gt") + v.add_labels(gt_skeleton.astype(np.uint16), name="gt_skeleton") + v.add_labels(pred, name="pred") + v.add_labels(pred_skeleton.astype(np.uint16), name="pred_skeleton") + napari.run() + + surf_dice, confusion_dict = masked_surface_dice( + pred_skeleton, gt_skeleton, pred, gt, mask + ) + return surf_dice, confusion_dict + + +def process_file(pred_path, gt_path, seg_key, gt_key, check, + min_bb_shape=(32, 384, 384), min_thinning_size=2500, + global_eval=False): + try: + pred = load_segmentation(pred_path, seg_key) + gt = load_segmentation(gt_path, gt_key) + raw = load_segmentation(gt_path, "raw") + + if global_eval: + gt_bin = (gt == 1).astype(np.uint8) + pred_bin = pred.astype(np.uint8) + + dice, confusion = evaluate_surface_dice(pred_bin, gt_bin, raw, check) + return [{ + "tomo_name": os.path.basename(pred_path), + "gt_component_id": -1, # -1 indicates global eval + "surface_dice": dice, + **confusion + }] + + labeled_gt, _ = label(gt == 1) + props = regionprops(labeled_gt) + results = [] + + for prop in props: + if prop.area < min_thinning_size: + continue + + comp_id = prop.label + bbox_start = prop.bbox[:3] + bbox_end = prop.bbox[3:] + bbox = tuple(slice(start, stop) for start, stop in zip(bbox_start, bbox_end)) + + pad_width = [ + max(min_shape - (sl.stop - sl.start), 0) // 2 + for sl, min_shape in zip(bbox, min_bb_shape) + ] + + expanded_bbox = tuple( + slice( + max(sl.start - pw, 0), + min(sl.stop + pw, dim) + ) + for sl, pw, dim in zip(bbox, pad_width, gt.shape) + ) + + gt_crop = (labeled_gt[expanded_bbox] == comp_id).astype(np.uint8) + pred_crop = pred[expanded_bbox].astype(np.uint8) + raw_crop = raw[expanded_bbox] + + try: + dice, confusion = evaluate_surface_dice(pred_crop, gt_crop, raw_crop, check) + except Exception as e: + print(f"Error computing Dice for GT component {comp_id} in {pred_path}: {e}") + continue + + result = { + "tomo_name": os.path.basename(pred_path), + "gt_component_id": comp_id, + "surface_dice": dice, + **confusion + } + results.append(result) + + return results + + except Exception as e: + print(f"Error processing {pred_path}: {e}") + return [] + + +def collect_results(input_folder, gt_folder, version, check=False, + min_bb_shape=(32, 384, 384), min_thinning_size=2500, + global_eval=False): + results = [] + seg_key = f"predictions/az/seg_v{version}" + gt_key = "/labels/az_merged_v6" + input_folder_name = os.path.basename(os.path.normpath(input_folder)) + + for fname in tqdm(os.listdir(input_folder), desc="Processing segmentations"): + if not fname.endswith(".h5"): + continue + + pred_path = os.path.join(input_folder, fname) + print(pred_path) + gt_path = os.path.join(gt_folder, fname) + + if not os.path.exists(gt_path): + print(f"Warning: Ground truth file not found for {fname}") + continue + + file_results = process_file( + pred_path, gt_path, seg_key, gt_key, check, + min_bb_shape=min_bb_shape, + min_thinning_size=min_thinning_size, + global_eval=global_eval + ) + + for res in file_results: + res["input_folder"] = input_folder_name + results.append(res) + + return results + + +def save_results(results, output_file): + new_df = pd.DataFrame(results) + + if os.path.exists(output_file): + existing_df = pd.read_excel(output_file) + + combined_df = existing_df[ + ~existing_df.set_index(["tomo_name", "input_folder", "gt_component_id"]).index.isin( + new_df.set_index(["tomo_name", "input_folder", "gt_component_id"]).index + ) + ] + + final_df = pd.concat([combined_df, new_df], ignore_index=True) + else: + final_df = new_df + + final_df.to_excel(output_file, index=False) + print(f"Results saved to {output_file}") + + +def main(): + parser = argparse.ArgumentParser(description="Compute surface dice per GT component or globally for AZ segmentations.") + parser.add_argument("--input_folder", "-i", required=True, help="Folder with predicted segmentations (.h5)") + parser.add_argument("--gt_folder", "-gt", required=True, help="Folder with ground truth segmentations (.h5)") + parser.add_argument("--version", "-v", required=True, help="Version string used in prediction key") + parser.add_argument("--check", action="store_true", help="Visualize intermediate outputs in Napari") + parser.add_argument("--global_eval", action="store_true", help="If set, compute global surface dice instead of per-component") + + args = parser.parse_args() + + min_bb_shape = (32, 384, 384) + min_thinning_size = 2500 + + suffix = "global" if args.global_eval else "per_gt_component" + output_file = f"/user/muth9/u12095/synapse-net/scripts/cooper/revision/evaluation_results/v{args.version}_surface_dice_{suffix}.xlsx" + + results = collect_results( + args.input_folder, + args.gt_folder, + args.version, + args.check, + min_bb_shape=min_bb_shape, + min_thinning_size=min_thinning_size, + global_eval=args.global_eval + ) + + save_results(results, output_file) + + +if __name__ == "__main__": + main() diff --git a/scripts/cooper/revision/thin_az_gt.py b/scripts/cooper/revision/thin_az_gt.py new file mode 100644 index 0000000..1008588 --- /dev/null +++ b/scripts/cooper/revision/thin_az_gt.py @@ -0,0 +1,79 @@ +import argparse +import os +from glob import glob +from tqdm import tqdm + +import h5py +import napari +from synapse_net.ground_truth.az_evaluation import thin_az + +ROOT = "/mnt/ceph-hdd/cold_store/projects/nim00007/AZ_data/training_data" +OUTPUT_ROOT = "/mnt/ceph-hdd/cold_store/projects/nim00007/AZ_predictions" + + +def run_az_thinning(folder=ROOT, out_path=OUTPUT_ROOT): + files = sorted(glob(os.path.join(folder, "**/*.h5"), recursive=True)) + for ff in tqdm(files): + ds_name = os.path.basename(os.path.split(ff)[0]) + '''if not ds_name.startswith(("04", "06")): + continue + if "rescaled" in ds_name: + continue''' + + print(f"ff {ff}") + ff_out = os.path.join(out_path, os.path.relpath(ff, folder)) + print(f"ff_out {ff_out}") + with h5py.File(ff_out, "r") as f_out, h5py.File(ff, "r") as f_in: + # if "labels/az_thin2" in f_out: + # continue + + boundary_pred = f_out["predictions/boundaries"] + vesicles = f_out["predictions/vesicle_seg"] + + tomo = f_in["raw"] + az = f_in["labels/az"][:] + + az_thin = thin_az( + az, boundary_map=boundary_pred, vesicles=vesicles, tomo=tomo, presyn_dist=8, check=False, + min_thinning_size=2500, + ) + + with h5py.File(ff_out, "a") as f: + ds = f.require_dataset("labels/az_thin", shape=az_thin.shape, dtype=az_thin.dtype, compression="gzip") + ds[:] = az_thin + + +def check_az_thinning(folder=ROOT, out_path=OUTPUT_ROOT): + files = sorted(glob(os.path.join(folder, "**/*.h5"), recursive=True)) + for ff in files: + + f_out = os.path.join(out_path, os.path.relpath(ff, folder)) + with h5py.File(f_out, "r") as f: + if "labels/az_thin" not in f: + continue + az_thin = f["labels/az_thin"][:] + + with h5py.File(ff, "r") as f: + tomo = f["raw"][:] + + v = napari.Viewer() + v.add_image(tomo) + v.add_labels(az_thin) + napari.run() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-i","--input_folder", type=str) + parser.add_argument("-o","--out_path", type=str) + parser.add_argument("--check", action="store_true") + args = parser.parse_args() + + if args.check: + check_az_thinning(args.input_folder, args.out_path) + else: + run_az_thinning(args.input_folder, args.out_path) + + +if __name__ == "__main__": + main() diff --git a/scripts/cooper/revision/train_az.py b/scripts/cooper/revision/train_az.py new file mode 100644 index 0000000..d5be4ae --- /dev/null +++ b/scripts/cooper/revision/train_az.py @@ -0,0 +1,152 @@ +import argparse +import os +import json +from glob import glob + +import torch_em + +from sklearn.model_selection import train_test_split + +from synapse_net.training import supervised_training, AZDistanceLabelTransform + +TRAIN_ROOT = "/mnt/ceph-hdd/cold/nim00007/new_AZ_train_data" +OUTPUT_ROOT = "./models_az_thin" + + +def _require_train_val_test_split(datasets): + train_ratio, val_ratio, test_ratio = 0.60, 0.2, 0.2 + + def _train_val_test_split(names): + train, test = train_test_split(names, test_size=1 - train_ratio, shuffle=True) + _ratio = test_ratio / (test_ratio + val_ratio) + if len(test) == 2: + val, test = test[:1], test[1:] + else: + val, test = train_test_split(test, test_size=_ratio) + return train, val, test + + for ds in datasets: + print(ds) + split_path = os.path.join(OUTPUT_ROOT, f"split-{ds}.json") + if os.path.exists(split_path): + continue + + ds_root = os.path.join(TRAIN_ROOT, ds) + assert os.path.exists(ds_root), ds_root + file_paths = sorted(glob(os.path.join(ds_root, "*.h5"))) + file_names = [os.path.basename(path) for path in file_paths] + + train, val, test = _train_val_test_split(file_names) + + with open(split_path, "w") as f: + json.dump({"train": train, "val": val, "test": test}, f) + + +def _require_train_val_split(datasets): + train_ratio = 0.8 + + def _train_val_split(names): + train, val = train_test_split(names, test_size=1 - train_ratio, shuffle=True) + return train, val + + for ds in datasets: + print(ds) + split_path = os.path.join(OUTPUT_ROOT, f"split-{ds}.json") + if os.path.exists(split_path): + continue + + file_paths = sorted(glob(os.path.join(TRAIN_ROOT, ds, "*.h5"))) + file_names = [os.path.basename(path) for path in file_paths] + + train, val = _train_val_split(file_names) + + with open(split_path, "w") as f: + json.dump({"train": train, "val": val}, f) + + +def get_paths(split, datasets, testset=True): + if testset: + _require_train_val_test_split(datasets) + else: + _require_train_val_split(datasets) + + paths = [] + for ds in datasets: + split_path = os.path.join(OUTPUT_ROOT, f"split-{ds}.json") + with open(split_path) as f: + names = json.load(f)[split] + ds_paths = [os.path.join(TRAIN_ROOT, ds, name) for name in names] + assert len(ds_paths) > 0 + assert all(os.path.exists(path) for path in ds_paths) + paths.extend(ds_paths) + + return paths + + +def train(key, ignore_label=None, use_distances=False, training_2D=False, testset=True, check=False): + + os.makedirs(OUTPUT_ROOT, exist_ok=True) + + datasets_with_testset_true = ["tem", "chemical_fixation", "stem", "endbulb_of_held"] + datasets_with_testset_false = ["stem_cropped", "endbulb_of_held_cropped"] + + train_paths = get_paths("train", datasets=datasets_with_testset_true, testset=True) + val_paths = get_paths("val", datasets=datasets_with_testset_true, testset=True) + + train_paths += get_paths("train", datasets=datasets_with_testset_false, testset=False) + val_paths += get_paths("val", datasets=datasets_with_testset_false, testset=False) + + print("Start training with:") + print(len(train_paths), "tomograms for training") + print(len(val_paths), "tomograms for validation") + + # patch_shape = [48, 256, 256] + patch_shape = [48, 384, 384] + model_name = "v7" + + # checking for 2D training + if training_2D: + patch_shape = [1, 256, 256] + model_name = "2D-AZ-model-v1" + + if use_distances: + out_channels = 2 + label_transform = AZDistanceLabelTransform() + else: + out_channels = 1 + label_transform = torch_em.transform.label.labels_to_binary + + batch_size = 2 + supervised_training( + name=model_name, + train_paths=train_paths, + val_paths=val_paths, + label_key=f"/labels/{key}", + patch_shape=patch_shape, batch_size=batch_size, + sampler=torch_em.data.sampler.MinInstanceSampler(min_num_instances=1, p_reject=0.85), + n_samples_train=None, n_samples_val=100, + check=check, + save_root="/mnt/lustre-emmy-hdd/usr/u12095/synapse_net/models/ConstantinAZ", + n_iterations=int(2e5), + ignore_label=ignore_label, + label_transform=label_transform, + out_channels=out_channels, + # BCE_loss=False, + # sigmoid_layer=True, + ) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-k", "--key", help="Key ID that will be used by model in training", default="az_merged") + parser.add_argument("-m", "--mask", type=int, default=None, + help="Mask ID that will be ignored by model in training") + parser.add_argument("-2D", "--training_2D", action='store_true', help="Set to True for 2D training") + parser.add_argument("-t", "--testset", action='store_false', help="Set to False if no testset should be created") + parser.add_argument("-c", "--check", action="store_true") + args = parser.parse_args() + train(args.key, ignore_label=args.mask, training_2D=args.training_2D, testset=args.testset, check=args.check) + + +if __name__ == "__main__": + main() diff --git a/scripts/cooper/revision/updated_data_analysis/analysis_segmentations.py b/scripts/cooper/revision/updated_data_analysis/analysis_segmentations.py new file mode 100644 index 0000000..2247862 --- /dev/null +++ b/scripts/cooper/revision/updated_data_analysis/analysis_segmentations.py @@ -0,0 +1,291 @@ +import os +import numpy as np +import h5py + +from skimage.measure import regionprops +from skimage.morphology import remove_small_holes +from skimage.segmentation import relabel_sequential + +from synapse_net.inference.vesicles import segment_vesicles +from synapse_net.inference.compartments import segment_compartments +from synapse_net.inference.active_zone import segment_active_zone +from synapse_net.inference.inference import get_model_path +from synapse_net.ground_truth.az_evaluation import _get_presynaptic_mask + + +def fill_and_filter_vesicles(vesicles: np.ndarray) -> np.ndarray: + """ + Apply a size filter and fill small holes in vesicle segments. + + Args: + vesicles (np.ndarray): 3D volume with vesicle segment labels. + + Returns: + np.ndarray: Processed vesicle segmentation volume. + """ + ids, sizes = np.unique(vesicles, return_counts=True) + ids, sizes = ids[1:], sizes[1:] # remove background + + min_size = 2500 + vesicles_pp = vesicles.copy() + filter_ids = ids[sizes < min_size] + vesicles_pp[np.isin(vesicles, filter_ids)] = 0 + + props = regionprops(vesicles_pp) + for prop in props: + bb = prop.bbox + bb = np.s_[ + bb[0]:bb[3], bb[1]:bb[4], bb[2]:bb[5] + ] + mask = vesicles_pp[bb] == prop.label + mask = remove_small_holes(mask, area_threshold=1000) + vesicles_pp[bb][mask] = prop.label + + return vesicles_pp + + +def SV_pred(raw: np.ndarray, SV_model: str, output_path: str = None, store: bool = False) -> np.ndarray: + """ + Run synaptic vesicle segmentation and optionally store the output. + + Args: + raw (np.ndarray): Raw EM image volume. + SV_model (str): Path to vesicle model. + output_path (str): HDF5 file to store predictions. + store (bool): Whether to store predictions. + + Returns: + np.ndarray: Segmentation result. + """ + pred_key = f"predictions/SV/pred" + seg_key = f"predictions/SV/seg" + + use_existing_seg = False + #checking if segmentation is already in output path and if so, use it + if output_path and os.path.exists(output_path): + with h5py.File(output_path, "r") as f: + if seg_key in f: + seg = f[seg_key][:] + use_existing_seg = True + print(f"Using existing SV seg in {output_path}") + + if not use_existing_seg: + #Excluding boundary SV, because they would also not be used in the manual annotation + seg, pred = segment_vesicles(input_volume=raw, model_path=SV_model, exclude_boundary=True, verbose=False, return_predictions=True) + + if store and output_path: + with h5py.File(output_path, "a") as f: + if pred_key in f: + print(f"{pred_key} already saved") + else: + f.create_dataset(pred_key, data=pred, compression="lzf") + + f.create_dataset(seg_key, data=seg, compression="lzf") + elif store and not output_path: + print("Output path is missing, not storing SV predictions") + else: + print("Not storing SV predictions") + + return seg + + +def compartment_pred(raw: np.ndarray, compartment_model: str, output_path: str = None, store: bool = False) -> np.ndarray: + """ + Run compartment segmentation and optionally store the output. + + Args: + raw (np.ndarray): Raw EM image volume. + compartment_model (str): Path to compartment model. + output_path (str): HDF5 file to store predictions. + store (bool): Whether to store predictions. + + Returns: + np.ndarray: Segmentation result. + """ + + pred_key = f"predictions/compartment/pred" + seg_key = f"predictions/compartment/seg" + + use_existing_seg = False + #checking if segmentation is already in output path and if so, use it + if output_path and os.path.exists(output_path): + with h5py.File(output_path, "r") as f: + if seg_key in f and pred_key in f: + seg = f[seg_key][:] + pred = f[pred_key][:] + use_existing_seg = True + print(f"Using existing compartment seg in {output_path}") + + if not use_existing_seg: + seg, pred = segment_compartments(input_volume=raw, model_path=compartment_model, verbose=False, return_predictions=True, boundary_threshold=0.9) + + if store and output_path: + with h5py.File(output_path, "a") as f: + if pred_key in f: + print(f"{pred_key} already saved") + else: + f.create_dataset(pred_key, data=pred, compression="lzf") + + f.create_dataset(seg_key, data=seg, compression="lzf") + elif store and not output_path: + print("Output path is missing, not storing compartment predictions") + else: + print("Not storing compartment predictions") + + return seg, pred + + +def AZ_pred(raw: np.ndarray, AZ_model: str, output_path: str = None, store: bool = False) -> np.ndarray: + """ + Run active zone segmentation and optionally store the output. + + Args: + raw (np.ndarray): Raw EM image volume. + AZ_model (str): Path to AZ model. + output_path (str): HDF5 file to store predictions. + store (bool): Whether to store predictions. + + Returns: + np.ndarray: Segmentation result. + """ + pred_key = f"predictions/az/pred" + seg_key = f"predictions/az/seg" + + use_existing_seg = False + #checking if segmentation is already in output path and if so, use it + if output_path and os.path.exists(output_path): + with h5py.File(output_path, "r") as f: + if seg_key in f: + seg = f[seg_key][:] + use_existing_seg = True + print(f"Using existing AZ seg in {output_path}") + + if not use_existing_seg: + + seg, pred = segment_active_zone(raw, model_path=AZ_model, verbose=False, return_predictions=True) + + if store and output_path: + + with h5py.File(output_path, "a") as f: + if pred_key in f: + print(f"{pred_key} already saved") + else: + f.create_dataset(pred_key, data=pred, compression="lzf") + + f.create_dataset(seg_key, data=seg, compression="lzf") + elif store and not output_path: + print("Output path is missing, not storing AZ predictions") + else: + print("Not storing AZ predictions") + + return seg + + +def filter_presynaptic_SV(sv_seg: np.ndarray, compartment_seg: np.ndarray, compartment_pred: np.ndarray, output_path: str = None, + store: bool = False, input_path: str = None) -> np.ndarray: + """ + Filters synaptic vesicle segmentation to retain only vesicles in the presynaptic region. + + Args: + sv_seg (np.ndarray): Vesicle segmentation. + compartment_seg (np.ndarray): Compartment segmentation. + output_path (str): Optional HDF5 file to store outputs. + store (bool): Whether to store outputs. + input_path (str): Path to input file (for filename-based filtering). + + Returns: + np.ndarray: Filtered presynaptic vesicle segmentation. + """ + # Fill out small holes in vesicles and then apply a size filter. + vesicles_pp = fill_and_filter_vesicles(sv_seg) + + def n_vesicles(mask, ves): + return len(np.unique(ves[mask])) - 1 + + '''# Find the segment with most vesicles. + props = regionprops(compartment_seg, intensity_image=vesicles_pp, extra_properties=[n_vesicles]) + compartment_ids = [prop.label for prop in props] + vesicle_counts = [prop.n_vesicles for prop in props] + if len(compartment_ids) == 0: + mask = np.ones(compartment_seg.shape, dtype="bool") + else: + mask = (compartment_seg == compartment_ids[np.argmax(vesicle_counts)]).astype("uint8")''' + + mask = _get_presynaptic_mask(compartment_pred, vesicles_pp) + + # Filter all vesicles that are not in the mask. + props = regionprops(vesicles_pp, mask) + filter_ids = [prop.label for prop in props if prop.max_intensity == 0] + + name = os.path.basename(input_path) if input_path else "unknown" + print(name) + + no_filter = ["C_M13DKO_080212_CTRL6.7B_crop.h5", "E_M13DKO_080212_DKO1.2_crop.h5", + "G_M13DKO_080212_CTRL6.7B_crop.h5", "A_SNAP25_120812_CTRL2.3_14_crop.h5", + "A_SNAP25_12082_KO2.1_6_crop.h5", "B_SNAP25_120812_CTRL2.3_14_crop.h5", + "B_SNAP25_12082_CTRL2.3_5_crop.h5", "D_SNAP25_120812_CTRL2.3_14_crop.h5", + "G_SNAP25_12.08.12_KO1.1_3_crop.h5"] + # Don't filter for wrong masks (visual inspection) + if name not in no_filter: + vesicles_pp[np.isin(vesicles_pp, filter_ids)] = 0 + + if store and output_path: + seg_presynapse = f"predictions/compartment/presynapse" + seg_presynaptic_SV = f"predictions/SV/presynaptic" + + with h5py.File(output_path, "a") as f: + if seg_presynapse in f: + print(f"{seg_presynapse} already saved") + else: + f.create_dataset(seg_presynapse, data=mask, compression="lzf") + if seg_presynaptic_SV in f: + print(f"{seg_presynaptic_SV} already saved") + else: + f.create_dataset(seg_presynaptic_SV, data=vesicles_pp, compression="lzf") + elif store and not output_path: + print("Output path is missing, not storing presynapse seg and presynaptic SV seg") + else: + print("Not storing presynapse seg and presynaptic SV seg") + + #All non-zero labels are relabeled starting from 1.Labels are sequential (1, 2, 3, ..., n). + #We do this to make the analysis part easier -> can match distances and diameters better + vesicles_pp, _, _ = relabel_sequential(vesicles_pp) + + return vesicles_pp + + +def run_predictions(input_path: str, output_path: str = None, store: bool = False): + """ + Run full inference pipeline: vesicles, compartments, active zone, and presynaptic SV filtering. + + Args: + input_path (str): Path to input HDF5 file with 'raw' dataset. + output_path (str): Path to output HDF5 file to store predictions. + store (bool): Whether to store intermediate and final results. + + Returns: + Tuple[np.ndarray, np.ndarray]: (Filtered vesicle segmentation, AZ segmentation) + """ + with h5py.File(input_path, "r") as f: + raw = f["raw"][:] + + SV_model = get_model_path("vesicles_3d") + compartment_model = get_model_path("compartments") + # TODO upload better AZ model + AZ_model = "/mnt/lustre-emmy-hdd/usr/u12095/synapse_net/models/ConstantinAZ/checkpoints/v7/" + + print("Running SV prediction") + sv_seg = SV_pred(raw, SV_model, output_path, store) + + print("Running compartment prediction") + comp_seg, comp_pred = compartment_pred(raw, compartment_model, output_path, store) + + print("Running AZ prediction") + az_seg = AZ_pred(raw, AZ_model, output_path, store) + + print("Filtering the presynaptic SV") + presyn_SV_seg = filter_presynaptic_SV(sv_seg, comp_seg, comp_pred, output_path, store, input_path) + + print("Done with predictions") + + return presyn_SV_seg, az_seg diff --git a/scripts/cooper/revision/updated_data_analysis/data_analysis.py b/scripts/cooper/revision/updated_data_analysis/data_analysis.py new file mode 100644 index 0000000..32238fc --- /dev/null +++ b/scripts/cooper/revision/updated_data_analysis/data_analysis.py @@ -0,0 +1,102 @@ +from synapse_net.distance_measurements import measure_segmentation_to_object_distances +from synapse_net.imod.to_imod import convert_segmentation_to_spheres + + +def calc_AZ_SV_distance(vesicles, az, resolution): + """ + Calculate the distance between synaptic vesicles (SVs) and the active zone (AZ). + + Args: + vesicles (np.ndarray): Segmentation of synaptic vesicles. + az (np.ndarray): Segmentation of the active zone. + resolution (tuple): Voxel resolution in nanometers (z, y, x). + + Returns: + list of dict: Each dict contains 'seg_id' and 'distance', sorted by seg_id. + """ + distances, _, _, seg_ids = measure_segmentation_to_object_distances(vesicles, az, resolution=resolution) + + # Exclude seg_id == 0 + dist_list = [ + {"seg_id": sid, "distance": dist} + for sid, dist in zip(seg_ids, distances) + if sid != 0 + ] + dist_list.sort(key=lambda x: x["seg_id"]) + + return dist_list + + +def sort_by_distances(input_list): + """ + Sort a list of dictionaries by the 'distance' key from smallest to largest. + + Args: + input_list (list of dict): List containing 'distance' as a key in each dictionary. + + Returns: + list of dict: Sorted list by ascending distance. + """ + sorted_list = sorted(input_list, key=lambda x: x["distance"]) + return sorted_list + + +def combine_lists(list1, list2): + """ + Combine two lists of dictionaries based on the shared 'seg_id' key. + + Args: + list1 (list of dict): First list with 'seg_id' key. + list2 (list of dict): Second list with 'seg_id' key. + + Returns: + list of dict: Combined dictionaries matching by 'seg_id'. Overlapping keys are merged. + """ + combined_dict = {} + + for item in list1: + seg_id = item["seg_id"] + combined_dict[seg_id] = item.copy() + + for item in list2: + seg_id = item["seg_id"] + if seg_id in combined_dict: + for key, value in item.items(): + if key != "seg_id": + combined_dict[seg_id][key] = value + else: + combined_dict[seg_id] = item.copy() + + combined_list = list(combined_dict.values()) + return combined_list + + +def calc_SV_diameters(vesicles, resolution): + """ + Calculate diameters of synaptic vesicles from segmentation data. + + Args: + vesicles (np.ndarray): Segmentation of synaptic vesicles. + resolution (tuple): Voxel resolution in nanometers (z, y, x). + + Returns: + list of dict: Each dict contains 'seg_id' and 'diameter', sorted by seg_id. + """ + coordinates, radii = convert_segmentation_to_spheres( + vesicles, resolution=resolution, radius_factor=0.7, estimate_radius_2d=True + ) + + # Assuming the segment ID is the index of the vesicle (same order as radii) + seg_ids = list(range(len(radii))) + radii_nm = radii * resolution[0] + diameters = radii_nm * 2 + + # Exclude seg_id == 0 + diam_list = [ + {"seg_id": sid, "diameter": diam} + for sid, diam in zip(seg_ids, diameters) + if sid != 0 + ] + diam_list.sort(key=lambda x: x["seg_id"]) + + return diam_list \ No newline at end of file diff --git a/scripts/cooper/revision/updated_data_analysis/run_data_analysis.py b/scripts/cooper/revision/updated_data_analysis/run_data_analysis.py new file mode 100644 index 0000000..cee93b3 --- /dev/null +++ b/scripts/cooper/revision/updated_data_analysis/run_data_analysis.py @@ -0,0 +1,96 @@ +import argparse +import os +from tqdm import tqdm + +from analysis_segmentations import run_predictions +from data_analysis import calc_AZ_SV_distance, calc_SV_diameters, combine_lists, sort_by_distances +from store_results import run_store_results + +def run_data_analysis(input_path, output_path, store, resolution, analysis_output): + print("Starting SV, compartment, and AZ predictions") + SV_seg, az_seg = run_predictions(input_path, output_path, store) + + print("Performing automatic data analysis") + print("Calculating per SV distance to AZ") + dist_list = calc_AZ_SV_distance(SV_seg, az_seg, resolution) + + print("Calculating per SV diameters") + diam_list = calc_SV_diameters(SV_seg, resolution) + + print("Combining lists") + combined_list = combine_lists(dist_list, diam_list) + print(combined_list) + + print("Sorting the combined list by distances") + sorted_list = sort_by_distances(combined_list) + + print(f"Storing lists under {analysis_output}") + run_store_results(input_path, analysis_output, sorted_list) + + +def main(): + parser = argparse.ArgumentParser(description="Run data analysis on HDF5 data.") + parser.add_argument( + "--input_path", "-i", type=str, required=True, + help="Path to an HDF5 file or directory of files." + ) + parser.add_argument( + "--analysis_output", "-s", type=str, default = "./analysis_results/", + help="Path to the folder where the analysis results get saved." + ) + parser.add_argument( + "--output_folder", "-o", type=str, default=None, + help="Optional output folder for storing results." + ) + parser.add_argument( + "--store", action="store_true", + help="Store predictions in output files." + ) + parser.add_argument( + "--resolution", type=float, nargs=3, default=(1.554, 1.554, 1.554), + help="Resolution of input image." + ) + + args = parser.parse_args() + + input_path = args.input_path + # Get the last directory name of the input_path + if os.path.isdir(input_path): + input_name = os.path.basename(os.path.normpath(input_path)) + else: + input_name = os.path.basename(os.path.dirname(input_path)) + + #get complete output_folder if there was an input + output_folder = args.output_folder + if output_folder: + output_folder = os.path.join(output_folder, input_name) + os.makedirs(output_folder, exist_ok=True) + + store = args.store + resolution = args.resolution + + #get complete output path for the analysis + analysis_output = args.analysis_output + analysis_output = os.path.join(analysis_output, input_name) + os.makedirs(analysis_output, exist_ok=True) + + if os.path.isfile(input_path): + filename = os.path.basename(input_path) + output_path = os.path.join(output_folder, filename) if output_folder else None + run_data_analysis(input_path, output_path, store, resolution, analysis_output) + + elif os.path.isdir(input_path): + h5_files = sorted([file for file in os.listdir(input_path) if file.endswith(".h5")]) + for file in tqdm(h5_files, desc="Processing files"): + full_input_path = os.path.join(input_path, file) + output_path = os.path.join(output_folder, file) if output_folder else None + run_data_analysis(full_input_path, output_path, store, resolution, analysis_output) + + else: + raise ValueError(f"Invalid input path: {input_path}") + + print("Finished!") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/cooper/revision/updated_data_analysis/store_results.py b/scripts/cooper/revision/updated_data_analysis/store_results.py new file mode 100644 index 0000000..d044278 --- /dev/null +++ b/scripts/cooper/revision/updated_data_analysis/store_results.py @@ -0,0 +1,163 @@ +import os +import pandas as pd + +def get_group(input_path): + """ + Determines whether a tomogram belongs to 'CTRL' or 'KO' group. + + Parameters: + input_path (str): Path to the input .h5 file. + + Returns: + str: 'CTRL' if input_path contains 'CTRL', else 'KO'. + """ + return 'CTRL' if 'CTRL' in input_path else 'KO' + + +def get_tomogram_name(input_path): + """ + Extracts the tomogram name from the input path (without extension). + + Parameters: + input_path (str): Path to the input .h5 file. + + Returns: + str: Tomogram base name without extension. + """ + return os.path.splitext(os.path.basename(input_path))[0] + + +def prepare_output_directory(base_output, group): + """ + Ensures that the group-specific output directory exists. + + Parameters: + base_output (str): Base output directory. + group (str): Group name ('CTRL' or 'KO'). + + Returns: + str: Full path to the group-specific directory. + """ + group_dir = os.path.join(base_output, group) + os.makedirs(group_dir, exist_ok=True) + return group_dir + +def write_or_append_excel(file_path, new_data): + """ + Writes a new DataFrame to Excel, or appends a new column(s) to an existing one. + + Parameters: + file_path (str): Path to the target Excel file. + new_data (pd.DataFrame): DataFrame to write or append. + """ + print(f"saving {file_path}") + if os.path.exists(file_path): + existing = pd.read_excel(file_path, index_col=0) + combined = pd.concat([existing, new_data], axis=1) + else: + combined = new_data + + combined.to_excel(file_path, index=True) + +def save_filtered_dataframes(output_dir, tomogram_name, df): + """ + Saves the sorted segment data into multiple filtered Excel files. + + Parameters: + output_dir (str): Directory where Excel files will be saved. + tomogram_name (str): Name of the tomogram (used as column header). + df (pd.DataFrame): DataFrame containing 'seg_id', 'distance', and 'diameter'. + """ + thresholds = { + 'AZ_distances': None, + 'AZ_distances_within_200': 200, + 'AZ_distances_within_100': 100, + 'AZ_distances_within_40': 40, + 'AZ_distances_within_100_with_diameters': 100, + 'AZ_distances_within_100_only_diameters': 100, + } + + for filename, max_dist in thresholds.items(): + file_path = os.path.join(output_dir, f"{filename}.xlsx") + filtered_df = df if max_dist is None else df[df['distance'] <= max_dist] + + if filename == 'AZ_distances_within_100_with_diameters': + data = pd.DataFrame({ + f"{tomogram_name}_distance": filtered_df['distance'].values, + f"{tomogram_name}_diameter": filtered_df['diameter'].values + }) + elif filename == 'AZ_distances_within_100_only_diameters': + data = pd.DataFrame({ + f"{tomogram_name}_diameter": filtered_df['diameter'].values + }) + else: + data = pd.DataFrame({tomogram_name: filtered_df['distance'].values}) + + write_or_append_excel(file_path, data) + + +def save_filtered_dataframes_with_seg_id(output_dir, tomogram_name, df): + """ + Saves segment data including seg_id into separate Excel files. + + Parameters: + output_dir (str): Directory to save files. + tomogram_name (str): Base name of the tomogram. + df (pd.DataFrame): DataFrame with 'seg_id', 'distance', 'diameter'. + """ + thresholds = { + 'AZ_distances_with_seg_id': None, + 'AZ_distances_within_200_with_seg_id': 200, + 'AZ_distances_within_100_with_seg_id': 100, + 'AZ_distances_within_40_with_seg_id': 40, + 'AZ_distances_within_100_with_diameters_and_seg_id': 100, + 'AZ_distances_within_100_only_diameters_and_seg_id': 100, + } + + with_segID_dir = os.path.join(output_dir, "with_segID") + os.makedirs(with_segID_dir, exist_ok=True) + + for filename, max_dist in thresholds.items(): + file_path = os.path.join(with_segID_dir, f"{filename}.xlsx") + filtered_df = df if max_dist is None else df[df['distance'] <= max_dist] + + if filename == 'AZ_distances_within_100_with_diameters_and_seg_id': + data = pd.DataFrame({ + f"{tomogram_name}_seg_id": filtered_df['seg_id'].values, + f"{tomogram_name}_distance": filtered_df['distance'].values, + f"{tomogram_name}_diameter": filtered_df['diameter'].values + }) + elif filename == 'AZ_distances_within_100_only_diameters_and_seg_id': + data = pd.DataFrame({ + f"{tomogram_name}_seg_id": filtered_df['seg_id'].values, + f"{tomogram_name}_diameter": filtered_df['diameter'].values + }) + else: + data = pd.DataFrame({ + f"{tomogram_name}_seg_id": filtered_df['seg_id'].values, + f"{tomogram_name}_distance": filtered_df['distance'].values + }) + + write_or_append_excel(file_path, data) + + +def run_store_results(input_path, analysis_output, sorted_list): + """ + Processes a single tomogram's sorted segment data and stores results into categorized Excel files. + + Parameters: + input_path (str): Path to the input .h5 file. + analysis_output (str): Directory where results should be saved. + sorted_list (list of dict): List of dicts with 'seg_id', 'distance', and 'diameter', + sorted by distance ascendingly. + """ + group = get_group(input_path) + tomogram_name = get_tomogram_name(input_path) + group_dir = prepare_output_directory(analysis_output, group) + df = pd.DataFrame(sorted_list) + + # First run: distances only + save_filtered_dataframes(group_dir, tomogram_name, df) + + # Second run: include seg_id in the filenames and output + save_filtered_dataframes_with_seg_id(group_dir, tomogram_name, df) diff --git a/scripts/cooper/run_compartment_segmentation_h5.py b/scripts/cooper/run_compartment_segmentation_h5.py new file mode 100644 index 0000000..4376493 --- /dev/null +++ b/scripts/cooper/run_compartment_segmentation_h5.py @@ -0,0 +1,105 @@ +import argparse +from functools import partial + +from synapse_net.inference.compartments import segment_compartments +from synapse_net.inference.inference import get_model_path +from synapse_net.inference.util import inference_helper, parse_tiling + +import h5py +import numpy as np +from elf.io import open_file + +def get_volume(input_path): + ''' + with h5py.File(input_path) as seg_file: + input_volume = seg_file["raw"][:] + ''' + with open_file(input_path, "r") as f: + + # Try to automatically derive the key with the raw data. + keys = list(f.keys()) + if len(keys) == 1: + key = keys[0] + elif "data" in keys: + key = "data" + elif "raw" in keys: + key = "raw" + + input_volume = f[key][:] + return input_volume + +def run_compartment_segmentation(args): + tiling = parse_tiling(args.tile_shape, args.halo) + + if args.model is None: + model_path = get_model_path("compartments") + else: + model_path = args.model + + # Call segment_compartments directly, since we need its outputs + segmentation, predictions = segment_compartments( + get_volume(args.input_path), + model_path=model_path, + verbose=True, + tiling=tiling, + scale=None, + boundary_threshold=args.boundary_threshold, + return_predictions=True + ) + + # Save outputs into input HDF5 file + with h5py.File(args.input_path, "a") as f: + pred_grp = f.require_group("predictions") + + if "comp_seg" in pred_grp: + if args.force: + del pred_grp["comp_seg"] + else: + raise RuntimeError("comp_seg already exists. Use --force to overwrite.") + pred_grp.create_dataset("comp_seg", data=segmentation.astype(np.uint8), compression="gzip") + + if "boundaries" in pred_grp: + if args.force: + del pred_grp["boundaries"] + else: + raise RuntimeError("boundaries already exist. Use --force to overwrite.") + pred_grp.create_dataset("boundaries", data=predictions.astype(np.float32), compression="gzip") + + print(f"Saved segmentation to: predictions/comp_seg") + print(f"Saved boundaries to: predictions/boundaries") + + +def main(): + parser = argparse.ArgumentParser(description="Segment synaptic compartments in EM tomograms.") + parser.add_argument( + "--input_path", "-i", required=True, + help="The filepath to mrc file or directory containing the tomogram data." + ) + parser.add_argument( + "--model", "-m", help="The filepath to the compartment model." + ) + parser.add_argument( + "--force", action="store_true", + help="Whether to over-write already present segmentation results." + ) + parser.add_argument( + "--tile_shape", type=int, nargs=3, + help="The tile shape for prediction. Lower the tile shape if GPU memory is insufficient." + ) + parser.add_argument( + "--halo", type=int, nargs=3, + help="The halo for prediction. Increase the halo to minimize boundary artifacts." + ) + parser.add_argument( + "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc." + ) + parser.add_argument( + "--boundary_threshold", type=float, default=0.4, help="Threshold that determines when the prediction of the network is foreground for the segmentation. Need higher threshold than default for TEM." + ) + + args = parser.parse_args() + run_compartment_segmentation(args) + + +if __name__ == "__main__": + main() diff --git a/scripts/cooper/vesicle_segmentation_h5.py b/scripts/cooper/vesicle_segmentation_h5.py index 0237973..3bda54b 100644 --- a/scripts/cooper/vesicle_segmentation_h5.py +++ b/scripts/cooper/vesicle_segmentation_h5.py @@ -8,6 +8,7 @@ from synapse_net.inference.vesicles import segment_vesicles from synapse_net.inference.util import parse_tiling +from synapse_net.inference.inference import get_model_path def _require_output_folders(output_folder): #seg_output = os.path.join(output_folder, "segmentations") @@ -34,7 +35,7 @@ def get_volume(input_path): input_volume = f[key][:] return input_volume -def run_vesicle_segmentation(input_path, output_path, model_path, mask_path, mask_key,tile_shape, halo, include_boundary, key_label): +def run_vesicle_segmentation(input_path, output_path, mask_path, mask_key,tile_shape, halo, include_boundary, key_label, model_path=None, save_pred=False): tiling = parse_tiling(tile_shape, halo) print(f"using tiling {tiling}") input = get_volume(input_path) @@ -45,7 +46,10 @@ def run_vesicle_segmentation(input_path, output_path, model_path, mask_path, mas mask = f[mask_key][:] else: mask = None - + + if model_path is None: + model_path = get_model_path("vesicles_3d") + segmentation, prediction = segment_vesicles(input_volume=input, model_path=model_path, verbose=False, tiling=tiling, return_predictions=True, exclude_boundary=not include_boundary, mask = mask) foreground, boundaries = prediction[:2] @@ -63,14 +67,15 @@ def run_vesicle_segmentation(input_path, output_path, model_path, mask_path, mas else: f.create_dataset("raw", data=input, compression="gzip") - key=f"vesicles/segment_from_{key_label}" + key=f"predictions/{key_label}" if key in f: print("Skipping", input_path, "because", key, "exists") else: f.create_dataset(key, data=segmentation, compression="gzip") - f.create_dataset(f"prediction_{key_label}/foreground", data = foreground, compression="gzip") - f.create_dataset(f"prediction_{key_label}/boundaries", data = boundaries, compression="gzip") - + if save_pred: + f.create_dataset(f"prediction_{key_label}/foreground", data = foreground, compression="gzip") + f.create_dataset(f"prediction_{key_label}/boundaries", data = boundaries, compression="gzip") + if mask is not None: if mask_key in f: print("mask image already saved") @@ -97,7 +102,7 @@ def segment_folder(args): print(f"Mask file not found for {input_path}") mask_path = None - run_vesicle_segmentation(input_path, args.output_path, args.model_path, mask_path, args.mask_key, args.tile_shape, args.halo, args.include_boundary, args.key_label) + run_vesicle_segmentation(input_path, args.output_path, mask_path, args.mask_key, args.tile_shape, args.halo, args.include_boundary, args.key_label, args.model_path, args.save_pred) def main(): parser = argparse.ArgumentParser(description="Segment vesicles in EM tomograms.") @@ -110,7 +115,7 @@ def main(): help="The filepath to directory where the segmentations will be saved." ) parser.add_argument( - "--model_path", "-m", required=True, help="The filepath to the vesicle model." + "--model_path", "-m", help="The filepath to the vesicle model." ) parser.add_argument( "--mask_path", help="The filepath to a h5 file with a mask that will be used to restrict the segmentation. Needs to be in combination with mask_key." @@ -131,9 +136,13 @@ def main(): help="Include vesicles that touch the top / bottom of the tomogram. By default these are excluded." ) parser.add_argument( - "--key_label", "-k", default = "combined_vesicles", + "--key_label", "-k", default = "vesicle_seg", help="Give the key name for saving the segmentation in h5." ) + parser.add_argument( + "--save_pred", action="store_true", + help="If set to true the prediction is also saved." + ) args = parser.parse_args() input_ = args.input_path @@ -141,7 +150,7 @@ def main(): if os.path.isdir(input_): segment_folder(args) else: - run_vesicle_segmentation(input_, args.output_path, args.model_path, args.mask_path, args.mask_key, args.tile_shape, args.halo, args.include_boundary, args.key_label) + run_vesicle_segmentation(input_, args.output_path, args.mask_path, args.mask_key, args.tile_shape, args.halo, args.include_boundary, args.key_label, args.model_path, args.save_pred) print("Finished segmenting!") diff --git a/synapse_net/ground_truth/az_evaluation.py b/synapse_net/ground_truth/az_evaluation.py new file mode 100644 index 0000000..9c95505 --- /dev/null +++ b/synapse_net/ground_truth/az_evaluation.py @@ -0,0 +1,243 @@ +import os +from typing import List, Optional + +import h5py +import pandas as pd +import numpy as np +import vigra + +from elf.evaluation.matching import _compute_scores, _compute_tps +from elf.evaluation import dice_score +from elf.segmentation.workflows import simple_multicut_workflow +from scipy.ndimage import binary_dilation, binary_closing, distance_transform_edt, binary_opening +from skimage.measure import label, regionprops, regionprops_table +from skimage.segmentation import relabel_sequential, watershed +from tqdm import tqdm + + +def _expand_seg(az, iterations): + return binary_closing(binary_dilation(az, iterations=iterations), iterations=iterations) + + +def _crop(seg, gt, return_bb=False): + bb_seg, bb_gt = np.where(seg), np.where(gt) + + # Handle empty segmentations. + if bb_seg[0].size == 0: + bb = tuple(slice(bgt.min(), bgt.max() + 1) for bseg, bgt in zip(bb_seg, bb_gt)) + else: + bb = tuple(slice( + min(bseg.min(), bgt.min()), max(bseg.max(), bgt.max()) + 1 + ) for bseg, bgt in zip(bb_seg, bb_gt)) + + if return_bb: + return seg[bb], gt[bb], bb + else: + return seg[bb], gt[bb] + + +def _postprocess(data, apply_cc, min_component_size, iterations=0): + if iterations > 0: + data = _expand_seg(data, iterations) + if apply_cc: + data = label(data) + ids, sizes = np.unique(data, return_counts=True) + filter_ids = ids[sizes < min_component_size] + data[np.isin(data, filter_ids)] = 0 + data, _, _ = relabel_sequential(data) + return data + + +def _single_az_evaluation(seg, gt, apply_cc, min_component_size, iterations, criterion): + assert seg.shape == gt.shape, f"{seg.shape}, {gt.shape}" + dice = dice_score(seg > 0, gt > 0) + + seg = _postprocess(seg, apply_cc, min_component_size, iterations=iterations) + gt = _postprocess(gt, apply_cc, min_component_size=500) + + n_true, n_matched, n_pred, scores = _compute_scores(seg, gt, criterion=criterion, ignore_label=0) + tp = _compute_tps(scores, n_matched, threshold=0.5) + fp = n_pred - tp + fn = n_true - tp + + return {"tp": tp, "fp": fp, "fn": fn, "dice": dice} + + +def az_evaluation( + seg_paths: List[str], + gt_paths: List[str], + seg_key: str, + gt_key: str, + crop: bool = True, + apply_cc: bool = True, + min_component_size: int = 5000, + iterations: int = 3, + criterion: str = "iou", + threshold: Optional[float] = None, + **extra_cols +) -> pd.DataFrame: + """Evaluate active zone segmentations against ground-truth annotations. + + This computes the dice score as well as false positives, false negatives and true positives + for each segmented tomogram. + + Args: + seg_paths: The filepaths to the segmentations, stored as hd5 files. + gt_paths: The filepaths to the ground-truth annotatons, stored as hdf5 files. + seg_key: The internal path to the data in the segmentation hdf5 file. + gt_key: The internal path to the data in the ground-truth hdf5 file. + crop: Whether to crop the segmentation and ground-truth to the bounding box. + apply_cc: Whether to apply connected components before evaluation. + min_component_size: Minimum component size for filtering the segmentation before evaluation. + iterations: Post-processing iterations for expanding the AZ annotations. + criterion: The criterion for matching annotations and segmentations + threshold: Threshold applied to the segmentation. This is required if the segmentation is passed as + probability prediction instead of a binary segmentation. Possible values: 'iou', 'iop', 'iot'. + extra_cols: Additional columns for the result table. + + Returns: + A data frame with the evaluation results per tomogram. + """ + assert len(seg_paths) == len(gt_paths) + + results = {key: [] for key in extra_cols.keys()} + results.update({ + "tomo_name": [], + "tp": [], + "fp": [], + "fn": [], + "dice": [], + }) + + i = 0 + for seg_path, gt_path in tqdm(zip(seg_paths, gt_paths), total=len(seg_paths), desc="Run AZ Eval"): + with h5py.File(seg_path, "r") as f: + if seg_key not in f: + print("Segmentation", seg_key, "could not be found in", seg_path) + i += 1 + continue + seg = f[seg_key][:].squeeze() + + with h5py.File(gt_path, "r") as f: + gt = f[gt_key][:] + + if threshold is not None: + seg = seg > threshold + + if crop: + seg, gt = _crop(seg, gt) + + result = _single_az_evaluation(seg, gt, apply_cc, min_component_size, iterations, criterion=criterion) + results["tomo_name"].append(os.path.basename(seg_path)) + for res in ("tp", "fp", "fn", "dice"): + results[res].append(result[res]) + for name, val in extra_cols.items(): + results[name].append(val[i]) + i += 1 + + return pd.DataFrame(results) + + +def _get_presynaptic_mask(boundary_map, vesicles): + mask = np.zeros(vesicles.shape, dtype="bool") + + def _compute_mask_2d(z): + distances = distance_transform_edt(boundary_map[z] < 0.25).astype("float32") + seeds = vigra.analysis.localMaxima(distances, marker=np.nan, allowAtBorder=True, allowPlateaus=True) + seeds = label(np.isnan(seeds)) + overseg = watershed(boundary_map[z], markers=seeds) + seg = simple_multicut_workflow( + boundary_map[z], use_2dws=False, watershed=overseg, n_threads=1, beta=0.6 + ) + + def n_vesicles(mask, seg): + return len(np.unique(seg[mask])) - 1 + + props = pd.DataFrame(regionprops_table(seg, vesicles[z], properties=["label"], extra_properties=[n_vesicles])) + ids, n_ves = props.label.values, props.n_vesicles.values + presyn_id = ids[np.argmax(n_ves)] + + mask[z] = seg == presyn_id + + for z in range(mask.shape[0]): + _compute_mask_2d(z) + + mask = binary_opening(mask, iterations=5) + + return mask + + +def thin_az( + az_segmentation: np.ndarray, + boundary_map: np.typing.ArrayLike, + vesicles: np.typing.ArrayLike, + tomo: Optional[np.typing.ArrayLike] = None, + presyn_dist: int = 6, + min_thinning_size: int = 2500, + post_closing: int = 2, + check: bool = False, +) -> np.ndarray: + """Thin the active zone annotations by restricting them to a certain distance from the presynaptic mask. + + Args: + az_segmentation: The active zone annotations. + boundary_map: The boundary / membrane predictions. + vesicles: The vesicle segmentation. + tomo: The tomogram data. Optional, will only be used for evaluation. + presyn_dist: The maximal distance to the presynaptic compartment, which is used for thinning. + min_thinning_size: The minimal size for a label component. + post_closing: Closing iterations to apply to the AZ annotations after thinning. + check: Whether to visually check the results. + + Returns: + The thinned AZ annotations. + """ + az_segmentation = label(az_segmentation) + thinned_az = np.zeros(az_segmentation.shape, dtype="uint8") + props = regionprops(az_segmentation) + + min_bb_shape = (32, 384, 384) + + for prop in props: + az_id = prop.label + + bb = tuple(slice(start, stop) for start, stop in zip(prop.bbox[:3], prop.bbox[3:])) + pad_width = [max(sh - (b.stop - b.start), 0) // 2 for b, sh in zip(bb, min_bb_shape)] + bb = tuple( + slice(max(b.start - pw, 0), min(b.stop + pw, sh)) for b, pw, sh in zip(bb, pad_width, az_segmentation.shape) + ) + + # If this is a small component then we discard it. This is likely some artifact in the ground-truth. + if prop.area < min_thinning_size: + continue + + # First, get everything for this bounding box. + az_bb = (az_segmentation[bb] == az_id) + vesicles_bb = vesicles[bb] + # Skip if we don't have a vesicle. + if vesicles[bb].max() == 0: + continue + + mask_bb = _get_presynaptic_mask(boundary_map[bb], vesicles_bb) + + # Apply post-processing to filter out only the parts of the AZ close to the presynaptic mask. + distances = np.stack([distance_transform_edt(mask_bb[z] == 0) for z in range(mask_bb.shape[0])]) + az_bb[distances > presyn_dist] = 0 + az_bb = np.logical_or(binary_closing(az_bb, iterations=post_closing), az_bb) + + if check: + import napari + tomo_bb = tomo[bb] + + v = napari.Viewer() + v.add_image(tomo_bb) + v.add_labels(az_bb.astype("uint8"), name="az-thinned") + v.add_labels(az_segmentation[bb], name="az", visible=False) + v.add_labels(mask_bb, visible=False) + v.title = f"{prop.label}: {prop.area}" + + napari.run() + + thinned_az[bb][az_bb] = 1 + + return thinned_az diff --git a/synapse_net/inference/active_zone.py b/synapse_net/inference/active_zone.py index e9040fe..2be654f 100644 --- a/synapse_net/inference/active_zone.py +++ b/synapse_net/inference/active_zone.py @@ -106,7 +106,6 @@ def segment_active_zone( # Run segmentation and rescale the result if necessary. foreground = pred[0] - print(f"shape {foreground.shape}") segmentation = _run_segmentation(foreground, verbose=verbose, min_size=min_size) segmentation = scaler.rescale_output(segmentation, is_segmentation=True) diff --git a/synapse_net/inference/compartments.py b/synapse_net/inference/compartments.py index 113230f..3c5d6b1 100644 --- a/synapse_net/inference/compartments.py +++ b/synapse_net/inference/compartments.py @@ -128,7 +128,7 @@ def _segment_compartments_3d( continue seg_z = _segment_compartments_2d(prediction[z], distances=distances[z]) seg_z[seg_z != 0] += offset - offset = int(seg_z.max()) + offset = max(int(seg_z.max()), offset) seg_2d[z] = seg_z seg = _merge_segmentation_3d(seg_2d, min_z_extent) diff --git a/synapse_net/training/__init__.py b/synapse_net/training/__init__.py index 7e32f94..84204c9 100644 --- a/synapse_net/training/__init__.py +++ b/synapse_net/training/__init__.py @@ -3,3 +3,4 @@ from .supervised_training import supervised_training from .semisupervised_training import semisupervised_training from .domain_adaptation import mean_teacher_adaptation +from .transform import AZDistanceLabelTransform diff --git a/synapse_net/training/transform.py b/synapse_net/training/transform.py new file mode 100644 index 0000000..04bad88 --- /dev/null +++ b/synapse_net/training/transform.py @@ -0,0 +1,18 @@ +import numpy as np +from torch_em.transform.label import labels_to_binary +from scipy.ndimage import distance_transform_edt + + +class AZDistanceLabelTransform: + def __init__(self, max_distance: float = 50.0): + self.max_distance = max_distance + + def __call__(self, input_): + binary_target = labels_to_binary(input_).astype("float32") + if binary_target.sum() == 0: + distances = np.ones_like(binary_target, dtype="float32") + else: + distances = distance_transform_edt(binary_target == 0) + distances = np.clip(distances, 0.0, self.max_distance) + distances /= self.max_distance + return np.stack([binary_target, distances])