4
4
import os
5
5
import shutil
6
6
import warnings
7
- from typing import Dict , Iterator , List , Union
7
+ from typing import Dict , List , Union
8
8
9
9
import cv2
10
10
import geopandas as gpd
11
11
import numpy as np
12
+ from loguru import logger
12
13
from hestcore .wsi import (WSI , CucimWarningSingleton , NumpyWSI ,
13
14
contours_to_img , wsi_factory )
14
15
from loguru import logger
15
16
16
-
17
-
18
- from hest .io .seg_readers import TissueContourReader
17
+ from hest .io .seg_readers import TissueContourReader , write_geojson
19
18
from hest .LazyShapes import LazyShapes , convert_old_to_gpd , old_geojson_to_new
20
19
from hest .segmentation .TissueMask import TissueMask , load_tissue_mask
21
20
31
30
from tqdm import tqdm
32
31
33
32
from .utils import (ALIGNED_HE_FILENAME , check_arg , deprecated ,
34
- find_first_file_endswith , get_path_from_meta_row ,
33
+ find_first_file_endswith , get_k_genes_from_df , get_path_from_meta_row ,
35
34
plot_verify_pixel_size , tiff_save , verify_paths )
36
35
37
36
@@ -100,7 +99,7 @@ class representing a single ST profile + its associated WSI image
100
99
else :
101
100
self ._tissue_contours = tissue_contours
102
101
103
- if 'total_counts' not in self .adata .var_names :
102
+ if 'total_counts' not in self .adata .var_names and len ( self . adata ) > 0 :
104
103
sc .pp .calculate_qc_metrics (self .adata , inplace = True )
105
104
106
105
@@ -133,7 +132,7 @@ def load_wsi(self) -> None:
133
132
self .wsi = NumpyWSI (self .wsi .numpy ())
134
133
135
134
136
- def save (self , path : str , save_img = True , pyramidal = True , bigtiff = False , plot_pxl_size = False ):
135
+ def save (self , path : str , save_img = True , pyramidal = True , bigtiff = False , plot_pxl_size = False , ** kwargs ):
137
136
"""Save a HESTData object to `path` as follows:
138
137
- aligned_adata.h5ad (contains expressions for each spots + their location on the fullres image + a downscaled version of the fullres image)
139
138
- metrics.json (contains useful metrics)
@@ -155,6 +154,8 @@ def save(self, path: str, save_img=True, pyramidal=True, bigtiff=False, plot_pxl
155
154
self .adata .write (os .path .join (path , 'aligned_adata.h5ad' ))
156
155
except :
157
156
# workaround from https://github.com/theislab/scvelo/issues/255
157
+ import traceback
158
+ traceback .print_exc ()
158
159
self .adata .__dict__ ['_raw' ].__dict__ ['_var' ] = self .adata .__dict__ ['_raw' ].__dict__ ['_var' ].rename (columns = {'_index' : 'features' })
159
160
self .adata .write (os .path .join (path , 'aligned_adata.h5ad' ))
160
161
@@ -172,7 +173,8 @@ def save(self, path: str, save_img=True, pyramidal=True, bigtiff=False, plot_pxl
172
173
downscaled_img = self .adata .uns ['spatial' ]['ST' ]['images' ]['downscaled_fullres' ]
173
174
down_fact = self .adata .uns ['spatial' ]['ST' ]['scalefactors' ]['tissue_downscaled_fullres_scalef' ]
174
175
down_img = Image .fromarray (downscaled_img )
175
- down_img .save (os .path .join (path , 'downscaled_fullres.jpeg' ))
176
+ if len (downscaled_img ) > 0 :
177
+ down_img .save (os .path .join (path , 'downscaled_fullres.jpeg' ))
176
178
177
179
178
180
if plot_pxl_size :
@@ -748,7 +750,9 @@ def __init__(
748
750
xenium_nuc_seg : pd .DataFrame = None ,
749
751
xenium_cell_seg : pd .DataFrame = None ,
750
752
cell_adata : sc .AnnData = None , # type: ignore
751
- transcript_df : pd .DataFrame = None
753
+ transcript_df : pd .DataFrame = None ,
754
+ dapi_path : str = None ,
755
+ alignment_file_path : str = None
752
756
):
753
757
"""
754
758
class representing a single ST profile + its associated WSI image
@@ -765,16 +769,31 @@ class representing a single ST profile + its associated WSI image
765
769
xenium_cell_seg (pd.DataFrame): content of a xenium cell contour file as a dataframe (cell_boundaries.parquet)
766
770
cell_adata (sc.AnnData): ST cell data, each row in adata.obs is a cell, each row in obsm is the cell location on the H&E image in pixels
767
771
transcript_df (pd.DataFrame): dataframe of transcripts, each row is a transcript, he_x and he_y is the transcript location on the H&E image in pixels
772
+ dapi_path (str): path to a dapi focus image
773
+ alignment_file_path (np.ndarray): path to xenium alignment path
768
774
"""
769
775
super ().__init__ (adata = adata , img = img , pixel_size = pixel_size , meta = meta , tissue_seg = tissue_seg , tissue_contours = tissue_contours , shapes = shapes )
770
776
771
777
self .xenium_nuc_seg = xenium_nuc_seg
772
778
self .xenium_cell_seg = xenium_cell_seg
773
779
self .cell_adata = cell_adata
774
780
self .transcript_df = transcript_df
775
-
776
-
777
- def save (self , path : str , save_img = True , pyramidal = True , bigtiff = False , plot_pxl_size = False ):
781
+ self .dapi_path = dapi_path
782
+ self .alignment_file_path = alignment_file_path
783
+
784
+
785
+ def save (
786
+ self ,
787
+ path : str ,
788
+ save_img = True ,
789
+ pyramidal = True ,
790
+ bigtiff = False ,
791
+ plot_pxl_size = False ,
792
+ save_transcripts = False ,
793
+ save_cell_seg = False ,
794
+ save_nuclei_seg = False ,
795
+ ** kwargs
796
+ ):
778
797
"""Save a HESTData object to `path` as follows:
779
798
- aligned_adata.h5ad (contains expressions for each spots + their location on the fullres image + a downscaled version of the fullres image)
780
799
- metrics.json (contains useful metrics)
@@ -795,21 +814,18 @@ def save(self, path: str, save_img=True, pyramidal=True, bigtiff=False, plot_pxl
795
814
if self .cell_adata is not None :
796
815
self .cell_adata .write_h5ad (os .path .join (path , 'aligned_cells.h5ad' ))
797
816
798
- if self .transcript_df is not None :
817
+ if save_transcripts and self .transcript_df is not None :
799
818
self .transcript_df .to_parquet (os .path .join (path , 'aligned_transcripts.parquet' ))
819
+
820
+ if save_cell_seg :
821
+ he_cells = self .get_shapes ('tenx_cell' , 'he' ).shapes
822
+ he_cells .to_parquet (os .path .join (path , 'he_cell_seg.parquet' ))
823
+ write_geojson (he_cells , os .path .join (path , f'he_cell_seg.geojson' ), '' , chunk = True )
800
824
801
- if self .xenium_nuc_seg is not None :
802
- print ('Saving Xenium nucleus boundaries... (can be slow)' )
803
- with open (os .path .join (path , 'nuclei_xenium.geojson' ), 'w' ) as f :
804
- json .dump (self .xenium_nuc_seg , f , indent = 4 )
805
-
806
- if self .xenium_cell_seg is not None :
807
- print ('Saving Xenium cells boundaries... (can be slow)' )
808
- with open (os .path .join (path , 'cells_xenium.geojson' ), 'w' ) as f :
809
- json .dump (self .xenium_cell_seg , f , indent = 4 )
810
-
811
-
812
- # TODO save segmentation
825
+ if save_nuclei_seg :
826
+ he_nuclei = self .get_shapes ('tenx_nucleus' , 'he' ).shapes
827
+ he_nuclei .to_parquet (os .path .join (path , 'he_nucleus_seg.parquet' ))
828
+ write_geojson (he_nuclei , os .path .join (path , f'he_nucleus_seg.geojson' ), '' , chunk = True )
813
829
814
830
815
831
def read_HESTData (
@@ -936,19 +952,33 @@ def mask_and_patchify_bench(meta_df: pd.DataFrame, save_dir: str, use_mask=True,
936
952
i += 1
937
953
938
954
939
- def create_benchmark_data (meta_df , save_dir :str , K , adata_folder , use_mask , keep_largest = None ):
955
+ def create_benchmark_data (meta_df , save_dir :str , K ):
940
956
os .makedirs (save_dir , exist_ok = True )
941
- if K is not None :
942
- splits = meta_df .groupby ('patient' )['id' ].agg (list ).to_dict ()
943
- create_splits (os .path .join (save_dir , 'splits' ), splits , K = K )
957
+
958
+ meta_df ['patient' ] = meta_df ['patient' ].fillna ('Patient 1' )
959
+
960
+ get_k_genes_from_df (meta_df , 50 , 'var' , os .path .join (save_dir , 'var_50genes.json' ))
961
+
962
+ splits = meta_df .groupby (['dataset_title' , 'patient' ])['id' ].agg (list ).to_dict ()
963
+ create_splits (os .path .join (save_dir , 'splits' ), splits , K = K )
944
964
945
965
os .makedirs (os .path .join (save_dir , 'patches' ), exist_ok = True )
946
- mask_and_patchify_bench (meta_df , os .path .join (save_dir , 'patches' ), use_mask = use_mask , keep_largest = keep_largest )
966
+ # mask_and_patchify_bench(meta_df, os.path.join(save_dir, 'patches'), use_mask=use_mask, keep_largest=keep_largest)
947
967
968
+ os .makedirs (os .path .join (save_dir , 'patches_vis' ), exist_ok = True )
948
969
os .makedirs (os .path .join (save_dir , 'adata' ), exist_ok = True )
949
- for index , row in meta_df .iterrows ():
970
+ for _ , row in meta_df .iterrows ():
950
971
id = row ['id' ]
951
- src_adata = os .path .join (adata_folder , id + '.h5ad' )
972
+ path = os .path .join (get_path_from_meta_row (row ), 'processed' )
973
+ src_patch = os .path .join (path , 'patches.h5' )
974
+ dst_patch = os .path .join (save_dir , 'patches' , id + '.h5' )
975
+ shutil .copy (src_patch , dst_patch )
976
+
977
+ src_vis = os .path .join (path , 'patches_patch_vis.png' )
978
+ dst_vis = os .path .join (save_dir , 'patches_vis' , id + '.png' )
979
+ shutil .copy (src_vis , dst_vis )
980
+
981
+ src_adata = os .path .join (path , 'aligned_adata.h5ad' )
952
982
dst_adata = os .path .join (save_dir , 'adata' , id + '.h5ad' )
953
983
shutil .copy (src_adata , dst_adata )
954
984
@@ -1200,6 +1230,13 @@ def unify_gene_names(adata: sc.AnnData, species="human", drop=False) -> sc.AnnDa
1200
1230
mask = ~ adata .var_names .duplicated (keep = 'first' )
1201
1231
adata = adata [:, mask ]
1202
1232
1233
+ duplicated_genes_after = adata .var_names [adata .var_names .duplicated ()]
1234
+ if len (duplicated_genes_after ) > len (duplicated_genes_before ):
1235
+ logger .warning (f"duplicated genes increased from { len (duplicated_genes_before )} to { len (duplicated_genes_after )} after resolving aliases" )
1236
+ logger .info ('deduplicating...' )
1237
+ mask = ~ adata .var_names .duplicated (keep = 'first' )
1238
+ adata = adata [:, mask ]
1239
+
1203
1240
if drop :
1204
1241
adata = adata [:, ~ remaining ]
1205
1242
0 commit comments