@@ -150,6 +150,16 @@ def get_metrics(tp_gdf, fp_gdf, fn_gdf, mismatch_gdf, id_classes=0, method='macr
150
150
pw_k = by_class_dict .copy ()
151
151
rw_k = by_class_dict .copy ()
152
152
153
+ by_class_dict = {key : None for key in id_classes }
154
+ tp_k = by_class_dict .copy ()
155
+ fp_k = by_class_dict .copy ()
156
+ fn_k = by_class_dict .copy ()
157
+ p_k = by_class_dict .copy ()
158
+ r_k = by_class_dict .copy ()
159
+ count_k = by_class_dict .copy ()
160
+ pw_k = by_class_dict .copy ()
161
+ rw_k = by_class_dict .copy ()
162
+
153
163
for id_cl in id_classes :
154
164
155
165
pure_fp_count = 0 if fp_gdf .empty else len (fp_gdf [fp_gdf .det_class == id_cl ])
@@ -166,22 +176,21 @@ def get_metrics(tp_gdf, fp_gdf, fn_gdf, mismatch_gdf, id_classes=0, method='macr
166
176
fn_count = pure_fn_count + mismatched_fn_count
167
177
tp_count = 0 if tp_gdf .empty else len (tp_gdf [tp_gdf .det_class == id_cl ])
168
178
179
+ tp_k [id_cl ] = tp_count
169
180
fp_k [id_cl ] = fp_count
170
181
fn_k [id_cl ] = fn_count
171
- tp_k [id_cl ] = tp_count
172
-
173
- if tp_count > 0 :
174
- p_k [id_cl ] = tp_count / (tp_count + fp_count )
175
- r_k [id_cl ] = tp_count / (tp_count + fn_count )
176
- count_k [id_cl ] = tp_count + fn_count
177
- if method == 'macro-weighted-average' :
178
- pw_k [id_cl ] = (count_k [id_cl ] / sum (count_k .values ())) * p_k [id_cl ]
179
- rw_k [id_cl ] = (count_k [id_cl ] / sum (count_k .values ())) * r_k [id_cl ]
182
+
183
+ p_k [id_cl ] = 0 if tp_count == 0 else tp_count / (tp_count + fp_count )
184
+ r_k [id_cl ] = 0 if tp_count == 0 else tp_count / (tp_count + fn_count )
185
+ count_k [id_cl ] = 0 if tp_count == 0 else tp_count + fn_count
180
186
181
187
if method == 'macro-average' :
182
188
precision = sum (p_k .values ()) / len (id_classes )
183
189
recall = sum (r_k .values ()) / len (id_classes )
184
- elif method == 'macro-weighted-average' :
190
+ elif method == 'macro-weighted-average' :
191
+ for id_cl in id_classes :
192
+ pw_k [id_cl ] = 0 if sum (count_k .values ()) == 0 else (count_k [id_cl ] / sum (count_k .values ())) * p_k [id_cl ]
193
+ rw_k [id_cl ] = 0 if sum (count_k .values ()) == 0 else (count_k [id_cl ] / sum (count_k .values ())) * r_k [id_cl ]
185
194
precision = sum (pw_k .values ()) / len (id_classes )
186
195
recall = sum (rw_k .values ()) / len (id_classes )
187
196
elif method == 'micro-average' :
0 commit comments