Skip to content

Commit 9da9a3a

Browse files
authored
Merge pull request #26 from swiss-territorial-data-lab/gs/improve_assessment
- Improve the assessment - Make the validity check more versatile
2 parents 292b1a2 + 8f96a85 commit 9da9a3a

File tree

3 files changed

+17
-8
lines changed

3 files changed

+17
-8
lines changed

helpers/metrics.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,7 @@ def get_fractional_sets(dets_gdf, labels_gdf, iou_threshold=0.25):
5454
# IoU computation between labels and detections
5555
geom1 = candidates_tp_gdf['geometry'].to_numpy().tolist()
5656
geom2 = candidates_tp_gdf['label_geom'].to_numpy().tolist()
57-
iou = []
58-
for (i, ii) in zip(geom1, geom2):
59-
iou.append(intersection_over_union(i, ii))
60-
candidates_tp_gdf['IOU'] = iou
57+
candidates_tp_gdf['IOU'] = [intersection_over_union(i, ii) for (i, ii) in zip(geom1, geom2)]
6158

6259
# Filter detections based on IoU value
6360
best_matches_gdf = candidates_tp_gdf.groupby(['det_id'], group_keys=False).apply(lambda g:g[g.IOU==g.IOU.max()])
@@ -144,7 +141,7 @@ def get_metrics(tp_gdf, fp_gdf, fn_gdf, mismatch_gdf, id_classes=0):
144141
mismatched_fp_count = 0
145142
mismatched_fn_count = 0
146143
else:
147-
mismatched_fp_count = len(tp_gdf[tp_gdf.det_class==id_cl])
144+
mismatched_fp_count = len(mismatch_gdf[mismatch_gdf.det_class==id_cl])
148145
mismatched_fn_count = len(mismatch_gdf[mismatch_gdf.label_class==id_cl+1])
149146

150147
fp_count = pure_fp_count + mismatched_fp_count

helpers/misc.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,13 @@ def check_validity(poly_gdf, correct=False):
5555
if correct:
5656
print("Correction of the invalid geometries with the shapely function 'make_valid'...")
5757
invalid_poly = poly_gdf.loc[invalid_condition, 'geometry']
58-
poly_gdf.loc[invalid_condition, 'geometry'] = [
59-
make_valid(poly) for poly in invalid_poly
60-
]
58+
try:
59+
poly_gdf.loc[invalid_condition, 'geometry'] = [
60+
make_valid(poly) for poly in invalid_poly
61+
]
62+
except ValueError:
63+
logger.info('Failed to fix geometries with "make_valid", try with a buffer of 0.')
64+
poly_gdf.loc[invalid_condition, 'geometry'] = [poly.buffer(0) for poly in invalid_poly]
6165
else:
6266
sys.exit(1)
6367

scripts/assess_detections.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,14 @@ def main(cfg_file_path):
392392
].sort_values(by=['class', 'dataset']).to_csv(file_to_write, index=False)
393393
written_files.append(file_to_write)
394394

395+
tmp_df = metrics_by_cl_df[['dataset', 'TP_k', 'FP_k', 'FN_k']].groupby(by='dataset', as_index=False).sum()
396+
tmp_df2 = metrics_by_cl_df[['dataset', 'precision_k', 'recall_k']].groupby(by='dataset', as_index=False).mean()
397+
global_metrics_df = tmp_df.merge(tmp_df2, on='dataset')
398+
399+
file_to_write = os.path.join(OUTPUT_DIR, 'global_metrics.csv')
400+
global_metrics_df.to_csv(file_to_write, index=False)
401+
written_files.append(file_to_write)
402+
395403
# Save the confusion matrix
396404
na_value_category = tagged_dets_gdf.CATEGORY.isna()
397405
sorted_classes = tagged_dets_gdf.loc[~na_value_category, 'CATEGORY'].sort_values().unique().tolist() + ['background']

0 commit comments

Comments
 (0)