Skip to content

Commit ac0d174

Browse files
committed
feat: add simplify param to merging
1 parent 344010e commit ac0d174

File tree

1 file changed

+27
-12
lines changed

1 file changed

+27
-12
lines changed

cellseg_models_pytorch/wsi/inst_merger.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
from shapely.geometry import LineString, Polygon, box
99
from tqdm import tqdm
1010

11-
from cellseg_models_pytorch.utils.spatial_ops import get_objs
12-
1311
__all__ = ["InstMerger"]
1412

1513

@@ -30,12 +28,17 @@ def __init__(
3028
self.grid = gpd.GeoDataFrame({"geometry": polygons})
3129
self.gdf = gdf
3230

33-
def merge(self, dst: str = None) -> Union[gpd.GeoDataFrame, None]:
31+
def merge(
32+
self, dst: str = None, simplify_level: int = 1
33+
) -> Union[gpd.GeoDataFrame, None]:
3434
"""Merge the instances at the image boundaries.
3535
3636
Parameters:
3737
dst (str):
3838
The destination directory to save the merged instances.
39+
If None, the merged GeoDataFrame is returned.
40+
simplify_level (int, default=1):
41+
The level of simplification to apply to the merged instances.
3942
4043
Returns:
4144
Union[gpd.GeoDataFrame, None]:
@@ -63,6 +66,7 @@ def merge(self, dst: str = None) -> Union[gpd.GeoDataFrame, None]:
6366
merged = pd.concat([merge_obj_x, merge_obj_y, non_boundary_objs]).reset_index(
6467
drop=True
6568
)
69+
merged.geometry = merged.geometry.simplify(simplify_level)
6670

6771
if dst is not None:
6872
if suff == ".parquet":
@@ -117,6 +121,21 @@ def _merge_boundary_objs(
117121

118122
return merged.reset_index(drop=True)
119123

124+
def _get_objs(
125+
self,
126+
objects: gpd.GeoDataFrame,
127+
area: gpd.GeoDataFrame,
128+
predicate: str,
129+
**kwargs,
130+
) -> gpd.GeoDataFrame:
131+
"""Get the objects that intersect with the midline."""
132+
inds = objects.geometry.sindex.query(
133+
area.geometry, predicate=predicate, **kwargs
134+
)
135+
objs: gpd.GeoDataFrame = objects.iloc[np.unique(inds)[1:]]
136+
137+
return objs.drop_duplicates("geometry")
138+
120139
def _merge_objs_axis(
121140
self,
122141
grid: gpd.GeoDataFrame,
@@ -148,7 +167,7 @@ def _merge_objs_axis(
148167
desc = "Merging objects (x-axis)" if axis == "x" else "Merging objects (y-axis)"
149168
for start, next_col in tqdm(grouped_list[1:], desc=desc):
150169
grid_union = pd.concat([last_col, next_col]).union_all()
151-
objs = get_objs(self._union_to_gdf(grid_union), gdf, predicate)
170+
objs = self._get_objs(gdf, self._union_to_gdf(grid_union), predicate)
152171

153172
minx, miny, maxx, maxy = next_col.total_bounds
154173

@@ -164,12 +183,12 @@ def _merge_objs_axis(
164183
midline_gdf = gpd.GeoDataFrame(geometry=[midline])
165184

166185
# get the cells hitting the midline
167-
boundary_objs = get_objs(midline_gdf, objs, predicate)
186+
boundary_objs = self._get_objs(objs, midline_gdf, predicate)
168187

169188
non_boundary_objs_left = None
170189
if get_non_boundary_objs:
171-
non_boundary_objs_left = get_objs(
172-
last_col.buffer(-midline_buffer), objs, "contains"
190+
non_boundary_objs_left = self._get_objs(
191+
objs, last_col.buffer(-midline_buffer), "contains"
173192
)
174193

175194
# merge the boundary objects
@@ -211,15 +230,11 @@ def _get_classes(
211230
class_names = []
212231
for ix, row in merged.iterrows():
213232
area = gpd.GeoDataFrame(geometry=[row.geometry])
214-
objs = get_objs(area, non_merged, predicate="intersects")
233+
objs = self._get_objs(non_merged, area, predicate="intersects")
215234

216235
if objs.empty:
217236
continue
218237

219-
# HACK: for some reason the first object is always the same
220-
if len(objs) > 1 and ix != 0:
221-
objs = objs.drop(index=0)
222-
223238
class_names.append(objs.loc[objs.area.idxmax()]["class_name"])
224239

225240
return class_names

0 commit comments

Comments
 (0)