@@ -150,15 +150,7 @@ def get_metrics(tp_gdf, fp_gdf, fn_gdf, mismatch_gdf, id_classes=0, method='macr
150
150
pw_k = by_class_dict .copy ()
151
151
rw_k = by_class_dict .copy ()
152
152
153
- by_class_dict = {key : None for key in id_classes }
154
- tp_k = by_class_dict .copy ()
155
- fp_k = by_class_dict .copy ()
156
- fn_k = by_class_dict .copy ()
157
- p_k = by_class_dict .copy ()
158
- r_k = by_class_dict .copy ()
159
- count_k = by_class_dict .copy ()
160
- pw_k = by_class_dict .copy ()
161
- rw_k = by_class_dict .copy ()
153
+ total_labels = len (tp_gdf ) + len (fn_gdf ) + len (mismatch_gdf )
162
154
163
155
for id_cl in id_classes :
164
156
@@ -176,25 +168,27 @@ def get_metrics(tp_gdf, fp_gdf, fn_gdf, mismatch_gdf, id_classes=0, method='macr
176
168
fn_count = pure_fn_count + mismatched_fn_count
177
169
tp_count = 0 if tp_gdf .empty else len (tp_gdf [tp_gdf .det_class == id_cl ])
178
170
179
- tp_k [id_cl ] = tp_count
180
171
fp_k [id_cl ] = fp_count
181
172
fn_k [id_cl ] = fn_count
182
-
183
- p_k [id_cl ] = 0 if tp_count == 0 else tp_count / (tp_count + fp_count )
184
- r_k [id_cl ] = 0 if tp_count == 0 else tp_count / (tp_count + fn_count )
185
- count_k [id_cl ] = 0 if tp_count == 0 else tp_count + fn_count
173
+ tp_k [id_cl ] = tp_count
174
+
175
+ count_k [id_cl ] = tp_count + fn_count
176
+ if tp_count > 0 :
177
+ p_k [id_cl ] = tp_count / (tp_count + fp_count )
178
+ r_k [id_cl ] = tp_count / (tp_count + fn_count )
179
+
180
+ if (method == 'macro-weighted-average' ) & (total_labels > 0 ):
181
+ pw_k [id_cl ] = (count_k [id_cl ] / total_labels ) * p_k [id_cl ]
182
+ rw_k [id_cl ] = (count_k [id_cl ] / total_labels ) * r_k [id_cl ]
186
183
187
184
if method == 'macro-average' :
188
185
precision = sum (p_k .values ()) / len (id_classes )
189
186
recall = sum (r_k .values ()) / len (id_classes )
190
- elif method == 'macro-weighted-average' :
191
- for id_cl in id_classes :
192
- pw_k [id_cl ] = 0 if sum (count_k .values ()) == 0 else (count_k [id_cl ] / sum (count_k .values ())) * p_k [id_cl ]
193
- rw_k [id_cl ] = 0 if sum (count_k .values ()) == 0 else (count_k [id_cl ] / sum (count_k .values ())) * r_k [id_cl ]
187
+ elif method == 'macro-weighted-average' :
194
188
precision = sum (pw_k .values ()) / len (id_classes )
195
189
recall = sum (rw_k .values ()) / len (id_classes )
196
190
elif method == 'micro-average' :
197
- if sum (tp_k .values ()) == 0 and sum ( fp_k . values ()) == 0 :
191
+ if sum (tp_k .values ()) == 0 :
198
192
precision = 0
199
193
recall = 0
200
194
else :
@@ -224,4 +218,9 @@ def intersection_over_union(polygon1_shape, polygon2_shape):
224
218
polygon_intersection = polygon1_shape .intersection (polygon2_shape ).area
225
219
polygon_union = polygon1_shape .area + polygon2_shape .area - polygon_intersection
226
220
227
- return polygon_intersection / polygon_union
221
+ if polygon_union != 0 :
222
+ iou = polygon_intersection / polygon_union
223
+ else :
224
+ iou = 0
225
+
226
+ return iou
0 commit comments