Skip to content

Commit 7069d4e

Browse files
committed
Merge branch 'master' into init_weights
2 parents 8b0c1e2 + 298492f commit 7069d4e

File tree

9 files changed

+124
-107
lines changed

9 files changed

+124
-107
lines changed

examples/mineral-extract-sites-detection/config_trne.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,5 @@ assess_detections.py:
9292
tst: tst_detections_at_0dot05_threshold.gpkg
9393
output_folder: .
9494
iou_threshold: 0.1
95-
area_threshold: 50 # area under which the polygons are discarded from assessment
95+
area_threshold: 50 # area under which the polygons are discarded from assessment
96+
metrics_method: micro-average # 1: macro-average ; 3: macro-weighted-average ; 2: micro-average

examples/mineral-extract-sites-detection/prepare_data.py

Lines changed: 41 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,14 @@ def bbox(bounds):
171171
labels_4326_gdf = labels_gdf.to_crs(epsg=4326).drop_duplicates(subset=['geometry', 'year'])
172172
else:
173173
labels_4326_gdf = labels_gdf.to_crs(epsg=4326).drop_duplicates(subset=['geometry'])
174+
nb_labels = len(labels_gdf)
175+
logger.info(f'There are {nb_labels} polygons in {SHPFILE}')
176+
174177
labels_4326_gdf['CATEGORY'] = 'quarry'
175178
labels_4326_gdf['SUPERCATEGORY'] = 'land usage'
179+
176180
gt_labels_4326_gdf = labels_4326_gdf.copy()
177181

178-
nb_labels = len(labels_gdf)
179-
logger.info(f'There are {nb_labels} polygons in {SHPFILE}')
180-
181182
label_filename = 'labels.geojson'
182183
label_filepath = os.path.join(OUTPUT_DIR, label_filename)
183184
labels_4326_gdf.to_file(label_filepath, driver='GeoJSON')
@@ -206,14 +207,14 @@ def bbox(bounds):
206207
logger.success(f"{DONE_MSG} A file was written: {filepath}")
207208
labels_4326_gdf = pd.concat([labels_4326_gdf, fp_labels_4326_gdf], ignore_index=True)
208209

209-
# Get the label boundaries (minx, miny, maxx, maxy)
210+
# Tiling of the AoI
210211
logger.info("- Get the label boundaries")
211212
boundaries_df = labels_4326_gdf.bounds
213+
logger.info("- Tiling of the AoI")
214+
tiles_4326_aoi_gdf = aoi_tiling(boundaries_df)
215+
tiles_4326_labels_gdf = gpd.sjoin(tiles_4326_aoi_gdf, labels_4326_gdf, how='inner', predicate='intersects')
212216

213-
# Get the global boundaries for all the labels (minx, miny, maxx, maxy)
214-
labels_bbox = bbox(labels_4326_gdf.iloc[0].geometry.bounds)
215-
216-
# Get tiles for a given AoI from which empty tiles will be selected
217+
# Tiling of the AoI from which empty tiles will be selected
217218
if EPT_SHPFILE:
218219
EPT_aoi_gdf = gpd.read_file(EPT_SHPFILE)
219220
EPT_aoi_4326_gdf = EPT_aoi_gdf.to_crs(epsg=4326)
@@ -223,80 +224,55 @@ def bbox(bounds):
223224
logger.info("- Get AoI boundaries")
224225
EPT_aoi_boundaries_df = EPT_aoi_4326_gdf.bounds
225226

226-
# Get the boundaries for all the AoI (minx, miny, maxx, maxy)
227-
aoi_bbox = bbox(EPT_aoi_4326_gdf.iloc[0].geometry.bounds)
228-
aoi_bbox_contains = aoi_bbox.contains(labels_bbox)
229-
230-
if aoi_bbox_contains:
231-
logger.info("- The surface area occupied by the bbox of the AoI used to find empty tiles is bigger than the label's one. The AoI boundaries will be used for tiling")
232-
boundaries_df = EPT_aoi_boundaries_df.copy()
233-
else:
234-
logger.info("- The surface area occupied by the bbox of the AoI used to find empty tiles is smaller than the label's one. Both the AoI and labels area will be used for tiling")
235-
# Get tiles coordinates and shapes
236-
empty_tiles_4326_all_gdf = aoi_tiling(EPT_aoi_boundaries_df)
237-
# Delete tiles outside of the AoI limits
238-
empty_tiles_4326_aoi_gdf = gpd.sjoin(empty_tiles_4326_all_gdf, EPT_aoi_4326_gdf, how='inner', lsuffix='ept_tiles', rsuffix='ept_aoi')
239-
# Attribute a year to empty tiles if necessary
240-
if 'year' in labels_4326_gdf.keys():
241-
if isinstance(EPT_YEAR, int):
242-
empty_tiles_4326_aoi_gdf['year'] = int(EPT_YEAR)
243-
else:
244-
empty_tiles_4326_aoi_gdf['year'] = np.random.randint(low=EPT_YEAR[0], high=EPT_YEAR[1], size=(len(empty_tiles_4326_aoi_gdf)))
245-
elif EPT_SHPFILE and EPT_YEAR:
246-
logger.warning("No year column in the label shapefile. The provided empty tile year will be ignored.")
227+
# Get tile coordinates and shapes
228+
logger.info("- Tiling of the empty tiles AoI")
229+
empty_tiles_4326_all_gdf = aoi_tiling(EPT_aoi_boundaries_df)
230+
# Delete tiles outside of the AoI limits
231+
empty_tiles_4326_aoi_gdf = gpd.sjoin(empty_tiles_4326_all_gdf, EPT_aoi_4326_gdf, how='inner', lsuffix='ept_tiles', rsuffix='ept_aoi')
232+
# Attribute a year to empty tiles if necessary
233+
if 'year' in labels_4326_gdf.keys():
234+
if isinstance(EPT_YEAR, int):
235+
empty_tiles_4326_aoi_gdf['year'] = int(EPT_YEAR)
236+
else:
237+
empty_tiles_4326_aoi_gdf['year'] = np.random.randint(low=EPT_YEAR[0], high=EPT_YEAR[1], size=(len(empty_tiles_4326_aoi_gdf)))
238+
elif EPT_SHPFILE and EPT_YEAR:
239+
logger.warning("No year column in the label shapefile. The provided empty tile year will be ignored.")
247240
elif EPT == 'shp':
248241
if EPT_YEAR:
249242
logger.warning("A shapefile of selected empty tiles are provided. The year set for the empty tiles in the configuration file will be ignored")
250243
EPT_YEAR = None
251244
empty_tiles_4326_aoi_gdf = EPT_aoi_4326_gdf.copy()
252-
aoi_bbox = None
253-
aoi_bbox_contains = False
254-
255-
logger.info("Creating tiles for the Area of Interest (AoI)...")
256-
257-
# Get tiles coordinates and shapes
258-
tiles_4326_aoi_gdf = aoi_tiling(boundaries_df)
259-
260-
# Compute labels intersecting tiles
261-
tiles_4326_gt_gdf = gpd.sjoin(tiles_4326_aoi_gdf, gt_labels_4326_gdf, how='inner', predicate='intersects')
262-
tiles_4326_gt_gdf.drop_duplicates('title', inplace=True)
263-
logger.info(f"- Number of tiles intersecting GT labels = {len(tiles_4326_gt_gdf)}")
264-
265-
if FP_SHPFILE:
266-
tiles_4326_fp_gdf = gpd.sjoin(tiles_4326_aoi_gdf, fp_labels_4326_gdf, how='inner', predicate='intersects')
267-
tiles_4326_fp_gdf.drop_duplicates('title', inplace=True)
268-
logger.info(f"- Number of tiles intersecting FP labels = {len(tiles_4326_fp_gdf)}")
269-
270-
if not EPT_SHPFILE or EPT_SHPFILE and aoi_bbox_contains == False:
271-
# Keep only tiles intersecting labels
272-
if FP_SHPFILE:
273-
tiles_4326_aoi_gdf = pd.concat([tiles_4326_gt_gdf, tiles_4326_fp_gdf])
274-
else:
275-
tiles_4326_aoi_gdf = tiles_4326_gt_gdf.copy()
276245

277246
# Get all the tiles in one gdf
278-
if EPT_SHPFILE and aoi_bbox_contains == False:
247+
if EPT_SHPFILE:
279248
logger.info("- Concatenate label tiles and empty AoI tiles")
280-
tiles_4326_all_gdf = pd.concat([tiles_4326_aoi_gdf, empty_tiles_4326_aoi_gdf])
249+
tiles_4326_all_gdf = pd.concat([tiles_4326_labels_gdf, empty_tiles_4326_aoi_gdf])
281250
else:
282-
tiles_4326_all_gdf = tiles_4326_aoi_gdf.copy()
283-
284-
# - Remove duplicated tiles
285-
if nb_labels > 1:
286-
if 'year' in tiles_4326_all_gdf.keys():
287-
tiles_4326_all_gdf['year'] = tiles_4326_all_gdf.year.astype(int)
288-
tiles_4326_all_gdf.drop_duplicates(['title', 'year'], inplace=True)
289-
else:
290-
tiles_4326_all_gdf.drop_duplicates(['title'], inplace=True)
251+
tiles_4326_all_gdf = tiles_4326_labels_gdf.copy()
291252

292253
# - Remove useless columns, reset feature id and redefine it according to xyz format
293-
logger.info('- Add tile IDs and reorganise data set')
254+
logger.info('- Add tile IDs and reorganise the data set')
294255
tiles_4326_all_gdf = tiles_4326_all_gdf[['geometry', 'title', 'year'] if 'year' in tiles_4326_all_gdf.keys() else ['geometry', 'title']].copy()
295256
tiles_4326_all_gdf.reset_index(drop=True, inplace=True)
296257
tiles_4326_all_gdf = tiles_4326_all_gdf.apply(add_tile_id, axis=1)
258+
259+
# - Remove duplicated tiles
260+
if nb_labels > 1:
261+
tiles_4326_all_gdf.drop_duplicates(['id'], inplace=True)
262+
297263
nb_tiles = len(tiles_4326_all_gdf)
298264
logger.info(f"There were {nb_tiles} tiles created")
299265

266+
# Get the number of tiles intersecting labels
267+
tiles_4326_gt_gdf = gpd.sjoin(tiles_4326_all_gdf, gt_labels_4326_gdf, how='inner', predicate='intersects')
268+
tiles_4326_gt_gdf.drop_duplicates(['id'], inplace=True)
269+
logger.info(f"- Number of tiles intersecting GT labels = {len(tiles_4326_gt_gdf)}")
270+
271+
if FP_SHPFILE:
272+
tiles_4326_fp_gdf = gpd.sjoin(tiles_4326_all_gdf, fp_labels_4326_gdf, how='inner', predicate='intersects')
273+
tiles_4326_fp_gdf.drop_duplicates(['id'], inplace=True)
274+
logger.info(f"- Number of tiles intersecting FP labels = {len(tiles_4326_fp_gdf)}")
275+
300276
# Save tile shapefile
301277
logger.info("Export tiles to GeoJSON (EPSG:4326)...")
302278
tile_filename = 'tiles.geojson'

examples/road-surface-classification/config_rs.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,5 @@ assess_detections.py:
9292
tst: tst_detections_at_0dot05_threshold.gpkg
9393
oth: oth_detections_at_0dot05_threshold.gpkg
9494
output_folder: .
95-
iou_threshold: 0.1
95+
iou_threshold: 0.1
96+
metrics_method: macro-average # 1: macro-average ; 3: macro-weighted-average ; 2: micro-average

examples/swimming-pool-detection/GE/config_GE.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,4 @@ assess_detections.py:
7979
tst: tst_detections_at_0dot05_threshold.gpkg
8080
oth: oth_detections_at_0dot05_threshold.gpkg
8181
output_folder: .
82+
metrics_method: micro-average # 1: macro-average ; 3: macro-weighted-average ; 2: micro-average

examples/swimming-pool-detection/NE/config_NE.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,4 @@ assess_detections.py:
8484
tst: tst_detections_at_0dot05_threshold.gpkg
8585
oth: oth_detections_at_0dot05_threshold.gpkg
8686
output_folder: .
87+
metrics_method: micro-average # 1: macro-average ; 3: macro-weighted-average ; 2: micro-average

helpers/metrics.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def get_fractional_sets(dets_gdf, labels_gdf, iou_threshold=0.25, area_threshold
120120
return tp_gdf, fp_gdf, fn_gdf, mismatched_classes_gdf, small_poly_gdf
121121

122122

123-
def get_metrics(tp_gdf, fp_gdf, fn_gdf, mismatch_gdf, id_classes=0):
123+
def get_metrics(tp_gdf, fp_gdf, fn_gdf, mismatch_gdf, id_classes=0, method='macro-average'):
124124
"""Determine the metrics based on the TP, FP and FN
125125
126126
Args:
@@ -129,6 +129,7 @@ def get_metrics(tp_gdf, fp_gdf, fn_gdf, mismatch_gdf, id_classes=0):
129129
fn_gdf (geodataframe): false negative labels
130130
mismatch_gdf (geodataframe): labels and detections intersecting with a mismatched class id
131131
id_classes (list): list of the possible class ids. Defaults to 0.
132+
method (str): method used to compute multi-class metrics. Default to macro-average
132133
133134
Returns:
134135
tuple:
@@ -139,16 +140,28 @@ def get_metrics(tp_gdf, fp_gdf, fn_gdf, mismatch_gdf, id_classes=0):
139140
- float: f1 score.
140141
"""
141142

143+
by_class_dict = {key: 0 for key in id_classes}
144+
tp_k = by_class_dict.copy()
145+
fp_k = by_class_dict.copy()
146+
fn_k = by_class_dict.copy()
147+
p_k = by_class_dict.copy()
148+
r_k = by_class_dict.copy()
149+
count_k = by_class_dict.copy()
150+
pw_k = by_class_dict.copy()
151+
rw_k = by_class_dict.copy()
152+
142153
by_class_dict = {key: None for key in id_classes}
143154
tp_k = by_class_dict.copy()
144155
fp_k = by_class_dict.copy()
145156
fn_k = by_class_dict.copy()
146157
p_k = by_class_dict.copy()
147158
r_k = by_class_dict.copy()
159+
count_k = by_class_dict.copy()
160+
pw_k = by_class_dict.copy()
161+
rw_k = by_class_dict.copy()
148162

149163
for id_cl in id_classes:
150164

151-
tp_count = 0 if tp_gdf.empty else len(tp_gdf[tp_gdf.det_class==id_cl])
152165
pure_fp_count = 0 if fp_gdf.empty else len(fp_gdf[fp_gdf.det_class==id_cl])
153166
pure_fn_count = 0 if fn_gdf.empty else len(fn_gdf[fn_gdf.label_class==id_cl+1]) # label class starting at 1 and id class at 0
154167

@@ -161,21 +174,33 @@ def get_metrics(tp_gdf, fp_gdf, fn_gdf, mismatch_gdf, id_classes=0):
161174

162175
fp_count = pure_fp_count + mismatched_fp_count
163176
fn_count = pure_fn_count + mismatched_fn_count
177+
tp_count = 0 if tp_gdf.empty else len(tp_gdf[tp_gdf.det_class==id_cl])
164178

165179
tp_k[id_cl] = tp_count
166180
fp_k[id_cl] = fp_count
167181
fn_k[id_cl] = fn_count
168182

169-
if tp_count == 0:
170-
p_k[id_cl] = 0
171-
r_k[id_cl] = 0
172-
else:
173-
p_k[id_cl] = tp_count / (tp_count + fp_count)
174-
r_k[id_cl] = tp_count / (tp_count + fn_count)
175-
176-
precision = sum(p_k.values()) / len(id_classes)
177-
recall = sum(r_k.values()) / len(id_classes)
178-
183+
p_k[id_cl] = 0 if tp_count == 0 else tp_count / (tp_count + fp_count)
184+
r_k[id_cl] = 0 if tp_count == 0 else tp_count / (tp_count + fn_count)
185+
count_k[id_cl] = 0 if tp_count == 0 else tp_count + fn_count
186+
187+
if method == 'macro-average':
188+
precision = sum(p_k.values()) / len(id_classes)
189+
recall = sum(r_k.values()) / len(id_classes)
190+
elif method == 'macro-weighted-average':
191+
for id_cl in id_classes:
192+
pw_k[id_cl] = 0 if sum(count_k.values()) == 0 else (count_k[id_cl] / sum(count_k.values())) * p_k[id_cl]
193+
rw_k[id_cl] = 0 if sum(count_k.values()) == 0 else (count_k[id_cl] / sum(count_k.values())) * r_k[id_cl]
194+
precision = sum(pw_k.values()) / len(id_classes)
195+
recall = sum(rw_k.values()) / len(id_classes)
196+
elif method == 'micro-average':
197+
if sum(tp_k.values()) == 0 and sum(fp_k.values()) == 0:
198+
precision = 0
199+
recall = 0
200+
else:
201+
precision = sum(tp_k.values()) / (sum(tp_k.values()) + sum(fp_k.values()))
202+
recall = sum(tp_k.values()) / (sum(tp_k.values()) + sum(fn_k.values()))
203+
179204
if precision==0 and recall==0:
180205
return tp_k, fp_k, fn_k, p_k, r_k, 0, 0, 0
181206

requirements.in

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,18 @@
55
GDAL==3.0.4
66
certifi>=2022.12.07
77
future>=0.18.3
8-
geopandas==0.11.1
8+
fiona==1.9.6
9+
geopandas
910
joblib
1011
loguru
1112
morecantile
12-
network
13+
networkx
1314
numpy==1.23.3
1415
oauthlib>=3.2.2
1516
opencv-python
1617
pillow==9.5.0
1718
plotly
19+
protobuf==4.25
1820
pygeohash
1921
pyyaml
2022
rasterio

0 commit comments

Comments
 (0)