Skip to content

Commit d30cc74

Browse files
author
Matt Sokoloff
committed
wip
1 parent 557f56f commit d30cc74

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

labelbox/data/metrics/confusion_matrix/calculation.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,7 @@ def classification_confusion_matrix(ground_truths: List[ClassificationAnnotation
8383
"Classification features must be the same type to compute agreement. "
8484
f"Found `{type(prediction)}` and `{type(ground_truth)}`")
8585

86-
if isinstance(prediction.value, Text):
87-
return text_confusion_matrix(ground_truth.value, prediction.value)
88-
elif isinstance(prediction.value, Radio):
86+
if isinstance(prediction.value, Radio):
8987
return radio_confusion_matrix(ground_truth.value, prediction.value)
9088
elif isinstance(prediction.value, Checklist):
9189
return checklist_confusion_matrix(ground_truth.value, prediction.value)
@@ -185,18 +183,24 @@ def _polygon_iou(poly1: Polygon, poly2: Polygon) -> ScalarMetricValue:
185183
def radio_confusion_matrix(ground_truth: Radio, prediction: Radio) -> ScalarMetricValue:
186184
"""
187185
Calculates confusion between ground truth and predicted radio values
186+
187+
The way we are calculating confusion matrix metrics:
188+
- TNs aren't defined because we don't know how many other classes exist ... etc
189+
190+
We treat each example as 1 vs all
191+
188192
"""
189193
key = get_identifying_key([prediction.answer], [ground_truth.answer])
194+
prediction_id = getattr(prediction.answer, key)
195+
ground_truth_id = getattr(ground_truth.answer, key)
196+
197+
198+
190199

191200
return float(getattr(prediction.answer, key) ==
192201
getattr(ground_truth.answer, key))
193202

194203

195-
def text_confusion_matrix(ground_truth: Text, prediction: Text) -> ScalarMetricValue:
196-
"""
197-
Calculates agreement between ground truth and predicted text
198-
"""
199-
return float(prediction.answer == ground_truth.answer)
200204

201205

202206
def checklist_confusion_matrix(ground_truth: Checklist, prediction: Checklist) -> ScalarMetricValue:

tests/data/metrics/confusion_matrix/data_row/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,3 +242,4 @@ def radio_pairs():
242242
"""
243243

244244
# Current question.. how do we handle classification precision and recall...
245+

0 commit comments

Comments
 (0)