Skip to content

Commit b0d884f

Browse files
author
Matt Sokoloff
committed
fix gt and pred order
1 parent ab01346 commit b0d884f

File tree

7 files changed

+23
-22
lines changed

7 files changed

+23
-22
lines changed

labelbox/data/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .confusion_matrix import confusion_matrix_metric, feature_confusion_matrix_metric
2+
from .iou import miou_metric, feature_miou_metric
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .calculation import *
2+
from .confusion_matrix import *

labelbox/data/metrics/confusion_matrix/calculation.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def confusion_matrix(ground_truths: List[Union[ObjectAnnotation,
3131
Returns None if there are no annotations in ground_truth or prediction annotations
3232
"""
3333

34-
annotation_pairs = get_feature_pairs(predictions, ground_truths)
34+
annotation_pairs = get_feature_pairs(ground_truths, predictions)
3535
conf_matrix = [
3636
feature_confusion_matrix(annotation_pair[0], annotation_pair[1],
3737
include_subclasses, iou)
@@ -166,7 +166,7 @@ def object_pair_confusion_matrix(pairs: List[Tuple[ObjectAnnotation,
166166
matched_predictions = set()
167167
matched_ground_truths = set()
168168

169-
for prediction, ground_truth, agreement in pairs:
169+
for ground_truth, prediction, agreement in pairs:
170170
prediction_id = id(prediction)
171171
ground_truth_id = id(ground_truth)
172172
prediction_ids.add(prediction_id)
@@ -177,14 +177,15 @@ def object_pair_confusion_matrix(pairs: List[Tuple[ObjectAnnotation,
177177
ground_truth_id not in matched_ground_truths:
178178
if include_subclasses and (ground_truth.classifications or
179179
prediction.classifications):
180-
if miou(prediction.classifications,
181-
ground_truth.classifications,
180+
if miou(ground_truth.classifications,
181+
prediction.classifications,
182182
include_subclasses=False) < 1.:
183183
# Incorrect if the subclasses don't 100% agree then there is no match
184184
continue
185185
matched_predictions.add(prediction_id)
186186
matched_ground_truths.add(ground_truth_id)
187187
tps = len(matched_ground_truths)
188+
188189
fps = len(prediction_ids.difference(matched_predictions))
189190
fns = len(ground_truth_ids.difference(matched_ground_truths))
190191
# Not defined for object detection.
@@ -210,11 +211,7 @@ def radio_confusion_matrix(ground_truth: Radio,
210211
key = get_identifying_key([prediction.answer], [ground_truth.answer])
211212
prediction_id = getattr(prediction.answer, key)
212213
ground_truth_id = getattr(ground_truth.answer, key)
213-
214-
if prediction_id == ground_truth_id:
215-
return [1, 0, 0, 0]
216-
else:
217-
return [0, 1, 0, 1]
214+
return [1, 0, 0, 0] if prediction_id == ground_truth_id else [0, 1, 0, 1]
218215

219216

220217
def checklist_confusion_matrix(

labelbox/data/metrics/confusion_matrix/confusion_matrix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def feature_confusion_matrix_metric(
6666
There will be one metric for each class in the union of ground truth and prediction classes.
6767
"""
6868
# Classifications are supported because we just take a naive approach to them..
69-
annotation_pairs = get_feature_pairs(predictions, ground_truths)
69+
annotation_pairs = get_feature_pairs(ground_truths, predictions)
7070
metrics = []
7171
for key in annotation_pairs:
7272
value = feature_confusion_matrix(annotation_pairs[key][0],

labelbox/data/metrics/iou/calculation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def _get_vector_pairs(
222222
# Get iou score for all pairs of ground truths and predictions
223223
"""
224224
pairs = []
225-
for prediction, ground_truth in product(predictions, ground_truths):
225+
for ground_truth, prediction in product(ground_truths, predictions):
226226
if isinstance(prediction.value, Geometry) and isinstance(
227227
ground_truth.value, Geometry):
228228
if isinstance(prediction.value, (Line, Point)):
@@ -232,7 +232,7 @@ def _get_vector_pairs(
232232
else:
233233
score = _polygon_iou(prediction.value.shapely,
234234
ground_truth.value.shapely)
235-
pairs.append((prediction, ground_truth, score))
235+
pairs.append((ground_truth, prediction, score))
236236
return pairs
237237

238238

@@ -243,12 +243,12 @@ def _get_mask_pairs(
243243
# Get iou score for all pairs of ground truths and predictions
244244
"""
245245
pairs = []
246-
for prediction, ground_truth in product(predictions, ground_truths):
246+
for ground_truth, prediction in product(ground_truths, predictions):
247247
if isinstance(prediction.value, Mask) and isinstance(
248248
ground_truth.value, Mask):
249249
score = _mask_iou(prediction.value.draw(color=1),
250250
ground_truth.value.draw(color=1))
251-
pairs.append((prediction, ground_truth, score))
251+
pairs.append((ground_truth, prediction, score))
252252
return pairs
253253

254254

tests/data/metrics/confusion_matrix/test_confusion_matrix_data_row.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
])
1414
def test_overlapping_objects(tool_examples):
1515
for example in tool_examples:
16-
score = confusion_matrix_metric(example.predictions,
17-
example.ground_truths)
16+
score = confusion_matrix_metric(example.ground_truths,
17+
example.predictions)
1818

1919
if len(example.expected) == 0:
2020
assert len(score) == 0
@@ -33,8 +33,8 @@ def test_overlapping_objects(tool_examples):
3333
fixture_ref('radio_pairs')])
3434
def test_overlapping_classifications(tool_examples):
3535
for example in tool_examples:
36-
score = confusion_matrix_metric(example.predictions,
37-
example.ground_truths)
36+
score = confusion_matrix_metric(example.ground_truths,
37+
example.predictions)
3838
if len(example.expected) == 0:
3939
assert len(score) == 0
4040
else:

tests/data/metrics/confusion_matrix/test_confusion_matrix_feature.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
])
1414
def test_overlapping_objects(tool_examples):
1515
for example in tool_examples:
16-
metrics = feature_confusion_matrix_metric(example.predictions,
17-
example.ground_truths)
16+
metrics = feature_confusion_matrix_metric(example.ground_truths,
17+
example.predictions)
1818

1919
metrics = {r.feature_name: list(r.value) for r in metrics}
2020
if len(example.expected) == 0:
@@ -30,8 +30,8 @@ def test_overlapping_objects(tool_examples):
3030
def test_overlapping_classifications(tool_examples):
3131
for example in tool_examples:
3232

33-
metrics = feature_confusion_matrix_metric(example.predictions,
34-
example.ground_truths)
33+
metrics = feature_confusion_matrix_metric(example.ground_truths,
34+
example.predictions)
3535

3636
metrics = {r.feature_name: list(r.value) for r in metrics}
3737
if len(example.expected) == 0:

0 commit comments

Comments
 (0)