Skip to content

Commit 885e35b

Browse files
committed
addition to tests and update to existing iou calculations
1 parent 60e9f46 commit 885e35b

File tree

3 files changed

+153
-43
lines changed

3 files changed

+153
-43
lines changed

labelbox/data/metrics/iou/calculation.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44
from shapely.geometry import Polygon
55
import numpy as np
66

7-
from labelbox.data.annotation_types.ner import TextEntity
8-
97
from ..group import get_feature_pairs, get_identifying_key, has_no_annotations, has_no_matching_annotations
108
from ...annotation_types import (ObjectAnnotation, ClassificationAnnotation,
119
Mask, Geometry, Point, Line, Checklist, Text,
12-
Radio, ScalarMetricValue)
10+
TextEntity, Radio, ScalarMetricValue)
1311

1412

1513
def miou(ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]],
@@ -63,6 +61,8 @@ def feature_miou(ground_truths: List[Union[ObjectAnnotation,
6361
return vector_miou(ground_truths, predictions, include_subclasses)
6462
elif isinstance(predictions[0], ClassificationAnnotation):
6563
return classification_miou(ground_truths, predictions)
64+
elif isinstance(predictions[0].value, TextEntity):
65+
return ner_miou(ground_truths, predictions, include_subclasses)
6666
else:
6767
raise ValueError(
6868
f"Unexpected annotation found. Found {type(predictions[0].value)}")
@@ -293,4 +293,29 @@ def _ner_iou(ner1: TextEntity, ner2: TextEntity):
293293
#edge case of only one character in text
294294
if union_start == union_end:
295295
return 1
296-
return (intersection_end - intersection_start) / (union_end - union_start)
296+
#if there is no intersection
297+
if intersection_start > intersection_end:
298+
return 0
299+
return (intersection_end - intersection_start) / (union_end - union_start)
300+
301+
302+
def ner_miou(ground_truths: List[ObjectAnnotation],
303+
predictions: List[ObjectAnnotation],
304+
include_subclasses: bool) -> Optional[ScalarMetricValue]:
305+
"""
306+
Computes iou score for all features with the same feature schema id.
307+
Calculation includes subclassifications.
308+
309+
Args:
310+
ground_truths: List of ground truth ner annotations
311+
predictions: List of prediction ner annotations
312+
Returns:
313+
float representing the iou score for the feature type.
314+
If there are no matches then this returns none
315+
"""
316+
if has_no_matching_annotations(ground_truths, predictions):
317+
return 0.
318+
elif has_no_annotations(ground_truths, predictions):
319+
return None
320+
pairs = _get_ner_pairs(ground_truths, predictions)
321+
return object_pair_miou(pairs, include_subclasses)

tests/data/metrics/iou/data_row/conftest.py

Lines changed: 117 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -320,45 +320,42 @@ def empty_radio_prediction():
320320

321321
@pytest.fixture
322322
def matching_checklist():
323-
return NameSpace(
324-
labels=[],
325-
classifications=[{
326-
'featureId':
327-
'1234567890111213141516171',
328-
'schemaId':
329-
'ckppid25v0000aeyjmxfwlc7t',
330-
'uuid':
331-
'76e0dcea-fe46-43e5-95f5-a5e3f378520a',
332-
'schemaId':
333-
'ckppid25v0000aeyjmxfwlc7t',
334-
'answers': [{
335-
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
336-
}, {
337-
'schemaId': 'ckppide010001aeyj0yhiaghc'
338-
}, {
339-
'schemaId': 'ckppidq4u0002aeyjmcc4toxw'
340-
}]
341-
}],
342-
predictions=[{
343-
'uuid':
344-
'76e0dcea-fe46-43e5-95f5-a5e3f378520a',
345-
'schemaId':
346-
'ckppid25v0000aeyjmxfwlc7t',
347-
'dataRow': {
348-
'id': 'ckppihxc10005aeyjen11h7jh'
349-
},
350-
'answers': [{
351-
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
352-
}, {
353-
'schemaId': 'ckppide010001aeyj0yhiaghc'
354-
}, {
355-
'schemaId': 'ckppidq4u0002aeyjmcc4toxw'
356-
}]
357-
}],
358-
data_row_expected=1.,
359-
# expected = [1.]
360-
# expected=[1., 1., 1.])
361-
expected={1.0: 3})
323+
return NameSpace(labels=[],
324+
classifications=[{
325+
'featureId':
326+
'1234567890111213141516171',
327+
'schemaId':
328+
'ckppid25v0000aeyjmxfwlc7t',
329+
'uuid':
330+
'76e0dcea-fe46-43e5-95f5-a5e3f378520a',
331+
'schemaId':
332+
'ckppid25v0000aeyjmxfwlc7t',
333+
'answers': [{
334+
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
335+
}, {
336+
'schemaId': 'ckppide010001aeyj0yhiaghc'
337+
}, {
338+
'schemaId': 'ckppidq4u0002aeyjmcc4toxw'
339+
}]
340+
}],
341+
predictions=[{
342+
'uuid':
343+
'76e0dcea-fe46-43e5-95f5-a5e3f378520a',
344+
'schemaId':
345+
'ckppid25v0000aeyjmxfwlc7t',
346+
'dataRow': {
347+
'id': 'ckppihxc10005aeyjen11h7jh'
348+
},
349+
'answers': [{
350+
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
351+
}, {
352+
'schemaId': 'ckppide010001aeyj0yhiaghc'
353+
}, {
354+
'schemaId': 'ckppidq4u0002aeyjmcc4toxw'
355+
}]
356+
}],
357+
data_row_expected=1.,
358+
expected={1.0: 3})
362359

363360

364361
@pytest.fixture
@@ -699,3 +696,84 @@ def point_pair():
699696
}
700697
}],
701698
expected=0.879113232477017)
699+
700+
701+
@pytest.fixture
702+
def matching_ner():
703+
return NameSpace(labels=[{
704+
'featureId': 'ckppivl7p0006aeyj92cezr9d',
705+
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
706+
'format': "text.location",
707+
'data': {
708+
"location": {
709+
"start": 0,
710+
"end": 10
711+
}
712+
}
713+
}],
714+
predictions=[{
715+
'dataRow': {
716+
'id': 'ckppihxc10005aeyjen11h7jh'
717+
},
718+
'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a',
719+
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
720+
"location": {
721+
"start": 0,
722+
"end": 10
723+
}
724+
}],
725+
expected=1)
726+
727+
728+
@pytest.fixture
729+
def no_matching_ner():
730+
return NameSpace(labels=[{
731+
'featureId': 'ckppivl7p0006aeyj92cezr9d',
732+
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
733+
'format': "text.location",
734+
'data': {
735+
"location": {
736+
"start": 0,
737+
"end": 5
738+
}
739+
}
740+
}],
741+
predictions=[{
742+
'dataRow': {
743+
'id': 'ckppihxc10005aeyjen11h7jh'
744+
},
745+
'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a',
746+
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
747+
"location": {
748+
"start": 5,
749+
"end": 10
750+
}
751+
}],
752+
expected=0)
753+
754+
755+
@pytest.fixture
756+
def partial_matching_ner():
757+
return NameSpace(labels=[{
758+
'featureId': 'ckppivl7p0006aeyj92cezr9d',
759+
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
760+
'format': "text.location",
761+
'data': {
762+
"location": {
763+
"start": 0,
764+
"end": 7
765+
}
766+
}
767+
}],
768+
predictions=[{
769+
'dataRow': {
770+
'id': 'ckppihxc10005aeyjen11h7jh'
771+
},
772+
'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a',
773+
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
774+
"location": {
775+
"start": 3,
776+
"end": 5
777+
}
778+
}],
779+
expected=0.2857142857142857)

tests/data/metrics/iou/data_row/test_data_row_iou.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,10 @@ def test_vector_with_subclass(pair):
115115
@parametrize("pair", strings_to_fixtures(["point_pair", "line_pair"]))
116116
def test_others(pair):
117117
check_iou(pair)
118+
119+
120+
@parametrize("pair",
121+
strings_to_fixtures(
122+
["matching_ner", "no_matching_ner", "partial_matching_ner"]))
123+
def test_ner(pair):
124+
check_iou(pair)

0 commit comments

Comments
 (0)