8
8
from shapely .geometry import LineString , Polygon , box
9
9
from tqdm import tqdm
10
10
11
- from cellseg_models_pytorch .utils .spatial_ops import get_objs
12
-
13
11
__all__ = ["InstMerger" ]
14
12
15
13
@@ -30,12 +28,17 @@ def __init__(
30
28
self .grid = gpd .GeoDataFrame ({"geometry" : polygons })
31
29
self .gdf = gdf
32
30
33
- def merge (self , dst : str = None ) -> Union [gpd .GeoDataFrame , None ]:
31
+ def merge (
32
+ self , dst : str = None , simplify_level : int = 1
33
+ ) -> Union [gpd .GeoDataFrame , None ]:
34
34
"""Merge the instances at the image boundaries.
35
35
36
36
Parameters:
37
37
dst (str):
38
38
The destination directory to save the merged instances.
39
+ If None, the merged GeoDataFrame is returned.
40
+ simplify_level (int, default=1):
41
+ The level of simplification to apply to the merged instances.
39
42
40
43
Returns:
41
44
Union[gpd.GeoDataFrame, None]:
@@ -63,6 +66,7 @@ def merge(self, dst: str = None) -> Union[gpd.GeoDataFrame, None]:
63
66
merged = pd .concat ([merge_obj_x , merge_obj_y , non_boundary_objs ]).reset_index (
64
67
drop = True
65
68
)
69
+ merged .geometry = merged .geometry .simplify (simplify_level )
66
70
67
71
if dst is not None :
68
72
if suff == ".parquet" :
@@ -117,6 +121,21 @@ def _merge_boundary_objs(
117
121
118
122
return merged .reset_index (drop = True )
119
123
124
+ def _get_objs (
125
+ self ,
126
+ objects : gpd .GeoDataFrame ,
127
+ area : gpd .GeoDataFrame ,
128
+ predicate : str ,
129
+ ** kwargs ,
130
+ ) -> gpd .GeoDataFrame :
131
+ """Get the objects that intersect with the midline."""
132
+ inds = objects .geometry .sindex .query (
133
+ area .geometry , predicate = predicate , ** kwargs
134
+ )
135
+ objs : gpd .GeoDataFrame = objects .iloc [np .unique (inds )[1 :]]
136
+
137
+ return objs .drop_duplicates ("geometry" )
138
+
120
139
def _merge_objs_axis (
121
140
self ,
122
141
grid : gpd .GeoDataFrame ,
@@ -148,7 +167,7 @@ def _merge_objs_axis(
148
167
desc = "Merging objects (x-axis)" if axis == "x" else "Merging objects (y-axis)"
149
168
for start , next_col in tqdm (grouped_list [1 :], desc = desc ):
150
169
grid_union = pd .concat ([last_col , next_col ]).union_all ()
151
- objs = get_objs ( self ._union_to_gdf (grid_union ), gdf , predicate )
170
+ objs = self . _get_objs ( gdf , self ._union_to_gdf (grid_union ), predicate )
152
171
153
172
minx , miny , maxx , maxy = next_col .total_bounds
154
173
@@ -164,12 +183,12 @@ def _merge_objs_axis(
164
183
midline_gdf = gpd .GeoDataFrame (geometry = [midline ])
165
184
166
185
# get the cells hitting the midline
167
- boundary_objs = get_objs ( midline_gdf , objs , predicate )
186
+ boundary_objs = self . _get_objs ( objs , midline_gdf , predicate )
168
187
169
188
non_boundary_objs_left = None
170
189
if get_non_boundary_objs :
171
- non_boundary_objs_left = get_objs (
172
- last_col .buffer (- midline_buffer ), objs , "contains"
190
+ non_boundary_objs_left = self . _get_objs (
191
+ objs , last_col .buffer (- midline_buffer ), "contains"
173
192
)
174
193
175
194
# merge the boundary objects
@@ -211,15 +230,11 @@ def _get_classes(
211
230
class_names = []
212
231
for ix , row in merged .iterrows ():
213
232
area = gpd .GeoDataFrame (geometry = [row .geometry ])
214
- objs = get_objs ( area , non_merged , predicate = "intersects" )
233
+ objs = self . _get_objs ( non_merged , area , predicate = "intersects" )
215
234
216
235
if objs .empty :
217
236
continue
218
237
219
- # HACK: for some reason the first object is always the same
220
- if len (objs ) > 1 and ix != 0 :
221
- objs = objs .drop (index = 0 )
222
-
223
238
class_names .append (objs .loc [objs .area .idxmax ()]["class_name" ])
224
239
225
240
return class_names
0 commit comments