Skip to content

Commit 60e9f46

Browse files
committed
added in test cases and tested against subclasses as well as accounted for divide by 0
1 parent d406e06 commit 60e9f46

File tree

5 files changed

+24
-8
lines changed

5 files changed

+24
-8
lines changed

labelbox/data/metrics/confusion_matrix/calculation.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ def feature_confusion_matrix(
6969
return vector_confusion_matrix(ground_truths, predictions,
7070
include_subclasses, iou)
7171
elif isinstance(predictions[0].value, TextEntity):
72-
pass #TODO
72+
return ner_confusion_matrix(ground_truths, predictions,
73+
include_subclasses, iou)
7374
elif isinstance(predictions[0], ClassificationAnnotation):
7475
return classification_confusion_matrix(ground_truths, predictions)
7576
else:
@@ -293,12 +294,10 @@ def mask_confusion_matrix(ground_truths: List[ObjectAnnotation],
293294

294295

295296
def ner_confusion_matrix(ground_truths: List[ObjectAnnotation],
296-
predictions: list[ObjectAnnotation],
297+
predictions: List[ObjectAnnotation],
297298
include_subclasses: bool,
298299
iou: float) -> Optional[ConfusionMatrixMetricValue]:
299300
"""Computes confusion matrix metric between two lists of TextEntity objects
300-
301-
TODO: work on include_subclasses logic
302301
303302
Args:
304303
ground_truths: List of ground truth mask annotations

labelbox/data/metrics/iou/calculation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,6 @@ def _get_ner_pairs(
281281
for ground_truth, prediction in product(ground_truths, predictions):
282282
score = _ner_iou(ground_truth.value, prediction.value)
283283
pairs.append((ground_truth, prediction, score))
284-
# print(ground_truth.value.start, ground_truth.value.end,
285-
# prediction.value.start, prediction.value.end)
286284
return pairs
287285

288286

@@ -292,4 +290,7 @@ def _ner_iou(ner1: TextEntity, ner2: TextEntity):
292290
ner1.end, ner2.end)
293291
union_start, union_end = min(ner1.start,
294292
ner2.start), max(ner1.end, ner2.end)
293+
#edge case of only one character in text
294+
if union_start == union_end:
295+
return 1
295296
return (intersection_end - intersection_start) / (union_end - union_start)

tests/data/metrics/confusion_matrix/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from labelbox.data.annotation_types import Polygon, Point, Rectangle, Mask, MaskData, Line, Radio, Text, Checklist, ClassificationAnswer
77
import numpy as np
88

9+
from labelbox.data.annotation_types.ner import TextEntity
10+
911

1012
class NameSpace(SimpleNamespace):
1113

@@ -84,6 +86,13 @@ def get_checklist(name, answer_names):
8486
]))
8587

8688

89+
def get_ner(name, start, end, subclasses=None):
90+
return ObjectAnnotation(
91+
name=name,
92+
value=TextEntity(start=start, end=end),
93+
classifications=[] if subclasses is None else subclasses)
94+
95+
8796
def get_object_pairs(tool_fn, **kwargs):
8897
return [
8998
NameSpace(predictions=[tool_fn("cat", **kwargs)],
@@ -326,6 +335,11 @@ def point_pairs():
326335
return get_object_pairs(get_point, x=0, y=0)
327336

328337

338+
@pytest.fixture
339+
def ner_pairs():
340+
return get_object_pairs(get_ner, start=0, end=10)
341+
342+
329343
@pytest.fixture()
330344
def pair_iou_thresholds():
331345
return [

tests/data/metrics/confusion_matrix/test_confusion_matrix_data_row.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
fixture_ref('rectangle_pairs'),
1010
fixture_ref('mask_pairs'),
1111
fixture_ref('line_pairs'),
12-
fixture_ref('point_pairs')
12+
fixture_ref('point_pairs'),
13+
fixture_ref('ner_pairs')
1314
])
1415
def test_overlapping_objects(tool_examples):
1516
for example in tool_examples:

tests/data/metrics/confusion_matrix/test_confusion_matrix_feature.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
fixture_ref('rectangle_pairs'),
1010
fixture_ref('mask_pairs'),
1111
fixture_ref('line_pairs'),
12-
fixture_ref('point_pairs')
12+
fixture_ref('point_pairs'),
13+
fixture_ref('ner_pairs')
1314
])
1415
def test_overlapping_objects(tool_examples):
1516
for example in tool_examples:

0 commit comments

Comments
 (0)