@@ -83,9 +83,7 @@ def classification_confusion_matrix(ground_truths: List[ClassificationAnnotation
83
83
"Classification features must be the same type to compute agreement. "
84
84
f"Found `{ type (prediction )} ` and `{ type (ground_truth )} `" )
85
85
86
- if isinstance (prediction .value , Text ):
87
- return text_confusion_matrix (ground_truth .value , prediction .value )
88
- elif isinstance (prediction .value , Radio ):
86
+ if isinstance (prediction .value , Radio ):
89
87
return radio_confusion_matrix (ground_truth .value , prediction .value )
90
88
elif isinstance (prediction .value , Checklist ):
91
89
return checklist_confusion_matrix (ground_truth .value , prediction .value )
@@ -185,18 +183,24 @@ def _polygon_iou(poly1: Polygon, poly2: Polygon) -> ScalarMetricValue:
185
183
def radio_confusion_matrix (ground_truth : Radio , prediction : Radio ) -> ScalarMetricValue :
186
184
"""
187
185
Calculates confusion between ground truth and predicted radio values
186
+
187
+ The way we are calculating confusion matrix metrics:
188
+ - TNs aren't defined because we don't know how many other classes exist ... etc
189
+
190
+ We treat each example as 1 vs all
191
+
188
192
"""
189
193
key = get_identifying_key ([prediction .answer ], [ground_truth .answer ])
194
+ prediction_id = getattr (prediction .answer , key )
195
+ ground_truth_id = getattr (ground_truth .answer , key )
196
+
197
+
198
+
190
199
191
200
return float (getattr (prediction .answer , key ) ==
192
201
getattr (ground_truth .answer , key ))
193
202
194
203
195
- def text_confusion_matrix (ground_truth : Text , prediction : Text ) -> ScalarMetricValue :
196
- """
197
- Calculates agreement between ground truth and predicted text
198
- """
199
- return float (prediction .answer == ground_truth .answer )
200
204
201
205
202
206
def checklist_confusion_matrix (ground_truth : Checklist , prediction : Checklist ) -> ScalarMetricValue :
0 commit comments