Skip to content

Commit dc6310b

Browse files
author
Matt Sokoloff
committed
tested
1 parent b1f9bf6 commit dc6310b

File tree

12 files changed

+614
-479
lines changed

12 files changed

+614
-479
lines changed

labelbox/data/annotation_types/metrics/confusion_matrix.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
Count = conint(ge=0, le=1e10)
1010

11-
1211
ConfusionMatrixMetricValue = Tuple[Count, Count, Count, Count]
1312
ConfusionMatrixMetricConfidenceValue = Dict[ConfidenceValue,
1413
ConfusionMatrixMetricValue]
Lines changed: 85 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,61 @@
1-
2-
3-
4-
from labelbox.data.metrics.iou.calculation import _mask_iou, miou
1+
from labelbox.data.metrics.iou.calculation import _get_mask_pairs, _get_vector_pairs, miou
52

63
from labelbox.data.annotation_types.metrics.confusion_matrix import \
74
ConfusionMatrixMetricValue
85

9-
106
from labelbox.data.annotation_types.metrics.scalar import ScalarMetricValue
117
from typing import List, Optional, Tuple, Union
12-
from shapely.geometry import Polygon
13-
from itertools import product
148
import numpy as np
159
from ...annotation_types import (ObjectAnnotation, ClassificationAnnotation,
16-
Mask, Geometry, Point, Line, Checklist, Text,
17-
Radio)
18-
from ..group import get_feature_pairs, get_identifying_key
10+
Mask, Geometry, Checklist, Radio)
11+
from ..processing import get_feature_pairs, get_identifying_key, has_no_annotations, has_no_matching_annotations
1912

2013

2114
def confusion_matrix(ground_truths: List[Union[ObjectAnnotation,
22-
ClassificationAnnotation]],
23-
predictions: List[Union[ObjectAnnotation,
24-
ClassificationAnnotation]],
25-
iou: float,
26-
include_subclasses: bool) -> ConfusionMatrixMetricValue:
15+
ClassificationAnnotation]],
16+
predictions: List[Union[ObjectAnnotation,
17+
ClassificationAnnotation]],
18+
include_subclasses: bool,
19+
iou: float) -> ConfusionMatrixMetricValue:
2720

2821
annotation_pairs = get_feature_pairs(predictions, ground_truths)
2922
ious = [
30-
feature_confusion_matrix(annotation_pair[0], annotation_pair[1], iou, include_subclasses)
23+
feature_confusion_matrix(annotation_pair[0], annotation_pair[1],
24+
include_subclasses, iou)
3125
for annotation_pair in annotation_pairs.values()
3226
]
3327
ious = [iou for iou in ious if iou is not None]
3428

35-
return None if not len(ious) else np.sum(ious, axis = 0 ).tolist()
29+
return None if not len(ious) else np.sum(ious, axis=0).tolist()
3630

3731

38-
39-
def feature_confusion_matrix(ground_truths: List[Union[ObjectAnnotation,
40-
ClassificationAnnotation]],
41-
predictions: List[Union[ObjectAnnotation,
42-
ClassificationAnnotation]],
43-
iou: float,
44-
include_subclasses: bool) -> Optional[ConfusionMatrixMetricValue]:
45-
if _no_matching_annotations(ground_truths, predictions):
46-
return [0,int(len(predictions) > 0), 0, int(len(ground_truths) > 0)]
47-
elif _no_annotations(ground_truths, predictions):
32+
def feature_confusion_matrix(
33+
ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]],
34+
predictions: List[Union[ObjectAnnotation, ClassificationAnnotation]],
35+
include_subclasses: bool,
36+
iou: float) -> Optional[ConfusionMatrixMetricValue]:
37+
if has_no_matching_annotations(ground_truths, predictions):
38+
return [0, int(len(predictions) > 0), 0, int(len(ground_truths) > 0)]
39+
elif has_no_annotations(ground_truths, predictions):
4840
# Note that we could return [0,0,0,0] but that will bloat the imports for no reason
4941
return None
5042
elif isinstance(predictions[0].value, Mask):
51-
return mask_confusion_matrix(ground_truths, predictions, iou, include_subclasses)
43+
return mask_confusion_matrix(ground_truths, predictions, iou,
44+
include_subclasses)
5245
elif isinstance(predictions[0].value, Geometry):
53-
return vector_confusion_matrix(ground_truths, predictions, iou, include_subclasses)
46+
return vector_confusion_matrix(ground_truths, predictions, iou,
47+
include_subclasses)
5448
elif isinstance(predictions[0], ClassificationAnnotation):
5549
return classification_confusion_matrix(ground_truths, predictions)
5650
else:
5751
raise ValueError(
5852
f"Unexpected annotation found. Found {type(predictions[0].value)}")
5953

6054

61-
def classification_confusion_matrix(ground_truths: List[ClassificationAnnotation],
62-
predictions: List[ClassificationAnnotation]) -> ConfusionMatrixMetricValue:
55+
def classification_confusion_matrix(
56+
ground_truths: List[ClassificationAnnotation],
57+
predictions: List[ClassificationAnnotation]
58+
) -> ConfusionMatrixMetricValue:
6359
"""
6460
Computes iou score for all features with the same feature schema id.
6561
@@ -70,9 +66,11 @@ def classification_confusion_matrix(ground_truths: List[ClassificationAnnotation
7066
float representing the iou score for the classification
7167
"""
7268

73-
if _no_matching_annotations(ground_truths, predictions):
74-
return [0,int(len(predictions) > 0), 0, int(len(ground_truths) > 0)]
75-
elif _no_annotations(ground_truths, predictions) or len(predictions) > 1 or len(ground_truths) > 1:
69+
if has_no_matching_annotations(ground_truths, predictions):
70+
return [0, int(len(predictions) > 0), 0, int(len(ground_truths) > 0)]
71+
elif has_no_annotations(
72+
ground_truths,
73+
predictions) or len(predictions) > 1 or len(ground_truths) > 1:
7674
# Note that we could return [0,0,0,0] but that will bloat the imports for no reason
7775
return None
7876

@@ -91,22 +89,24 @@ def classification_confusion_matrix(ground_truths: List[ClassificationAnnotation
9189
raise ValueError(f"Unsupported subclass. {prediction}.")
9290

9391

94-
9592
def vector_confusion_matrix(ground_truths: List[ObjectAnnotation],
96-
predictions: List[ObjectAnnotation],
97-
iou,
98-
include_subclasses: bool,
99-
buffer=70.) -> Optional[ConfusionMatrixMetricValue]:
100-
if _no_matching_annotations(ground_truths, predictions):
101-
return [0,int(len(predictions) > 0), 0, int(len(ground_truths) > 0)]
102-
elif _no_annotations(ground_truths, predictions):
93+
predictions: List[ObjectAnnotation],
94+
iou: float,
95+
include_subclasses: bool,
96+
buffer=70.) -> Optional[ConfusionMatrixMetricValue]:
97+
if has_no_matching_annotations(ground_truths, predictions):
98+
return [0, int(len(predictions) > 0), 0, int(len(ground_truths) > 0)]
99+
elif has_no_annotations(ground_truths, predictions):
103100
return None
104101

105102
pairs = _get_vector_pairs(ground_truths, predictions, buffer=buffer)
106103
return object_pair_confusion_matrix(pairs, iou, include_subclasses)
107104

108105

109-
def object_pair_confusion_matrix(pairs : List[Tuple[ObjectAnnotation, ObjectAnnotation, ScalarMetricValue]], iou, include_subclasses) -> ConfusionMatrixMetricValue:
106+
def object_pair_confusion_matrix(
107+
pairs: List[Tuple[ObjectAnnotation, ObjectAnnotation,
108+
ScalarMetricValue]], iou,
109+
include_subclasses) -> ConfusionMatrixMetricValue:
110110
pairs.sort(key=lambda triplet: triplet[2], reverse=True)
111111
prediction_ids = set()
112112
ground_truth_ids = set()
@@ -122,9 +122,12 @@ def object_pair_confusion_matrix(pairs : List[Tuple[ObjectAnnotation, ObjectAnno
122122
if agreement > iou and \
123123
prediction_id not in matched_predictions and \
124124
ground_truth_id not in matched_ground_truths:
125-
if include_subclasses and (ground_truth.classifications or prediction.classifications):
126-
if miou(prediction.classifications, ground_truth.classifications) < 1.:
127-
# Incorrect if the subclasses don't 100% agree
125+
if include_subclasses and (ground_truth.classifications or
126+
prediction.classifications):
127+
if miou(prediction.classifications,
128+
ground_truth.classifications,
129+
include_subclasses=False) < 1.:
130+
# Incorrect if the subclasses don't 100% agree then there is no match
128131
continue
129132
matched_predictions.add(prediction_id)
130133
matched_ground_truths.add(ground_truth_id)
@@ -136,89 +139,56 @@ def object_pair_confusion_matrix(pairs : List[Tuple[ObjectAnnotation, ObjectAnno
136139
return [tps, fps, tns, fns]
137140

138141

139-
140-
def _get_vector_pairs(
141-
ground_truths: List[ObjectAnnotation],
142-
predictions: List[ObjectAnnotation], buffer: float
143-
) -> List[Tuple[ObjectAnnotation, ObjectAnnotation, ScalarMetricValue]]:
144-
"""
145-
# Get iou score for all pairs of ground truths and predictions
146-
"""
147-
pairs = []
148-
for prediction, ground_truth in product(predictions, ground_truths):
149-
if isinstance(prediction.value, Geometry) and isinstance(
150-
ground_truth.value, Geometry):
151-
if isinstance(prediction.value, (Line, Point)):
152-
score = _polygon_iou(prediction.value.shapely.buffer(buffer),
153-
ground_truth.value.shapely.buffer(buffer))
154-
else:
155-
score = _polygon_iou(prediction.value.shapely,
156-
ground_truth.value.shapely)
157-
pairs.append((prediction, ground_truth, score))
158-
return pairs
159-
160-
def _get_mask_pairs(
161-
ground_truths: List[ObjectAnnotation],
162-
predictions: List[ObjectAnnotation]
163-
) -> List[Tuple[ObjectAnnotation, ObjectAnnotation, ScalarMetricValue]]:
164-
"""
165-
# Get iou score for all pairs of ground truths and predictions
166-
"""
167-
pairs = []
168-
for prediction, ground_truth in product(predictions, ground_truths):
169-
if isinstance(prediction.value, Mask) and isinstance(
170-
ground_truth.value, Mask):
171-
score = _mask_iou(prediction.value.draw(color = 1),
172-
ground_truth.value.draw(color = 1))
173-
pairs.append((prediction, ground_truth, score))
174-
return pairs
175-
176-
def _polygon_iou(poly1: Polygon, poly2: Polygon) -> ScalarMetricValue:
177-
"""Computes iou between two shapely polygons."""
178-
if poly1.intersects(poly2):
179-
return poly1.intersection(poly2).area / poly1.union(poly2).area
180-
return 0.
181-
182-
183-
def radio_confusion_matrix(ground_truth: Radio, prediction: Radio) -> ScalarMetricValue:
142+
def radio_confusion_matrix(ground_truth: Radio,
143+
prediction: Radio) -> ConfusionMatrixMetricValue:
184144
"""
185145
Calculates confusion between ground truth and predicted radio values
186146
187147
The way we are calculating confusion matrix metrics:
188148
- TNs aren't defined because we don't know how many other classes exist ... etc
189149
190-
We treat each example as 1 vs all
150+
When P == L, then we get [1,0,0,0]
151+
when P != L, we get [0,1,0,1]
191152
153+
This is because we are aggregating the stats for the entire radio. Not for each class.
154+
Since we are not tracking TNs (P == L) only adds to TP.
155+
We are not tracking TNs because the number of TNs is equal to the number of classes which we do not know
156+
from just looking at the predictions and labels. Also TNs are necessary for precision/recall/f1.
192157
"""
193158
key = get_identifying_key([prediction.answer], [ground_truth.answer])
194159
prediction_id = getattr(prediction.answer, key)
195160
ground_truth_id = getattr(ground_truth.answer, key)
196161

162+
if prediction_id == ground_truth_id:
163+
return [1, 0, 0, 0]
164+
else:
165+
return [0, 1, 0, 1]
197166

198167

199-
200-
return float(getattr(prediction.answer, key) ==
201-
getattr(ground_truth.answer, key))
202-
203-
204-
205-
206-
def checklist_confusion_matrix(ground_truth: Checklist, prediction: Checklist) -> ScalarMetricValue:
168+
def checklist_confusion_matrix(
169+
ground_truth: Checklist,
170+
prediction: Checklist) -> ConfusionMatrixMetricValue:
207171
"""
208172
Calculates agreement between ground truth and predicted checklist items
173+
174+
Also not tracking TNs
209175
"""
210176
key = get_identifying_key(prediction.answer, ground_truth.answer)
211177
schema_ids_pred = {getattr(answer, key) for answer in prediction.answer}
212-
schema_ids_label = {
213-
getattr(answer, key) for answer in ground_truth.answer
214-
}
215-
return float(
216-
len(schema_ids_label & schema_ids_pred) /
217-
len(schema_ids_label | schema_ids_pred))
178+
schema_ids_label = {getattr(answer, key) for answer in ground_truth.answer}
179+
agree = schema_ids_label & schema_ids_pred
180+
all_selected = schema_ids_label | schema_ids_pred
181+
disagree = all_selected.difference(agree)
182+
fps = len({x for x in disagree if x in schema_ids_pred})
183+
fns = len({x for x in disagree if x in schema_ids_label})
184+
tps = len(agree)
185+
return [tps, fps, 0, fns]
218186

219187

220-
def mask_confusion_matrix(ground_truths: List[ObjectAnnotation],
221-
predictions: List[ObjectAnnotation], iou, include_subclasses: bool) -> Optional[ScalarMetricValue]:
188+
def mask_confusion_matrix(
189+
ground_truths: List[ObjectAnnotation],
190+
predictions: List[ObjectAnnotation], iou,
191+
include_subclasses: bool) -> Optional[ScalarMetricValue]:
222192
"""
223193
Computes iou score for all features with the same feature schema id.
224194
Calculation includes subclassifications.
@@ -229,9 +199,9 @@ def mask_confusion_matrix(ground_truths: List[ObjectAnnotation],
229199
Returns:
230200
float representing the iou score for the masks
231201
"""
232-
if _no_matching_annotations(ground_truths, predictions):
233-
return [0,int(len(predictions) > 0), 0, int(len(ground_truths) > 0)]
234-
elif _no_annotations(ground_truths, predictions):
202+
if has_no_matching_annotations(ground_truths, predictions):
203+
return [0, int(len(predictions) > 0), 0, int(len(ground_truths) > 0)]
204+
elif has_no_annotations(ground_truths, predictions):
235205
return None
236206

237207
if include_subclasses:
@@ -240,7 +210,8 @@ def mask_confusion_matrix(ground_truths: List[ObjectAnnotation],
240210
# Otherwise this will flatten the masks.
241211
# TODO: Make this more apprent in the configuration.
242212
pairs = _get_mask_pairs(ground_truths, predictions)
243-
return object_pair_confusion_matrix(pairs, iou, include_subclasses=include_subclasses)
213+
return object_pair_confusion_matrix(
214+
pairs, iou, include_subclasses=include_subclasses)
244215

245216
prediction_np = np.max([pred.value.draw(color=1) for pred in predictions],
246217
axis=0)
@@ -253,27 +224,7 @@ def mask_confusion_matrix(ground_truths: List[ObjectAnnotation],
253224
f" Found {prediction_np.shape}/{ground_truth_np.shape}.")
254225

255226
tp_mask = prediction_np == ground_truth_np == 1
256-
fp_mask = (prediction_np == 1) & (ground_truth_np==0)
257-
fn_mask = (prediction_np == 0) & (ground_truth_np==1)
227+
fp_mask = (prediction_np == 1) & (ground_truth_np == 0)
228+
fn_mask = (prediction_np == 0) & (ground_truth_np == 1)
258229
tn_mask = prediction_np == ground_truth_np == 0
259230
return [np.sum(tp_mask), np.sum(fp_mask), np.sum(fn_mask), np.sum(tn_mask)]
260-
261-
262-
263-
def _no_matching_annotations(ground_truths: List[ObjectAnnotation],
264-
predictions: List[ObjectAnnotation]):
265-
if len(ground_truths) and not len(predictions):
266-
# No existing predictions but existing ground truths means no matches.
267-
return True
268-
elif not len(ground_truths) and len(predictions):
269-
# No ground truth annotations but there are predictions means no matches
270-
return True
271-
return False
272-
273-
274-
def _no_annotations(ground_truths: List[ObjectAnnotation],
275-
predictions: List[ObjectAnnotation]):
276-
return not len(ground_truths) and not len(predictions)
277-
278-
279-

0 commit comments

Comments
 (0)