Skip to content

Commit 3fe54f2

Browse files
committed
Simplify get_metrics to suit sonarcube v2
1 parent c6091b5 commit 3fe54f2

File tree

1 file changed

+18
-13
lines changed

1 file changed

+18
-13
lines changed

helpers/metrics.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def get_metrics(tp_gdf, fp_gdf, fn_gdf, mismatch_gdf, id_classes=0, method='macr
140140
- float: f1 score.
141141
"""
142142

143-
by_class_dict = {key: None for key in id_classes}
143+
by_class_dict = {key: 0 for key in id_classes}
144144
tp_k = by_class_dict.copy()
145145
fp_k = by_class_dict.copy()
146146
fn_k = by_class_dict.copy()
@@ -152,31 +152,36 @@ def get_metrics(tp_gdf, fp_gdf, fn_gdf, mismatch_gdf, id_classes=0, method='macr
152152

153153
for id_cl in id_classes:
154154

155-
tp_count = 0 if tp_gdf.empty else len(tp_gdf[tp_gdf.det_class==id_cl])
156155
pure_fp_count = 0 if fp_gdf.empty else len(fp_gdf[fp_gdf.det_class==id_cl])
157156
pure_fn_count = 0 if fn_gdf.empty else len(fn_gdf[fn_gdf.label_class==id_cl+1]) # label class starting at 1 and id class at 0
158157

159-
mismatched_fp_count = 0 if mismatch_gdf.empty else len(mismatch_gdf[mismatch_gdf.det_class==id_cl])
160-
mismatched_fn_count = 0 if mismatch_gdf.empty else len(mismatch_gdf[mismatch_gdf.label_class==id_cl+1])
158+
if mismatch_gdf.empty:
159+
mismatched_fp_count = 0
160+
mismatched_fn_count = 0
161+
else:
162+
mismatched_fp_count = len(mismatch_gdf[mismatch_gdf.det_class==id_cl])
163+
mismatched_fn_count = len(mismatch_gdf[mismatch_gdf.label_class==id_cl+1])
161164

162165
fp_count = pure_fp_count + mismatched_fp_count
163166
fn_count = pure_fn_count + mismatched_fn_count
167+
tp_count = 0 if tp_gdf.empty else len(tp_gdf[tp_gdf.det_class==id_cl])
164168

165-
tp_k[id_cl] = tp_count
166169
fp_k[id_cl] = fp_count
167170
fn_k[id_cl] = fn_count
168-
169-
p_k[id_cl] = 0 if tp_count == 0 else tp_count / (tp_count + fp_count)
170-
r_k[id_cl] = 0 if tp_count == 0 else tp_count / (tp_count + fn_count)
171-
count_k[id_cl] = 0 if tp_count == 0 else tp_count + fn_count
171+
tp_k[id_cl] = tp_count
172+
173+
if tp_count > 0:
174+
p_k[id_cl] = tp_count / (tp_count + fp_count)
175+
r_k[id_cl] = tp_count / (tp_count + fn_count)
176+
count_k[id_cl] = tp_count + fn_count
177+
if method == 'macro-weighted-average':
178+
pw_k[id_cl] = (count_k[id_cl] / sum(count_k.values())) * p_k[id_cl]
179+
rw_k[id_cl] = (count_k[id_cl] / sum(count_k.values())) * r_k[id_cl]
172180

173181
if method == 'macro-average':
174182
precision = sum(p_k.values()) / len(id_classes)
175183
recall = sum(r_k.values()) / len(id_classes)
176-
elif method == 'macro-weighted-average':
177-
for id_cl in id_classes:
178-
pw_k[id_cl] = 0 if sum(count_k.values()) == 0 else (count_k[id_cl] / sum(count_k.values())) * p_k[id_cl]
179-
rw_k[id_cl] = 0 if sum(count_k.values()) == 0 else (count_k[id_cl] / sum(count_k.values())) * r_k[id_cl]
184+
elif method == 'macro-weighted-average':
180185
precision = sum(pw_k.values()) / len(id_classes)
181186
recall = sum(rw_k.values()) / len(id_classes)
182187
elif method == 'micro-average':

0 commit comments

Comments
 (0)