Skip to content

Commit 526fb82

Browse files
authored
Merge pull request #486 from Labelbox/al-1723
[AL-1723] NER Confusion Matrix
2 parents 6b5d9e5 + 885e35b commit 526fb82

File tree

7 files changed

+218
-44
lines changed

7 files changed

+218
-44
lines changed

labelbox/data/metrics/confusion_matrix/calculation.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import numpy as np
44

5-
from ..iou.calculation import _get_mask_pairs, _get_vector_pairs, miou
5+
from ..iou.calculation import _get_mask_pairs, _get_vector_pairs, _get_ner_pairs, miou
66
from ...annotation_types import (ObjectAnnotation, ClassificationAnnotation,
7-
Mask, Geometry, Checklist, Radio,
7+
Mask, Geometry, Checklist, Radio, TextEntity,
88
ScalarMetricValue, ConfusionMatrixMetricValue)
99
from ..group import (get_feature_pairs, get_identifying_key, has_no_annotations,
1010
has_no_matching_annotations)
@@ -68,6 +68,9 @@ def feature_confusion_matrix(
6868
elif isinstance(predictions[0].value, Geometry):
6969
return vector_confusion_matrix(ground_truths, predictions,
7070
include_subclasses, iou)
71+
elif isinstance(predictions[0].value, TextEntity):
72+
return ner_confusion_matrix(ground_truths, predictions,
73+
include_subclasses, iou)
7174
elif isinstance(predictions[0], ClassificationAnnotation):
7275
return classification_confusion_matrix(ground_truths, predictions)
7376
else:
@@ -288,3 +291,23 @@ def mask_confusion_matrix(ground_truths: List[ObjectAnnotation],
288291
fn_mask = (prediction_np == 0) & (ground_truth_np == 1)
289292
tn_mask = prediction_np == ground_truth_np == 0
290293
return [np.sum(tp_mask), np.sum(fp_mask), np.sum(fn_mask), np.sum(tn_mask)]
294+
295+
296+
def ner_confusion_matrix(ground_truths: List[ObjectAnnotation],
297+
predictions: List[ObjectAnnotation],
298+
include_subclasses: bool,
299+
iou: float) -> Optional[ConfusionMatrixMetricValue]:
300+
"""Computes confusion matrix metric between two lists of TextEntity objects
301+
302+
Args:
303+
ground_truths: List of ground truth mask annotations
304+
predictions: List of prediction mask annotations
305+
Returns:
306+
confusion matrix as a list: [TP,FP,TN,FN]
307+
"""
308+
if has_no_matching_annotations(ground_truths, predictions):
309+
return [0, int(len(predictions) > 0), 0, int(len(ground_truths) > 0)]
310+
elif has_no_annotations(ground_truths, predictions):
311+
return None
312+
pairs = _get_ner_pairs(ground_truths, predictions)
313+
return object_pair_confusion_matrix(pairs, include_subclasses, iou)

labelbox/data/metrics/iou/calculation.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ..group import get_feature_pairs, get_identifying_key, has_no_annotations, has_no_matching_annotations
88
from ...annotation_types import (ObjectAnnotation, ClassificationAnnotation,
99
Mask, Geometry, Point, Line, Checklist, Text,
10-
Radio, ScalarMetricValue)
10+
TextEntity, Radio, ScalarMetricValue)
1111

1212

1313
def miou(ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]],
@@ -61,6 +61,8 @@ def feature_miou(ground_truths: List[Union[ObjectAnnotation,
6161
return vector_miou(ground_truths, predictions, include_subclasses)
6262
elif isinstance(predictions[0], ClassificationAnnotation):
6363
return classification_miou(ground_truths, predictions)
64+
elif isinstance(predictions[0].value, TextEntity):
65+
return ner_miou(ground_truths, predictions, include_subclasses)
6466
else:
6567
raise ValueError(
6668
f"Unexpected annotation found. Found {type(predictions[0].value)}")
@@ -269,3 +271,51 @@ def _ensure_valid_poly(poly):
269271
def _mask_iou(mask1: np.ndarray, mask2: np.ndarray) -> ScalarMetricValue:
270272
"""Computes iou between two binary segmentation masks."""
271273
return np.sum(mask1 & mask2) / np.sum(mask1 | mask2)
274+
275+
276+
def _get_ner_pairs(
277+
ground_truths: List[ObjectAnnotation], predictions: List[ObjectAnnotation]
278+
) -> List[Tuple[ObjectAnnotation, ObjectAnnotation, ScalarMetricValue]]:
279+
"""Get iou score for all possible pairs of ground truths and predictions"""
280+
pairs = []
281+
for ground_truth, prediction in product(ground_truths, predictions):
282+
score = _ner_iou(ground_truth.value, prediction.value)
283+
pairs.append((ground_truth, prediction, score))
284+
return pairs
285+
286+
287+
def _ner_iou(ner1: TextEntity, ner2: TextEntity):
288+
"""Computes iou between two text entity annotations"""
289+
intersection_start, intersection_end = max(ner1.start, ner2.start), min(
290+
ner1.end, ner2.end)
291+
union_start, union_end = min(ner1.start,
292+
ner2.start), max(ner1.end, ner2.end)
293+
#edge case of only one character in text
294+
if union_start == union_end:
295+
return 1
296+
#if there is no intersection
297+
if intersection_start > intersection_end:
298+
return 0
299+
return (intersection_end - intersection_start) / (union_end - union_start)
300+
301+
302+
def ner_miou(ground_truths: List[ObjectAnnotation],
303+
predictions: List[ObjectAnnotation],
304+
include_subclasses: bool) -> Optional[ScalarMetricValue]:
305+
"""
306+
Computes iou score for all features with the same feature schema id.
307+
Calculation includes subclassifications.
308+
309+
Args:
310+
ground_truths: List of ground truth ner annotations
311+
predictions: List of prediction ner annotations
312+
Returns:
313+
float representing the iou score for the feature type.
314+
If there are no matches then this returns none
315+
"""
316+
if has_no_matching_annotations(ground_truths, predictions):
317+
return 0.
318+
elif has_no_annotations(ground_truths, predictions):
319+
return None
320+
pairs = _get_ner_pairs(ground_truths, predictions)
321+
return object_pair_miou(pairs, include_subclasses)

tests/data/metrics/confusion_matrix/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from labelbox.data.annotation_types import Polygon, Point, Rectangle, Mask, MaskData, Line, Radio, Text, Checklist, ClassificationAnswer
77
import numpy as np
88

9+
from labelbox.data.annotation_types.ner import TextEntity
10+
911

1012
class NameSpace(SimpleNamespace):
1113

@@ -84,6 +86,13 @@ def get_checklist(name, answer_names):
8486
]))
8587

8688

89+
def get_ner(name, start, end, subclasses=None):
90+
return ObjectAnnotation(
91+
name=name,
92+
value=TextEntity(start=start, end=end),
93+
classifications=[] if subclasses is None else subclasses)
94+
95+
8796
def get_object_pairs(tool_fn, **kwargs):
8897
return [
8998
NameSpace(predictions=[tool_fn("cat", **kwargs)],
@@ -326,6 +335,11 @@ def point_pairs():
326335
return get_object_pairs(get_point, x=0, y=0)
327336

328337

338+
@pytest.fixture
339+
def ner_pairs():
340+
return get_object_pairs(get_ner, start=0, end=10)
341+
342+
329343
@pytest.fixture()
330344
def pair_iou_thresholds():
331345
return [

tests/data/metrics/confusion_matrix/test_confusion_matrix_data_row.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
fixture_ref('rectangle_pairs'),
1010
fixture_ref('mask_pairs'),
1111
fixture_ref('line_pairs'),
12-
fixture_ref('point_pairs')
12+
fixture_ref('point_pairs'),
13+
fixture_ref('ner_pairs')
1314
])
1415
def test_overlapping_objects(tool_examples):
1516
for example in tool_examples:

tests/data/metrics/confusion_matrix/test_confusion_matrix_feature.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
fixture_ref('rectangle_pairs'),
1010
fixture_ref('mask_pairs'),
1111
fixture_ref('line_pairs'),
12-
fixture_ref('point_pairs')
12+
fixture_ref('point_pairs'),
13+
fixture_ref('ner_pairs')
1314
])
1415
def test_overlapping_objects(tool_examples):
1516
for example in tool_examples:

tests/data/metrics/iou/data_row/conftest.py

Lines changed: 117 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -320,45 +320,42 @@ def empty_radio_prediction():
320320

321321
@pytest.fixture
322322
def matching_checklist():
323-
return NameSpace(
324-
labels=[],
325-
classifications=[{
326-
'featureId':
327-
'1234567890111213141516171',
328-
'schemaId':
329-
'ckppid25v0000aeyjmxfwlc7t',
330-
'uuid':
331-
'76e0dcea-fe46-43e5-95f5-a5e3f378520a',
332-
'schemaId':
333-
'ckppid25v0000aeyjmxfwlc7t',
334-
'answers': [{
335-
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
336-
}, {
337-
'schemaId': 'ckppide010001aeyj0yhiaghc'
338-
}, {
339-
'schemaId': 'ckppidq4u0002aeyjmcc4toxw'
340-
}]
341-
}],
342-
predictions=[{
343-
'uuid':
344-
'76e0dcea-fe46-43e5-95f5-a5e3f378520a',
345-
'schemaId':
346-
'ckppid25v0000aeyjmxfwlc7t',
347-
'dataRow': {
348-
'id': 'ckppihxc10005aeyjen11h7jh'
349-
},
350-
'answers': [{
351-
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
352-
}, {
353-
'schemaId': 'ckppide010001aeyj0yhiaghc'
354-
}, {
355-
'schemaId': 'ckppidq4u0002aeyjmcc4toxw'
356-
}]
357-
}],
358-
data_row_expected=1.,
359-
# expected = [1.]
360-
# expected=[1., 1., 1.])
361-
expected={1.0: 3})
323+
return NameSpace(labels=[],
324+
classifications=[{
325+
'featureId':
326+
'1234567890111213141516171',
327+
'schemaId':
328+
'ckppid25v0000aeyjmxfwlc7t',
329+
'uuid':
330+
'76e0dcea-fe46-43e5-95f5-a5e3f378520a',
331+
'schemaId':
332+
'ckppid25v0000aeyjmxfwlc7t',
333+
'answers': [{
334+
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
335+
}, {
336+
'schemaId': 'ckppide010001aeyj0yhiaghc'
337+
}, {
338+
'schemaId': 'ckppidq4u0002aeyjmcc4toxw'
339+
}]
340+
}],
341+
predictions=[{
342+
'uuid':
343+
'76e0dcea-fe46-43e5-95f5-a5e3f378520a',
344+
'schemaId':
345+
'ckppid25v0000aeyjmxfwlc7t',
346+
'dataRow': {
347+
'id': 'ckppihxc10005aeyjen11h7jh'
348+
},
349+
'answers': [{
350+
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
351+
}, {
352+
'schemaId': 'ckppide010001aeyj0yhiaghc'
353+
}, {
354+
'schemaId': 'ckppidq4u0002aeyjmcc4toxw'
355+
}]
356+
}],
357+
data_row_expected=1.,
358+
expected={1.0: 3})
362359

363360

364361
@pytest.fixture
@@ -699,3 +696,84 @@ def point_pair():
699696
}
700697
}],
701698
expected=0.879113232477017)
699+
700+
701+
@pytest.fixture
702+
def matching_ner():
703+
return NameSpace(labels=[{
704+
'featureId': 'ckppivl7p0006aeyj92cezr9d',
705+
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
706+
'format': "text.location",
707+
'data': {
708+
"location": {
709+
"start": 0,
710+
"end": 10
711+
}
712+
}
713+
}],
714+
predictions=[{
715+
'dataRow': {
716+
'id': 'ckppihxc10005aeyjen11h7jh'
717+
},
718+
'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a',
719+
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
720+
"location": {
721+
"start": 0,
722+
"end": 10
723+
}
724+
}],
725+
expected=1)
726+
727+
728+
@pytest.fixture
729+
def no_matching_ner():
730+
return NameSpace(labels=[{
731+
'featureId': 'ckppivl7p0006aeyj92cezr9d',
732+
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
733+
'format': "text.location",
734+
'data': {
735+
"location": {
736+
"start": 0,
737+
"end": 5
738+
}
739+
}
740+
}],
741+
predictions=[{
742+
'dataRow': {
743+
'id': 'ckppihxc10005aeyjen11h7jh'
744+
},
745+
'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a',
746+
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
747+
"location": {
748+
"start": 5,
749+
"end": 10
750+
}
751+
}],
752+
expected=0)
753+
754+
755+
@pytest.fixture
756+
def partial_matching_ner():
757+
return NameSpace(labels=[{
758+
'featureId': 'ckppivl7p0006aeyj92cezr9d',
759+
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
760+
'format': "text.location",
761+
'data': {
762+
"location": {
763+
"start": 0,
764+
"end": 7
765+
}
766+
}
767+
}],
768+
predictions=[{
769+
'dataRow': {
770+
'id': 'ckppihxc10005aeyjen11h7jh'
771+
},
772+
'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a',
773+
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
774+
"location": {
775+
"start": 3,
776+
"end": 5
777+
}
778+
}],
779+
expected=0.2857142857142857)

tests/data/metrics/iou/data_row/test_data_row_iou.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,10 @@ def test_vector_with_subclass(pair):
115115
@parametrize("pair", strings_to_fixtures(["point_pair", "line_pair"]))
116116
def test_others(pair):
117117
check_iou(pair)
118+
119+
120+
@parametrize("pair",
121+
strings_to_fixtures(
122+
["matching_ner", "no_matching_ner", "partial_matching_ner"]))
123+
def test_ner(pair):
124+
check_iou(pair)

0 commit comments

Comments
 (0)