@@ -140,7 +140,7 @@ def get_metrics(tp_gdf, fp_gdf, fn_gdf, mismatch_gdf, id_classes=0, method='macr
140
140
- float: f1 score.
141
141
"""
142
142
143
- by_class_dict = {key : None for key in id_classes }
143
+ by_class_dict = {key : 0 for key in id_classes }
144
144
tp_k = by_class_dict .copy ()
145
145
fp_k = by_class_dict .copy ()
146
146
fn_k = by_class_dict .copy ()
@@ -152,31 +152,36 @@ def get_metrics(tp_gdf, fp_gdf, fn_gdf, mismatch_gdf, id_classes=0, method='macr
152
152
153
153
for id_cl in id_classes :
154
154
155
- tp_count = 0 if tp_gdf .empty else len (tp_gdf [tp_gdf .det_class == id_cl ])
156
155
pure_fp_count = 0 if fp_gdf .empty else len (fp_gdf [fp_gdf .det_class == id_cl ])
157
156
pure_fn_count = 0 if fn_gdf .empty else len (fn_gdf [fn_gdf .label_class == id_cl + 1 ]) # label class starting at 1 and id class at 0
158
157
159
- mismatched_fp_count = 0 if mismatch_gdf .empty else len (mismatch_gdf [mismatch_gdf .det_class == id_cl ])
160
- mismatched_fn_count = 0 if mismatch_gdf .empty else len (mismatch_gdf [mismatch_gdf .label_class == id_cl + 1 ])
158
+ if mismatch_gdf .empty :
159
+ mismatched_fp_count = 0
160
+ mismatched_fn_count = 0
161
+ else :
162
+ mismatched_fp_count = len (mismatch_gdf [mismatch_gdf .det_class == id_cl ])
163
+ mismatched_fn_count = len (mismatch_gdf [mismatch_gdf .label_class == id_cl + 1 ])
161
164
162
165
fp_count = pure_fp_count + mismatched_fp_count
163
166
fn_count = pure_fn_count + mismatched_fn_count
167
+ tp_count = 0 if tp_gdf .empty else len (tp_gdf [tp_gdf .det_class == id_cl ])
164
168
165
- tp_k [id_cl ] = tp_count
166
169
fp_k [id_cl ] = fp_count
167
170
fn_k [id_cl ] = fn_count
168
-
169
- p_k [id_cl ] = 0 if tp_count == 0 else tp_count / (tp_count + fp_count )
170
- r_k [id_cl ] = 0 if tp_count == 0 else tp_count / (tp_count + fn_count )
171
- count_k [id_cl ] = 0 if tp_count == 0 else tp_count + fn_count
171
+ tp_k [id_cl ] = tp_count
172
+
173
+ if tp_count > 0 :
174
+ p_k [id_cl ] = tp_count / (tp_count + fp_count )
175
+ r_k [id_cl ] = tp_count / (tp_count + fn_count )
176
+ count_k [id_cl ] = tp_count + fn_count
177
+ if method == 'macro-weighted-average' :
178
+ pw_k [id_cl ] = (count_k [id_cl ] / sum (count_k .values ())) * p_k [id_cl ]
179
+ rw_k [id_cl ] = (count_k [id_cl ] / sum (count_k .values ())) * r_k [id_cl ]
172
180
173
181
if method == 'macro-average' :
174
182
precision = sum (p_k .values ()) / len (id_classes )
175
183
recall = sum (r_k .values ()) / len (id_classes )
176
- elif method == 'macro-weighted-average' :
177
- for id_cl in id_classes :
178
- pw_k [id_cl ] = 0 if sum (count_k .values ()) == 0 else (count_k [id_cl ] / sum (count_k .values ())) * p_k [id_cl ]
179
- rw_k [id_cl ] = 0 if sum (count_k .values ()) == 0 else (count_k [id_cl ] / sum (count_k .values ())) * r_k [id_cl ]
184
+ elif method == 'macro-weighted-average' :
180
185
precision = sum (pw_k .values ()) / len (id_classes )
181
186
recall = sum (rw_k .values ()) / len (id_classes )
182
187
elif method == 'micro-average' :
0 commit comments