Skip to content

Commit b87046f

Browse files
authored
add dask suppport and valis alignment (#69)
- add dask support for loading larger than RAM transcript dataframes - refactor alignment methods (decouple alignment matrix loading from alignment) - add Valis registration support to the processinig pipeline - cleanup of unused functions/impots - hestcore 1.0.3 -> 1.0.4 (better documentation)
1 parent 239b506 commit b87046f

19 files changed

+3018
-688
lines changed

.gitignore

Lines changed: 23 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,23 @@
1-
__pycache__
2-
data
3-
align_coord_img.ipynb
4-
src/hiSTloader/yolov8n.pt
5-
.vscode/launch.json
6-
dist
7-
src/hest.egg-info
8-
make.bat
9-
Makefile
10-
*.parquet
11-
bench_data
12-
cell_seg
13-
filtered
14-
src/hest/bench/timm_ctp
15-
my_notebooks
16-
.gitattributes
17-
cells_xenium.geojson
18-
nuclei_xenium.geojson
19-
nuclei.tif
20-
cells.tif
21-
src_slides
22-
data64.h5
23-
tests/assets
24-
config
25-
tests/output_tests
26-
tissue_seg
27-
28-
results
29-
atlas
30-
figures/test_paul
31-
bench_data.zip
32-
old_bench_data
33-
tutorials/downloads
34-
tutorials/processed
35-
bench_config/my_bench_config.yaml
36-
src/hest/bench/private
37-
str.csv
38-
bench_data_old
39-
ST_data_emb/
40-
ST_pred_results/
41-
hest_data
42-
fm_v1
43-
cufile.log
44-
int.csv
45-
docs/build
46-
docs/source/generated
47-
local
48-
hest_vis
49-
hest_vis2
50-
hest_vis
51-
vis
52-
vis2
53-
models/deeplabv3*
54-
htmlcov
55-
models/CellViT-SAM-H-x40.pth
56-
debug_seg
57-
replace_seg
58-
test_vis
1+
data
2+
.vscode/launch.json
3+
dist
4+
src/hest.egg-info
5+
bench_data
6+
.gitattributes
7+
tests/assets
8+
config
9+
tests/output_tests
10+
HEST/
11+
12+
results
13+
atlas
14+
ST_data_emb/
15+
ST_pred_results/
16+
hest_data
17+
fm_v1
18+
docs/build
19+
docs/source/generated
20+
local
21+
models/deeplabv3*
22+
htmlcov
23+
models/CellViT-SAM-H-x40.pth

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ dependencies = [
2828
"spatial_image >= 0.3.0",
2929
"datasets",
3030
"mygene",
31-
"hestcore == 1.0.3"
31+
"hestcore == 1.0.4"
3232
]
3333

3434
requires-python = ">=3.9"

src/hest/HESTData.py

Lines changed: 69 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,17 @@
44
import os
55
import shutil
66
import warnings
7-
from typing import Dict, Iterator, List, Union
7+
from typing import Dict, List, Union
88

99
import cv2
1010
import geopandas as gpd
1111
import numpy as np
12+
from loguru import logger
1213
from hestcore.wsi import (WSI, CucimWarningSingleton, NumpyWSI,
1314
contours_to_img, wsi_factory)
1415
from loguru import logger
1516

16-
17-
18-
from hest.io.seg_readers import TissueContourReader
17+
from hest.io.seg_readers import TissueContourReader, write_geojson
1918
from hest.LazyShapes import LazyShapes, convert_old_to_gpd, old_geojson_to_new
2019
from hest.segmentation.TissueMask import TissueMask, load_tissue_mask
2120

@@ -31,7 +30,7 @@
3130
from tqdm import tqdm
3231

3332
from .utils import (ALIGNED_HE_FILENAME, check_arg, deprecated,
34-
find_first_file_endswith, get_path_from_meta_row,
33+
find_first_file_endswith, get_k_genes_from_df, get_path_from_meta_row,
3534
plot_verify_pixel_size, tiff_save, verify_paths)
3635

3736

@@ -100,7 +99,7 @@ class representing a single ST profile + its associated WSI image
10099
else:
101100
self._tissue_contours = tissue_contours
102101

103-
if 'total_counts' not in self.adata.var_names:
102+
if 'total_counts' not in self.adata.var_names and len(self.adata) > 0:
104103
sc.pp.calculate_qc_metrics(self.adata, inplace=True)
105104

106105

@@ -133,7 +132,7 @@ def load_wsi(self) -> None:
133132
self.wsi = NumpyWSI(self.wsi.numpy())
134133

135134

136-
def save(self, path: str, save_img=True, pyramidal=True, bigtiff=False, plot_pxl_size=False):
135+
def save(self, path: str, save_img=True, pyramidal=True, bigtiff=False, plot_pxl_size=False, **kwargs):
137136
"""Save a HESTData object to `path` as follows:
138137
- aligned_adata.h5ad (contains expressions for each spots + their location on the fullres image + a downscaled version of the fullres image)
139138
- metrics.json (contains useful metrics)
@@ -155,6 +154,8 @@ def save(self, path: str, save_img=True, pyramidal=True, bigtiff=False, plot_pxl
155154
self.adata.write(os.path.join(path, 'aligned_adata.h5ad'))
156155
except:
157156
# workaround from https://github.com/theislab/scvelo/issues/255
157+
import traceback
158+
traceback.print_exc()
158159
self.adata.__dict__['_raw'].__dict__['_var'] = self.adata.__dict__['_raw'].__dict__['_var'].rename(columns={'_index': 'features'})
159160
self.adata.write(os.path.join(path, 'aligned_adata.h5ad'))
160161

@@ -172,7 +173,8 @@ def save(self, path: str, save_img=True, pyramidal=True, bigtiff=False, plot_pxl
172173
downscaled_img = self.adata.uns['spatial']['ST']['images']['downscaled_fullres']
173174
down_fact = self.adata.uns['spatial']['ST']['scalefactors']['tissue_downscaled_fullres_scalef']
174175
down_img = Image.fromarray(downscaled_img)
175-
down_img.save(os.path.join(path, 'downscaled_fullres.jpeg'))
176+
if len(downscaled_img) > 0:
177+
down_img.save(os.path.join(path, 'downscaled_fullres.jpeg'))
176178

177179

178180
if plot_pxl_size:
@@ -748,7 +750,9 @@ def __init__(
748750
xenium_nuc_seg: pd.DataFrame=None,
749751
xenium_cell_seg: pd.DataFrame=None,
750752
cell_adata: sc.AnnData=None, # type: ignore
751-
transcript_df: pd.DataFrame=None
753+
transcript_df: pd.DataFrame=None,
754+
dapi_path: str=None,
755+
alignment_file_path: str=None
752756
):
753757
"""
754758
class representing a single ST profile + its associated WSI image
@@ -765,16 +769,31 @@ class representing a single ST profile + its associated WSI image
765769
xenium_cell_seg (pd.DataFrame): content of a xenium cell contour file as a dataframe (cell_boundaries.parquet)
766770
cell_adata (sc.AnnData): ST cell data, each row in adata.obs is a cell, each row in obsm is the cell location on the H&E image in pixels
767771
transcript_df (pd.DataFrame): dataframe of transcripts, each row is a transcript, he_x and he_y is the transcript location on the H&E image in pixels
772+
dapi_path (str): path to a dapi focus image
773+
alignment_file_path (np.ndarray): path to xenium alignment path
768774
"""
769775
super().__init__(adata=adata, img=img, pixel_size=pixel_size, meta=meta, tissue_seg=tissue_seg, tissue_contours=tissue_contours, shapes=shapes)
770776

771777
self.xenium_nuc_seg = xenium_nuc_seg
772778
self.xenium_cell_seg = xenium_cell_seg
773779
self.cell_adata = cell_adata
774780
self.transcript_df = transcript_df
775-
776-
777-
def save(self, path: str, save_img=True, pyramidal=True, bigtiff=False, plot_pxl_size=False):
781+
self.dapi_path = dapi_path
782+
self.alignment_file_path = alignment_file_path
783+
784+
785+
def save(
786+
self,
787+
path: str,
788+
save_img=True,
789+
pyramidal=True,
790+
bigtiff=False,
791+
plot_pxl_size=False,
792+
save_transcripts=False,
793+
save_cell_seg=False,
794+
save_nuclei_seg=False,
795+
**kwargs
796+
):
778797
"""Save a HESTData object to `path` as follows:
779798
- aligned_adata.h5ad (contains expressions for each spots + their location on the fullres image + a downscaled version of the fullres image)
780799
- metrics.json (contains useful metrics)
@@ -795,21 +814,18 @@ def save(self, path: str, save_img=True, pyramidal=True, bigtiff=False, plot_pxl
795814
if self.cell_adata is not None:
796815
self.cell_adata.write_h5ad(os.path.join(path, 'aligned_cells.h5ad'))
797816

798-
if self.transcript_df is not None:
817+
if save_transcripts and self.transcript_df is not None:
799818
self.transcript_df.to_parquet(os.path.join(path, 'aligned_transcripts.parquet'))
819+
820+
if save_cell_seg:
821+
he_cells = self.get_shapes('tenx_cell', 'he').shapes
822+
he_cells.to_parquet(os.path.join(path, 'he_cell_seg.parquet'))
823+
write_geojson(he_cells, os.path.join(path, f'he_cell_seg.geojson'), '', chunk=True)
800824

801-
if self.xenium_nuc_seg is not None:
802-
print('Saving Xenium nucleus boundaries... (can be slow)')
803-
with open(os.path.join(path, 'nuclei_xenium.geojson'), 'w') as f:
804-
json.dump(self.xenium_nuc_seg, f, indent=4)
805-
806-
if self.xenium_cell_seg is not None:
807-
print('Saving Xenium cells boundaries... (can be slow)')
808-
with open(os.path.join(path, 'cells_xenium.geojson'), 'w') as f:
809-
json.dump(self.xenium_cell_seg, f, indent=4)
810-
811-
812-
# TODO save segmentation
825+
if save_nuclei_seg:
826+
he_nuclei = self.get_shapes('tenx_nucleus', 'he').shapes
827+
he_nuclei.to_parquet(os.path.join(path, 'he_nucleus_seg.parquet'))
828+
write_geojson(he_nuclei, os.path.join(path, f'he_nucleus_seg.geojson'), '', chunk=True)
813829

814830

815831
def read_HESTData(
@@ -936,19 +952,33 @@ def mask_and_patchify_bench(meta_df: pd.DataFrame, save_dir: str, use_mask=True,
936952
i += 1
937953

938954

939-
def create_benchmark_data(meta_df, save_dir:str, K, adata_folder, use_mask, keep_largest=None):
955+
def create_benchmark_data(meta_df, save_dir:str, K):
940956
os.makedirs(save_dir, exist_ok=True)
941-
if K is not None:
942-
splits = meta_df.groupby('patient')['id'].agg(list).to_dict()
943-
create_splits(os.path.join(save_dir, 'splits'), splits, K=K)
957+
958+
meta_df['patient'] = meta_df['patient'].fillna('Patient 1')
959+
960+
get_k_genes_from_df(meta_df, 50, 'var', os.path.join(save_dir, 'var_50genes.json'))
961+
962+
splits = meta_df.groupby(['dataset_title', 'patient'])['id'].agg(list).to_dict()
963+
create_splits(os.path.join(save_dir, 'splits'), splits, K=K)
944964

945965
os.makedirs(os.path.join(save_dir, 'patches'), exist_ok=True)
946-
mask_and_patchify_bench(meta_df, os.path.join(save_dir, 'patches'), use_mask=use_mask, keep_largest=keep_largest)
966+
#mask_and_patchify_bench(meta_df, os.path.join(save_dir, 'patches'), use_mask=use_mask, keep_largest=keep_largest)
947967

968+
os.makedirs(os.path.join(save_dir, 'patches_vis'), exist_ok=True)
948969
os.makedirs(os.path.join(save_dir, 'adata'), exist_ok=True)
949-
for index, row in meta_df.iterrows():
970+
for _, row in meta_df.iterrows():
950971
id = row['id']
951-
src_adata = os.path.join(adata_folder, id + '.h5ad')
972+
path = os.path.join(get_path_from_meta_row(row), 'processed')
973+
src_patch = os.path.join(path, 'patches.h5')
974+
dst_patch = os.path.join(save_dir, 'patches', id + '.h5')
975+
shutil.copy(src_patch, dst_patch)
976+
977+
src_vis = os.path.join(path, 'patches_patch_vis.png')
978+
dst_vis = os.path.join(save_dir, 'patches_vis', id + '.png')
979+
shutil.copy(src_vis, dst_vis)
980+
981+
src_adata = os.path.join(path, 'aligned_adata.h5ad')
952982
dst_adata = os.path.join(save_dir, 'adata', id + '.h5ad')
953983
shutil.copy(src_adata, dst_adata)
954984

@@ -1200,6 +1230,13 @@ def unify_gene_names(adata: sc.AnnData, species="human", drop=False) -> sc.AnnDa
12001230
mask = ~adata.var_names.duplicated(keep='first')
12011231
adata = adata[:, mask]
12021232

1233+
duplicated_genes_after = adata.var_names[adata.var_names.duplicated()]
1234+
if len(duplicated_genes_after) > len(duplicated_genes_before):
1235+
logger.warning(f"duplicated genes increased from {len(duplicated_genes_before)} to {len(duplicated_genes_after)} after resolving aliases")
1236+
logger.info('deduplicating...')
1237+
mask = ~adata.var_names.duplicated(keep='first')
1238+
adata = adata[:, mask]
1239+
12031240
if drop:
12041241
adata = adata[:, ~remaining]
12051242

src/hest/LazyShapes.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,30 @@
22
import pandas as pd
33
from shapely import Polygon
44

5-
from hest.io.seg_readers import read_gdf
5+
from hest.io.seg_readers import GDFReader, read_gdf
66
from hest.utils import verify_paths
77

88

99
class LazyShapes:
1010

1111
path: str = None
1212

13-
def __init__(self, path: str, name: str, coordinate_system: str):
13+
def __init__(self, path: str, name: str, coordinate_system: str, reader: GDFReader=None, reader_kwargs = {}):
1414
verify_paths([path])
1515
self.path = path
1616
self.name = name
1717
self.coordinate_system = coordinate_system
1818
self._shapes = None
19+
self.reader_kwargs = reader_kwargs
20+
self.reader = reader
1921

2022
def compute(self) -> None:
2123
if self._shapes is None:
22-
self._shapes = read_gdf(self.path)
24+
if self.reader is None:
25+
self._shapes = read_gdf(self.path, self.reader_kwargs)
26+
else:
27+
self._shapes = self.reader(**self.reader_kwargs).read_gdf(self.path)
28+
2329

2430
@property
2531
def shapes(self) -> gpd.GeoDataFrame:

src/hest/SlideReaderAdapter.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Slide Adapter class for Valis compatibility
2+
import os
3+
4+
import numpy as np
5+
from valis import slide_tools
6+
from valis.slide_io import PIXEL_UNIT, MetaData, SlideReader
7+
8+
from hestcore.wsi import wsi_factory
9+
10+
11+
class SlideReaderAdapter(SlideReader):
12+
def __init__(self, src_f, *args, **kwargs):
13+
super().__init__(src_f, *args, **kwargs)
14+
self.wsi = wsi_factory(src_f)
15+
self.metadata = self.create_metadata()
16+
17+
def create_metadata(self):
18+
meta_name = f"{os.path.split(self.src_f)[1]}_Series(0)".strip("_")
19+
slide_meta = MetaData(meta_name, 'SlideReaderAdapter')
20+
21+
slide_meta.is_rgb = True
22+
slide_meta.channel_names = self._get_channel_names('NO_NAME')
23+
slide_meta.n_channels = 1
24+
slide_meta.pixel_physical_size_xyu = [0.25, 0.25, PIXEL_UNIT]
25+
level_dim = self.wsi.level_dimensions() #self._get_slide_dimensions()
26+
slide_meta.slide_dimensions = np.array([list(item) for item in level_dim])
27+
28+
return slide_meta
29+
30+
def slide2vips(self, level, xywh=None, *args, **kwargs):
31+
img = self.slide2image(level, xywh=xywh, *args, **kwargs)
32+
vips_img = slide_tools.numpy2vips(img)
33+
34+
return vips_img
35+
36+
def slide2image(self, level, xywh=None, *args, **kwargs):
37+
level_dim = self.wsi.level_dimensions()[level]
38+
img = self.wsi.get_thumbnail(level_dim[0], level_dim[1])
39+
40+
if xywh is not None:
41+
xywh = np.array(xywh)
42+
start_c, start_r = xywh[0:2]
43+
end_c, end_r = xywh[0:2] + xywh[2:]
44+
img = img[start_r:end_r, start_c:end_c]
45+
46+
return img

src/hest/bench/st_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,4 @@ def load_adata(expr_path, genes = None, barcodes = None, normalize=False):
4242
adata = adata[:, genes]
4343
if normalize:
4444
adata = normalize_adata(adata)
45-
return adata.to_df()
45+
return adata.to_df()

0 commit comments

Comments
 (0)