1
+ import json
2
+ import warnings
3
+ from abc import abstractmethod
4
+
5
+ import geopandas as gpd
6
+ import numpy as np
7
+ import pandas as pd
8
+ from matplotlib import pyplot as plt
9
+ from shapely .geometry .polygon import Point , Polygon
10
+ from tqdm import tqdm
11
+
12
+
13
+ def _process (x , extra_props , index_key , class_name ):
14
+ from shapely .geometry .polygon import Point , Polygon
15
+
16
+ geom_type = x ['geometry' ]['type' ]
17
+ if geom_type == 'MultiPoint' :
18
+ coords = [Point (x ['geometry' ]['coordinates' ][i ]) for i in range (len (x ['geometry' ]['coordinates' ]))]
19
+ elif geom_type == 'MultiPolygon' :
20
+ coords = [Polygon (x ['geometry' ]['coordinates' ][i ][0 ]) for i in range (len (x ['geometry' ]['coordinates' ]))]
21
+ else :
22
+ raise ValueError ("Doesn't recognize type {geom_type}, must be either MultiPoint or MultiPolygon" )
23
+
24
+ name = x ['properties' ]['classification' ]['name' ]
25
+
26
+ gdf = gpd .GeoDataFrame (geometry = coords )
27
+
28
+ class_index = 'class' if not class_name else class_name
29
+ gdf [class_index ] = [name for _ in range (len (gdf ))]
30
+
31
+ if index_key is not None :
32
+ indices = x ['properties' ][index_key ]
33
+ values = np .zeros (len (x ['geometry' ]['coordinates' ]), dtype = bool )
34
+ values [indices ] = True
35
+ gdf [index_key ] = values
36
+
37
+ if extra_props :
38
+ extra_props = [k for k in x ['properties' ].keys () if k not in ['objectType' , 'classification' ]]
39
+ for prop in extra_props :
40
+ val = x ['properties' ][prop ]
41
+ gdf [prop ] = [val for _ in range (len (gdf ))]
42
+
43
+ return gdf
44
+
45
+
46
+ def _read_geojson (path , class_name = None , extra_props = False , index_key = None ) -> gpd .GeoDataFrame :
47
+ with open (path ) as f :
48
+ ls = json .load (f )
49
+
50
+ sub_gdfs = []
51
+ for x in tqdm (ls ):
52
+ sub_gdfs .append (_process (x , extra_props , index_key , class_name ))
53
+
54
+ gdf = gpd .GeoDataFrame (pd .concat (sub_gdfs , ignore_index = True ))
55
+
56
+ return gdf
57
+
58
+
59
+ class GDFReader :
60
+ @abstractmethod
61
+ def read_gdf (self , path ) -> gpd .GeoDataFrame :
62
+ pass
63
+
64
+
65
+ class XeniumParquetCellReader (GDFReader ):
66
+
67
+ def read_gdf (self , path ) -> gpd .GeoDataFrame :
68
+
69
+ df = pd .read_parquet (path )
70
+
71
+ df ['xy' ] = list (zip (df ['vertex_x' ], df ['vertex_y' ]))
72
+ df = df .drop (['vertex_x' , 'vertex_y' ], axis = 1 )
73
+
74
+ df = df .groupby ('cell_id' ).agg ({
75
+ 'xy' : Polygon
76
+ }).reset_index ()
77
+
78
+ gdf = gpd .GeoDataFrame (df , geometry = df ['xy' ])
79
+ gdf = gdf .drop (['xy' ], axis = 1 )
80
+ return gdf
81
+
82
+
83
+ class GDFParquetCellReader (GDFReader ):
84
+
85
+ def read_gdf (self , path ) -> gpd .GeoDataFrame :
86
+ return gpd .read_parquet (path )
87
+
88
+
89
+ class GeojsonCellReader (GDFReader ):
90
+
91
+ def read_gdf (self , path ) -> gpd .GeoDataFrame :
92
+ gdf = _read_geojson (path )
93
+ gdf ['cell_id' ] = np .arange (len (gdf ))
94
+
95
+ return gdf
96
+
97
+
98
+ class TissueContourReader (GDFReader ):
99
+
100
+ def read_gdf (self , path ) -> gpd .GeoDataFrame :
101
+
102
+ gdf = _read_geojson (path , class_name = 'tissue_id' , index_key = 'hole' )
103
+
104
+ return gdf
105
+
106
+
107
+ def write_geojson (gdf : gpd .GeoDataFrame , path : str , category_key : str , extra_prop = False , uniform_prop = True , index_key : str = None ) -> None :
108
+
109
+ if isinstance (gdf .geometry .iloc [0 ], Point ):
110
+ geometry = 'MultiPoint'
111
+ elif isinstance (gdf .geometry .iloc [0 ], Polygon ):
112
+ geometry = 'MultiPolygon'
113
+ else :
114
+ raise ValueError (f"gdf.geometry[0] must be of type Point or Polygon, got { type (gdf .geometry .iloc [0 ])} " )
115
+
116
+ groups = np .unique (gdf [category_key ])
117
+ colors = generate_colors (groups )
118
+ cells = []
119
+ for group in tqdm (groups ):
120
+
121
+ slice = gdf [gdf [category_key ] == group ]
122
+ shapes = slice .geometry
123
+
124
+ properties = {
125
+ "objectType" : "annotation" ,
126
+ "classification" : {
127
+ "name" : str (group ),
128
+ "color" : colors [group ]
129
+ }
130
+ }
131
+
132
+ if extra_prop :
133
+ props = {}
134
+ col_exclude = [category_key , 'geometry' ]
135
+ if index_key is not None :
136
+ col_exclude .append (index_key )
137
+ for col in [c for c in gdf .columns if c not in col_exclude ]:
138
+ if uniform_prop :
139
+ unique = np .unique (slice [col ])
140
+ if len (unique ) != 1 :
141
+ warnings .warn (f"extra property { col } is not uniform for group { group } , found { unique } " )
142
+ props [col ] = slice [col ].iloc [0 ]
143
+
144
+ properties = {** properties , ** props }
145
+
146
+ if index_key is not None :
147
+ key = index_key
148
+ props = {}
149
+ mask = (slice [key ] == True ).values
150
+ props = {key : np .arange (len (mask ))[mask ].tolist ()}
151
+ properties = {** properties , ** props }
152
+
153
+ if isinstance (gdf .geometry .iloc [0 ], Point ):
154
+ shapes = [[point .x , point .y ] for point in shapes ]
155
+ elif isinstance (gdf .geometry .iloc [0 ], Polygon ):
156
+ shapes = [[[[x , y ] for x , y in polygon .exterior .coords ]] for polygon in shapes ]
157
+ cell = {
158
+ 'type' : 'Feature' ,
159
+ 'id' : (str (id (path )) + '-id-' + str (group )).replace ('.' , '-' ),
160
+ 'geometry' : {
161
+ 'type' : geometry ,
162
+ 'coordinates' : shapes
163
+ },
164
+ "properties" : properties
165
+ }
166
+ cells .append (cell )
167
+
168
+ with open (path , 'w' ) as f :
169
+ json .dump (cells , f , indent = 4 )
170
+
171
+
172
+
173
+ def generate_colors (names ):
174
+ colors = plt .get_cmap ('hsv' , len (names ))
175
+ color_dict = {}
176
+ for i in range (len (names )):
177
+ rgb = colors (i )[:3 ]
178
+ rgb = [int (255 * c ) for c in rgb ]
179
+ color_dict [names [i ]] = rgb
180
+ return color_dict
181
+
182
+
183
+ def read_parquet_schema_df (path : str ) -> pd .DataFrame :
184
+ """Return a Pandas dataframe corresponding to the schema of a local URI of a parquet file.
185
+
186
+ The returned dataframe has the columns: column, pa_dtype
187
+ """
188
+ import pyarrow .parquet
189
+
190
+ # Ref: https://stackoverflow.com/a/64288036/
191
+ schema = pyarrow .parquet .read_schema (path , memory_map = True )
192
+ schema = pd .DataFrame (({"column" : name , "pa_dtype" : str (pa_dtype )} for name , pa_dtype in zip (schema .names , schema .types )))
193
+ schema = schema .reindex (columns = ["column" , "pa_dtype" ], fill_value = pd .NA ) # Ensures columns in case the parquet file has an empty dataframe.
194
+ return schema
195
+
196
+
197
+ def cell_reader_factory (path ) -> GDFReader :
198
+ if path .endswith ('.geojson' ):
199
+ return GeojsonCellReader ()
200
+ elif path .endswith ('.parquet' ):
201
+ schema = read_parquet_schema_df (path )
202
+ if 'geometry' in schema ['column' ].values :
203
+ return GDFParquetCellReader ()
204
+ else :
205
+ return XeniumParquetCellReader ()
206
+ else :
207
+ ext = path .split ('.' )[- 1 ]
208
+ raise ValueError (f'Unknown file extension { ext } for a cell segmentation file, needs to be .geojson or .parquet' )
209
+
210
+
211
+ def read_gdf (path ) -> gpd .GeoDataFrame :
212
+ return cell_reader_factory (path ).read_gdf (path )
0 commit comments