Skip to content

Commit 1a0f537

Browse files
authored
Merge pull request #278 from Labelbox/ms/confusion-matrix-metrics
Metric serialization
2 parents 045e7ea + e4fd74e commit 1a0f537

File tree

5 files changed

+101
-36
lines changed

5 files changed

+101
-36
lines changed

labelbox/data/serialization/ndjson/label.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
11
from itertools import groupby
2-
3-
from labelbox.data.annotation_types.metrics import ScalarMetric
4-
52
from operator import itemgetter
63
from typing import Dict, Generator, List, Tuple, Union
74
from collections import defaultdict
@@ -14,13 +11,16 @@
1411
from ...annotation_types.label import Label
1512
from ...annotation_types.ner import TextEntity
1613
from ...annotation_types.classification import Dropdown
17-
from .metric import NDScalarMetric, NDMetricAnnotation
14+
from ...annotation_types.metrics import ScalarMetric, ConfusionMatrixMetric
15+
16+
from .metric import NDScalarMetric, NDMetricAnnotation, NDConfusionMatrixMetric
1817
from .classification import NDChecklistSubclass, NDClassification, NDClassificationType, NDRadioSubclass
1918
from .objects import NDObject, NDObjectType
2019

2120

2221
class NDLabel(BaseModel):
23-
annotations: List[Union[NDObjectType, NDClassificationType, NDScalarMetric]]
22+
annotations: List[Union[NDObjectType, NDClassificationType,
23+
NDConfusionMatrixMetric, NDScalarMetric]]
2424

2525
def to_common(self) -> LabelGenerator:
2626
grouped_annotations = defaultdict(list)
@@ -39,6 +39,7 @@ def from_common(cls,
3939
def _generate_annotations(
4040
self, grouped_annotations: Dict[str, List[Union[NDObjectType,
4141
NDClassificationType,
42+
NDConfusionMatrixMetric,
4243
NDScalarMetric]]]
4344
) -> Generator[Label, None, None]:
4445

@@ -49,7 +50,8 @@ def _generate_annotations(
4950
annots.append(NDObject.to_common(annotation))
5051
elif isinstance(annotation, NDClassificationType.__args__):
5152
annots.extend(NDClassification.to_common(annotation))
52-
elif isinstance(annotation, NDScalarMetric):
53+
elif isinstance(annotation,
54+
(NDScalarMetric, NDConfusionMatrixMetric)):
5355
annots.append(NDMetricAnnotation.to_common(annotation))
5456
else:
5557
raise TypeError(
@@ -113,7 +115,7 @@ def _create_non_video_annotations(cls, label: Label):
113115
yield NDClassification.from_common(annotation, label.data)
114116
elif isinstance(annotation, ObjectAnnotation):
115117
yield NDObject.from_common(annotation, label.data)
116-
elif isinstance(annotation, ScalarMetric):
118+
elif isinstance(annotation, (ScalarMetric, ConfusionMatrixMetric)):
117119
yield NDMetricAnnotation.from_common(annotation, label.data)
118120
else:
119121
raise TypeError(
Lines changed: 71 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,70 @@
1-
from labelbox.data.annotation_types.metrics import ScalarMetricAggregation
2-
from typing import Union, Optional
1+
from typing import Optional, Union, Type
32

43
from labelbox.data.annotation_types.data import ImageData, TextData
5-
from labelbox.data.annotation_types.metrics import ScalarMetric
64
from labelbox.data.serialization.ndjson.base import NDJsonBase
5+
from labelbox.data.annotation_types.metrics.scalar import (
6+
ScalarMetric, ScalarMetricAggregation, ScalarMetricValue,
7+
ScalarMetricConfidenceValue)
8+
from labelbox.data.annotation_types.metrics.confusion_matrix import (
9+
ConfusionMatrixAggregation, ConfusionMatrixMetric,
10+
ConfusionMatrixMetricValue, ConfusionMatrixMetricConfidenceValue)
711

812

9-
class NDScalarMetric(NDJsonBase):
13+
class BaseNDMetric(NDJsonBase):
1014
metric_value: float
11-
metric_name: Optional[str]
1215
feature_name: Optional[str] = None
1316
subclass_name: Optional[str] = None
14-
aggregation: ScalarMetricAggregation = ScalarMetricAggregation.ARITHMETIC_MEAN.value
17+
18+
class Config:
19+
use_enum_values = True
20+
21+
def dict(self, *args, **kwargs):
22+
res = super().dict(*args, **kwargs)
23+
for field in ['featureName', 'subclassName']:
24+
if res[field] is None:
25+
res.pop(field)
26+
return res
27+
28+
29+
class NDConfusionMatrixMetric(BaseNDMetric):
30+
metric_value: Union[ConfusionMatrixMetricValue,
31+
ConfusionMatrixMetricConfidenceValue]
32+
metric_name: str
33+
aggregation: ConfusionMatrixAggregation
34+
35+
def to_common(self) -> ConfusionMatrixMetric:
36+
return ConfusionMatrixMetric(value=self.metric_value,
37+
metric_name=self.metric_name,
38+
feature_name=self.feature_name,
39+
subclass_name=self.subclass_name,
40+
aggregation=self.aggregation,
41+
extra={'uuid': self.uuid})
42+
43+
@classmethod
44+
def from_common(
45+
cls, metric: ConfusionMatrixMetric,
46+
data: Union[TextData, ImageData]) -> "NDConfusionMatrixMetric":
47+
return cls(uuid=metric.extra.get('uuid'),
48+
metric_value=metric.value,
49+
metric_name=metric.metric_name,
50+
feature_name=metric.feature_name,
51+
subclass_name=metric.subclass_name,
52+
aggregation=metric.aggregation,
53+
data_row={'id': data.uid})
54+
55+
56+
class NDScalarMetric(BaseNDMetric):
57+
metric_value: Union[ScalarMetricValue, ScalarMetricConfidenceValue]
58+
metric_name: Optional[str]
59+
aggregation: ScalarMetricAggregation = ScalarMetricAggregation.ARITHMETIC_MEAN
1560

1661
def to_common(self) -> ScalarMetric:
17-
return ScalarMetric(
18-
value=self.metric_value,
19-
metric_name=self.metric_name,
20-
feature_name=self.feature_name,
21-
subclass_name=self.subclass_name,
22-
aggregation=ScalarMetricAggregation[self.aggregation],
23-
extra={'uuid': self.uuid})
62+
return ScalarMetric(value=self.metric_value,
63+
metric_name=self.metric_name,
64+
feature_name=self.feature_name,
65+
subclass_name=self.subclass_name,
66+
aggregation=self.aggregation,
67+
extra={'uuid': self.uuid})
2468

2569
@classmethod
2670
def from_common(cls, metric: ScalarMetric,
@@ -35,38 +79,39 @@ def from_common(cls, metric: ScalarMetric,
3579

3680
def dict(self, *args, **kwargs):
3781
res = super().dict(*args, **kwargs)
38-
for field in ['featureName', 'subclassName']:
39-
if res[field] is None:
40-
res.pop(field)
41-
4282
# For backwards compatibility.
4383
if res['metricName'] is None:
4484
res.pop('metricName')
4585
res.pop('aggregation')
4686
return res
4787

48-
class Config:
49-
use_enum_values = True
50-
5188

5289
class NDMetricAnnotation:
5390

5491
@classmethod
55-
def to_common(cls, annotation: "NDScalarMetric") -> ScalarMetric:
92+
def to_common(
93+
cls, annotation: Union[NDScalarMetric, NDConfusionMatrixMetric]
94+
) -> Union[ScalarMetric, ConfusionMatrixMetric]:
5695
return annotation.to_common()
5796

5897
@classmethod
59-
def from_common(cls, annotation: ScalarMetric,
60-
data: Union[TextData, ImageData]) -> "NDScalarMetric":
98+
def from_common(
99+
cls, annotation: Union[ScalarMetric,
100+
ConfusionMatrixMetric], data: Union[TextData,
101+
ImageData]
102+
) -> Union[NDScalarMetric, NDConfusionMatrixMetric]:
61103
obj = cls.lookup_object(annotation)
62104
return obj.from_common(annotation, data)
63105

64106
@staticmethod
65-
def lookup_object(metric: ScalarMetric) -> "NDScalarMetric":
107+
def lookup_object(
108+
annotation: Union[ScalarMetric, ConfusionMatrixMetric]
109+
) -> Union[Type[NDScalarMetric], Type[NDConfusionMatrixMetric]]:
66110
result = {
67111
ScalarMetric: NDScalarMetric,
68-
}.get(type(metric))
112+
ConfusionMatrixMetric: NDConfusionMatrixMetric,
113+
}.get(type(annotation))
69114
if result is None:
70115
raise TypeError(
71-
f"Unable to convert object to MAL format. `{type(metric)}`")
116+
f"Unable to convert object to MAL format. `{type(annotation)}`")
72117
return result
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[{"uuid" : "a22bbf6e-b2da-4abe-9a11-df84759f7672","dataRow" : {"id": "ckrmdnqj4000007msh9p2a27r"}, "metricValue" : [1,1,2,3], "metricName" : "50%_iou", "featureName" : "sample_class", "subclassName" : "sample_subclass", "aggregation" : "CONFUSION_MATRIX"},
2+
{"uuid" : "a22bbf6e-b2da-4abe-9a11-df84759f7672","dataRow" : {"id": "ckrmdnqj4000007msh9p2a27r"}, "metricValue" : [0,1,2,5], "metricName" : "50%_iou", "featureName" : "sample_class", "aggregation" : "CONFUSION_MATRIX"},
3+
{"uuid" : "a22bbf6e-b2da-4abe-9a11-df84759f7672","dataRow" : {"id": "ckrmdnqj4000007msh9p2a27r"}, "metricValue" : {"0.1" : [0,1,2,3], "0.2" : [5,3,4,3]}, "metricName" : "50%_iou", "aggregation" : "CONFUSION_MATRIX"}]
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[{"uuid" : "a22bbf6e-b2da-4abe-9a11-df84759f7672","dataRow" : {"id": "ckrmdnqj4000007msh9p2a27r"}, "metricValue" : 0.1, "metricName" : "iou", "featureName" : "sample_class", "subclassName" : "sample_subclass", "aggregation" : "SUM"},
22
{"uuid" : "a22bbf6e-b2da-4abe-9a11-df84759f7672","dataRow" : {"id": "ckrmdnqj4000007msh9p2a27r"}, "metricValue" : 0.1, "metricName" : "iou", "featureName" : "sample_class", "aggregation" : "SUM"},
3-
{"uuid" : "a22bbf6e-b2da-4abe-9a11-df84759f7672","dataRow" : {"id": "ckrmdnqj4000007msh9p2a27r"}, "metricValue" : 0.1, "metricName" : "iou", "aggregation" : "SUM"}]
3+
{"uuid" : "a22bbf6e-b2da-4abe-9a11-df84759f7672","dataRow" : {"id": "ckrmdnqj4000007msh9p2a27r"}, "metricValue" : { "0.1" : 0.1, "0.2" : 0.5}, "metricName" : "iou", "aggregation" : "SUM"}]

tests/data/serialization/ndjson/test_metric.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,29 @@ def test_metric():
1616
list(LBV1Converter.serialize(label_list))
1717

1818

19-
def test_custom_metric():
19+
def test_custom_scalar_metric():
2020
with open('tests/data/assets/ndjson/custom_scalar_import.json',
2121
'r') as file:
2222
data = json.load(file)
2323

2424
label_list = NDJsonConverter.deserialize(data).as_list()
2525
reserialized = list(NDJsonConverter.serialize(label_list))
26-
assert reserialized == data
26+
assert json.dumps(reserialized,
27+
sort_keys=True) == json.dumps(data, sort_keys=True)
28+
29+
# Just make sure that this doesn't break
30+
list(LBV1Converter.serialize(label_list))
31+
32+
33+
def test_custom_confusion_matrix_metric():
34+
with open('tests/data/assets/ndjson/custom_confusion_matrix_import.json',
35+
'r') as file:
36+
data = json.load(file)
37+
38+
label_list = NDJsonConverter.deserialize(data).as_list()
39+
reserialized = list(NDJsonConverter.serialize(label_list))
40+
assert json.dumps(reserialized,
41+
sort_keys=True) == json.dumps(data, sort_keys=True)
2742

2843
# Just make sure that this doesn't break
2944
list(LBV1Converter.serialize(label_list))

0 commit comments

Comments
 (0)