Skip to content

Commit c8b11fb

Browse files
authored
HEST-lib v1.1.0 (#25)
* start implementing lazy shapes * refactoring of hest data shape mechanism * refactor tests and add seg_readers * cleanup * minor changes 2 * speed up segmentation * speed up tissue segmentation * fix segmentation and refactor wsi handling * refactoring of tissue segmentation * fix spatial data cellvit conversion * add spatialdata test * fix geojson contours and update tutorial 2 * cleanup * read .parquet cellvit and refactor seg readers * optimize imports * remove kwimage dependency * only warn once for cucim import * correct cucim circular import * warn_cucim not defined * use singleton warn cucim * update tutorials
1 parent 3b244af commit c8b11fb

15 files changed

+2029
-524
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ dependencies = [
1313
"ultralytics >= 8.2.4",
1414
"pyvips >= 2.2.3",
1515
"scanpy >= 1.10.1",
16-
"kwimage >= 0.9.25",
1716
"imagecodecs >= 2024.1.1",
1817
"loguru >= 0.7.2",
1918
"timm >= 0.9.16",
@@ -27,7 +26,8 @@ dependencies = [
2726
"spatialdata >= 0.1.2",
2827
"dask >= 2024.2.1",
2928
"spatial_image >= 0.3.0",
30-
"datasets"
29+
"datasets",
30+
"mygene"
3131
]
3232

3333
requires-python = ">=3.9"

src/hest/HESTData.py

Lines changed: 342 additions & 180 deletions
Large diffs are not rendered by default.

src/hest/LazyShapes.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import geopandas as gpd
2+
import pandas as pd
3+
from shapely import Polygon
4+
5+
from hest.io.seg_readers import read_gdf
6+
from hest.utils import verify_paths
7+
8+
9+
class LazyShapes:
10+
11+
path: str = None
12+
13+
def __init__(self, path: str, name: str, coordinate_system: str):
14+
verify_paths([path])
15+
self.path = path
16+
self.name = name
17+
self.coordinate_system = coordinate_system
18+
self._shapes = None
19+
20+
def compute(self) -> None:
21+
if self._shapes is None:
22+
self._shapes = read_gdf(self.path)
23+
24+
@property
25+
def shapes(self) -> gpd.GeoDataFrame:
26+
if self._shapes is None:
27+
self.compute()
28+
29+
return self._shapes
30+
31+
def __repr__(self) -> str:
32+
sup_rep = super().__repr__()
33+
34+
loaded_rep = 'loaded' if self._shapes is not None else 'not loaded'
35+
36+
rep = f"""name: {self.name}, coord-system: {self.coordinate_system}, <{loaded_rep}>"""
37+
return rep
38+
39+
40+
def convert_old_to_gpd(contours_holes, contours_tissue) -> gpd.GeoDataFrame:
41+
assert len(contours_holes) == len(contours_tissue)
42+
43+
shapes = []
44+
tissue_ids = []
45+
types = []
46+
for i in range(len(contours_holes)):
47+
tissue = contours_tissue[i]
48+
shapes.append(Polygon(tissue[:, 0, :]))
49+
tissue_ids.append(i)
50+
types.append('tissue')
51+
holes = contours_holes[i]
52+
if len(holes) > 0:
53+
for hole in holes:
54+
shapes.append(Polygon(hole[:, 0, :]))
55+
tissue_ids.append(i)
56+
types.append('hole')
57+
58+
df = pd.DataFrame(tissue_ids, columns=['tissue_id'])
59+
df['hole'] = types
60+
df['hole'] = df['hole'] == 'hole'
61+
62+
return gpd.GeoDataFrame(df, geometry=shapes)
63+

src/hest/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .autoalign import autoalign_visium
55
from .readers import *
66
from .HESTData import HESTData, read_HESTData, load_hest
7+
from .segmentation.cell_segmenters import segment_cellvit
78

89
__all__ = [
910
'tiff_save',
@@ -18,5 +19,6 @@
1819
'STReader',
1920
'autoalign_visium',
2021
'write_10X_h5',
21-
'HESTData'
22+
'HESTData',
23+
'segment_cellvit'
2224
]

src/hest/autoalign.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@
55
import cv2
66
import matplotlib.collections as mc
77
import matplotlib.patches as patches
8-
import matplotlib.pyplot as plt
98
import numpy as np
10-
from kwimage.im_cv2 import imresize
11-
from ultralytics import YOLO
129

1310
from hest.utils import get_path_relative
1411

@@ -89,7 +86,7 @@ def _spots_to_file(path, dict):
8986
def _resize_to_target(img):
9087
TARGET_PIXEL_EDGE = 1000
9188
downscale_factor = TARGET_PIXEL_EDGE / np.max(img.shape)
92-
downscaled_fullres = imresize(img, downscale_factor)
89+
downscaled_fullres = cv2.resize(img, (round(img.shape[1] * downscale_factor), round(img.shape[0] * downscale_factor)))
9390
return downscaled_fullres, downscale_factor
9491

9592

@@ -101,6 +98,9 @@ def _alignment_plot_to_file(boxes_to_match,
10198
aligned_fiducials,
10299
img,
103100
save_path):
101+
102+
import matplotlib.pyplot as plt
103+
104104
fig, ax = plt.subplots()
105105

106106
i = 0
@@ -198,6 +198,8 @@ def autoalign_visium(fullres_img: np.ndarray, save_dir: str=None, name='') -> Di
198198
Returns:
199199
Dict: spot alignment as a dictionary
200200
"""
201+
from ultralytics import YOLO
202+
201203
path_model = get_path_relative(__file__, '../../models/visium_yolov8_v1.pt')
202204
model = YOLO(path_model)
203205

src/hest/io/seg_readers.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
import json
2+
import warnings
3+
from abc import abstractmethod
4+
5+
import geopandas as gpd
6+
import numpy as np
7+
import pandas as pd
8+
from matplotlib import pyplot as plt
9+
from shapely.geometry.polygon import Point, Polygon
10+
from tqdm import tqdm
11+
12+
13+
def _process(x, extra_props, index_key, class_name):
14+
from shapely.geometry.polygon import Point, Polygon
15+
16+
geom_type = x['geometry']['type']
17+
if geom_type == 'MultiPoint':
18+
coords = [Point(x['geometry']['coordinates'][i]) for i in range(len(x['geometry']['coordinates']))]
19+
elif geom_type == 'MultiPolygon':
20+
coords = [Polygon(x['geometry']['coordinates'][i][0]) for i in range(len(x['geometry']['coordinates']))]
21+
else:
22+
raise ValueError("Doesn't recognize type {geom_type}, must be either MultiPoint or MultiPolygon")
23+
24+
name = x['properties']['classification']['name']
25+
26+
gdf = gpd.GeoDataFrame(geometry=coords)
27+
28+
class_index = 'class' if not class_name else class_name
29+
gdf[class_index] = [name for _ in range(len(gdf))]
30+
31+
if index_key is not None:
32+
indices = x['properties'][index_key]
33+
values = np.zeros(len(x['geometry']['coordinates']), dtype=bool)
34+
values[indices] = True
35+
gdf[index_key] = values
36+
37+
if extra_props:
38+
extra_props = [k for k in x['properties'].keys() if k not in ['objectType', 'classification']]
39+
for prop in extra_props:
40+
val = x['properties'][prop]
41+
gdf[prop] = [val for _ in range(len(gdf))]
42+
43+
return gdf
44+
45+
46+
def _read_geojson(path, class_name=None, extra_props=False, index_key=None) -> gpd.GeoDataFrame:
47+
with open(path) as f:
48+
ls = json.load(f)
49+
50+
sub_gdfs = []
51+
for x in tqdm(ls):
52+
sub_gdfs.append(_process(x, extra_props, index_key, class_name))
53+
54+
gdf = gpd.GeoDataFrame(pd.concat(sub_gdfs, ignore_index=True))
55+
56+
return gdf
57+
58+
59+
class GDFReader:
60+
@abstractmethod
61+
def read_gdf(self, path) -> gpd.GeoDataFrame:
62+
pass
63+
64+
65+
class XeniumParquetCellReader(GDFReader):
66+
67+
def read_gdf(self, path) -> gpd.GeoDataFrame:
68+
69+
df = pd.read_parquet(path)
70+
71+
df['xy'] = list(zip(df['vertex_x'], df['vertex_y']))
72+
df = df.drop(['vertex_x', 'vertex_y'], axis=1)
73+
74+
df = df.groupby('cell_id').agg({
75+
'xy': Polygon
76+
}).reset_index()
77+
78+
gdf = gpd.GeoDataFrame(df, geometry=df['xy'])
79+
gdf = gdf.drop(['xy'], axis=1)
80+
return gdf
81+
82+
83+
class GDFParquetCellReader(GDFReader):
84+
85+
def read_gdf(self, path) -> gpd.GeoDataFrame:
86+
return gpd.read_parquet(path)
87+
88+
89+
class GeojsonCellReader(GDFReader):
90+
91+
def read_gdf(self, path) -> gpd.GeoDataFrame:
92+
gdf = _read_geojson(path)
93+
gdf['cell_id'] = np.arange(len(gdf))
94+
95+
return gdf
96+
97+
98+
class TissueContourReader(GDFReader):
99+
100+
def read_gdf(self, path) -> gpd.GeoDataFrame:
101+
102+
gdf = _read_geojson(path, class_name='tissue_id', index_key='hole')
103+
104+
return gdf
105+
106+
107+
def write_geojson(gdf: gpd.GeoDataFrame, path: str, category_key: str, extra_prop=False, uniform_prop=True, index_key: str=None) -> None:
108+
109+
if isinstance(gdf.geometry.iloc[0], Point):
110+
geometry = 'MultiPoint'
111+
elif isinstance(gdf.geometry.iloc[0], Polygon):
112+
geometry = 'MultiPolygon'
113+
else:
114+
raise ValueError(f"gdf.geometry[0] must be of type Point or Polygon, got {type(gdf.geometry.iloc[0])}")
115+
116+
groups = np.unique(gdf[category_key])
117+
colors = generate_colors(groups)
118+
cells = []
119+
for group in tqdm(groups):
120+
121+
slice = gdf[gdf[category_key] == group]
122+
shapes = slice.geometry
123+
124+
properties = {
125+
"objectType": "annotation",
126+
"classification": {
127+
"name": str(group),
128+
"color": colors[group]
129+
}
130+
}
131+
132+
if extra_prop:
133+
props = {}
134+
col_exclude = [category_key, 'geometry']
135+
if index_key is not None:
136+
col_exclude.append(index_key)
137+
for col in [c for c in gdf.columns if c not in col_exclude]:
138+
if uniform_prop:
139+
unique = np.unique(slice[col])
140+
if len(unique) != 1:
141+
warnings.warn(f"extra property {col} is not uniform for group {group}, found {unique}")
142+
props[col] = slice[col].iloc[0]
143+
144+
properties = {**properties, **props}
145+
146+
if index_key is not None:
147+
key = index_key
148+
props = {}
149+
mask = (slice[key] == True).values
150+
props = {key: np.arange(len(mask))[mask].tolist()}
151+
properties = {**properties, **props}
152+
153+
if isinstance(gdf.geometry.iloc[0], Point):
154+
shapes = [[point.x, point.y] for point in shapes]
155+
elif isinstance(gdf.geometry.iloc[0], Polygon):
156+
shapes = [[[[x, y] for x, y in polygon.exterior.coords]] for polygon in shapes]
157+
cell = {
158+
'type': 'Feature',
159+
'id': (str(id(path)) + '-id-' + str(group)).replace('.', '-'),
160+
'geometry': {
161+
'type': geometry,
162+
'coordinates': shapes
163+
},
164+
"properties": properties
165+
}
166+
cells.append(cell)
167+
168+
with open(path, 'w') as f:
169+
json.dump(cells, f, indent=4)
170+
171+
172+
173+
def generate_colors(names):
174+
colors = plt.get_cmap('hsv', len(names))
175+
color_dict = {}
176+
for i in range(len(names)):
177+
rgb = colors(i)[:3]
178+
rgb = [int(255 * c) for c in rgb]
179+
color_dict[names[i]] = rgb
180+
return color_dict
181+
182+
183+
def read_parquet_schema_df(path: str) -> pd.DataFrame:
184+
"""Return a Pandas dataframe corresponding to the schema of a local URI of a parquet file.
185+
186+
The returned dataframe has the columns: column, pa_dtype
187+
"""
188+
import pyarrow.parquet
189+
190+
# Ref: https://stackoverflow.com/a/64288036/
191+
schema = pyarrow.parquet.read_schema(path, memory_map=True)
192+
schema = pd.DataFrame(({"column": name, "pa_dtype": str(pa_dtype)} for name, pa_dtype in zip(schema.names, schema.types)))
193+
schema = schema.reindex(columns=["column", "pa_dtype"], fill_value=pd.NA) # Ensures columns in case the parquet file has an empty dataframe.
194+
return schema
195+
196+
197+
def cell_reader_factory(path) -> GDFReader:
198+
if path.endswith('.geojson'):
199+
return GeojsonCellReader()
200+
elif path.endswith('.parquet'):
201+
schema = read_parquet_schema_df(path)
202+
if 'geometry' in schema['column'].values:
203+
return GDFParquetCellReader()
204+
else:
205+
return XeniumParquetCellReader()
206+
else:
207+
ext = path.split('.')[-1]
208+
raise ValueError(f'Unknown file extension {ext} for a cell segmentation file, needs to be .geojson or .parquet')
209+
210+
211+
def read_gdf(path) -> gpd.GeoDataFrame:
212+
return cell_reader_factory(path).read_gdf(path)

0 commit comments

Comments
 (0)