Skip to content

Commit 5d57910

Browse files
author
Matt Sokoloff
committed
tests passing
1 parent e75e5d8 commit 5d57910

File tree

7 files changed

+62
-23
lines changed

7 files changed

+62
-23
lines changed

labelbox/data/annotation_types/metrics/scalar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class ScalarMetric(BaseModel):
1111

1212
class CustomScalarMetric(BaseModel):
1313
metric_name: str
14-
value: float
14+
metric_value: float
1515
feature_name: Optional[str] = None
1616
subclass_name: Optional[str] = None
1717
aggregation: MetricAggregation = MetricAggregation.ARITHMETIC_MEAN

labelbox/data/serialization/ndjson/converter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ def deserialize(json_data: Iterable[Dict[str, Any]]) -> LabelGenerator:
2020
LabelGenerator containing the ndjson data.
2121
"""
2222
data = NDLabel(**{'annotations': json_data})
23-
return data.to_common()
23+
res = data.to_common()
24+
return res
2425

2526
@staticmethod
2627
def serialize(

labelbox/data/serialization/ndjson/label.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from itertools import groupby
2+
from labelbox.data.annotation_types.metrics.scalar import CustomScalarMetric
23
from labelbox.data.annotation_types.metrics import ScalarMetric
34

45
from operator import itemgetter
@@ -12,7 +13,7 @@
1213
from ...annotation_types.data import ImageData, TextData, VideoData
1314
from ...annotation_types.label import Label
1415
from ...annotation_types.ner import TextEntity
15-
from .metric import NDMetricAnnotation, NDMetricType
16+
from .metric import NDCustomScalarMetric, NDMetricAnnotation, NDMetricType
1617
from .classification import NDChecklistSubclass, NDClassification, NDClassificationType, NDRadioSubclass
1718
from .objects import NDObject, NDObjectType
1819

@@ -39,14 +40,16 @@ def _generate_annotations(
3940
NDClassificationType,
4041
NDMetricType]]]
4142
) -> Generator[Label, None, None]:
43+
4244
for data_row_id, annotations in grouped_annotations.items():
4345
annots = []
4446
for annotation in annotations:
4547
if isinstance(annotation, NDObjectType.__args__):
4648
annots.append(NDObject.to_common(annotation))
4749
elif isinstance(annotation, NDClassificationType.__args__):
4850
annots.extend(NDClassification.to_common(annotation))
49-
elif isinstance(annotation, NDMetricType):
51+
elif isinstance(annotation, NDMetricType.__args__):
52+
5053
annots.append(NDMetricAnnotation.to_common(annotation))
5154
else:
5255
raise TypeError(
@@ -105,9 +108,9 @@ def _create_non_video_annotations(cls, label: Label):
105108
yield NDClassification.from_common(annotation, label.data)
106109
elif isinstance(annotation, ObjectAnnotation):
107110
yield NDObject.from_common(annotation, label.data)
108-
elif isinstance(annotation, ScalarMetric):
111+
elif isinstance(annotation, (ScalarMetric, CustomScalarMetric)):
109112
yield NDMetricAnnotation.from_common(annotation, label.data)
110113
else:
111114
raise TypeError(
112-
f"Unable to convert object to MAL format. `{type(annotation.value)}`"
115+
f"Unable to convert object to MAL format. `{type(getattr(annotation, 'value',annotation))}`"
113116
)

labelbox/data/serialization/ndjson/metric.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1+
from labelbox.data.annotation_types.metrics.aggregations import MetricAggregation
12
from labelbox.data.annotation_types.metrics.scalar import CustomScalarMetric
2-
from typing import Literal, Union
3-
4-
from fiftyone.core.collections import aggregation
3+
from typing import Union, Optional
54

65
from labelbox.data.annotation_types.data import ImageData, TextData
76
from labelbox.data.annotation_types.metrics import ScalarMetric
@@ -24,34 +23,54 @@ def from_common(
2423

2524

2625
class NDCustomScalarMetric(NDJsonBase):
27-
metric_name: float
26+
metric_name: str
2827
metric_value: float
29-
sublcass_name: float
30-
aggregation: Union[Literal["ARITHMETIC_MEAN"], Literal["GEOMETRIC_MEAN"],
31-
Literal["HARMONIC_MEAN"], Literal["SUM"]]
28+
feature_name: Optional[str] = None
29+
subclass_name: Optional[str] = None
30+
aggregation: MetricAggregation
3231

3332
def to_common(self) -> CustomScalarMetric:
34-
return ScalarMetric(value=self.metric_value, extra={'uuid': self.uuid})
33+
return CustomScalarMetric(
34+
metric_value=self.metric_value,
35+
metric_name=self.metric_name,
36+
feature_name=self.feature_name,
37+
subclass_name=self.subclass_name,
38+
aggregation=MetricAggregation[self.aggregation],
39+
extra={'uuid': self.uuid})
3540

3641
@classmethod
37-
def from_common(cls, metric: ScalarMetric,
42+
def from_common(cls, metric: CustomScalarMetric,
3843
data: Union[TextData, ImageData]) -> "NDCustomScalarMetric":
3944
return NDCustomScalarMetric(uuid=metric.extra.get('uuid'),
40-
metric_value=metric.value,
45+
metric_value=metric.metric_value,
46+
metric_name=metric.metric_name,
47+
feature_name=metric.feature_name,
48+
subclass_name=metric.subclass_name,
49+
aggregation=metric.aggregation.value,
4150
data_row={'id': data.uid})
4251

52+
def dict(self, *args, **kwargs):
53+
res = super().dict(*args, **kwargs)
54+
for field in ['featureName', 'subclassName']:
55+
if res[field] is None:
56+
res.pop(field)
57+
return res
58+
59+
class Config:
60+
use_enum_values = True
61+
4362

4463
class NDMetricAnnotation:
4564

4665
@classmethod
4766
def to_common(cls, annotation: "NDMetricType") -> ScalarMetric:
48-
4967
return annotation.to_common()
5068

5169
@classmethod
5270
def from_common(cls, annotation: Union[ScalarMetric, CustomScalarMetric],
5371
data: Union[TextData, ImageData]) -> "NDMetricType":
54-
return NDDataRowScalarMetric.from_common(annotation, data)
72+
obj = cls.lookup_object(annotation)
73+
return obj.from_common(annotation, data)
5574

5675
@staticmethod
5776
def lookup_object(
@@ -66,4 +85,4 @@ def lookup_object(
6685
return result
6786

6887

69-
NDMetricType = Union[NDDataRowScalarMetric, NDCustomScalarMetric]
88+
NDMetricType = Union[NDCustomScalarMetric, NDDataRowScalarMetric]

tests/data/annotation_types/test_metrics.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ def test_custom_scalar_metric(feature_name, subclass_name, aggregation):
4747
value = 0.5
4848
kwargs = {'aggregation': aggregation} if aggregation is not None else {}
4949
metric = CustomScalarMetric(metric_name="iou",
50-
value=value,
50+
metric_value=value,
5151
feature_name=feature_name,
5252
subclass_name=subclass_name,
5353
**kwargs)
54-
assert metric.value == value
54+
assert metric.metric_value == value
5555

5656
label = Label(data=ImageData(uid="ckrmd9q8g000009mg6vej7hzg"),
5757
annotations=[metric])
@@ -65,7 +65,7 @@ def test_custom_scalar_metric(feature_name, subclass_name, aggregation):
6565
'arr': None
6666
},
6767
'annotations': [{
68-
'value': value,
68+
'metric_value': value,
6969
'metric_name': 'iou',
7070
'feature_name': feature_name,
7171
'subclass_name': subclass_name,
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[{"uuid" : "a22bbf6e-b2da-4abe-9a11-df84759f7672","dataRow" : {"id": "ckrmdnqj4000007msh9p2a27r"}, "metricValue" : 0.1, "metricName" : "iou", "featureName" : "sample_class", "subclassName" : "sample_subclass", "aggregation" : "SUM"},
2+
{"uuid" : "a22bbf6e-b2da-4abe-9a11-df84759f7672","dataRow" : {"id": "ckrmdnqj4000007msh9p2a27r"}, "metricValue" : 0.1, "metricName" : "iou", "featureName" : "sample_class", "aggregation" : "SUM"},
3+
{"uuid" : "a22bbf6e-b2da-4abe-9a11-df84759f7672","dataRow" : {"id": "ckrmdnqj4000007msh9p2a27r"}, "metricValue" : 0.1, "metricName" : "iou", "aggregation" : "SUM"}]

tests/data/serialization/ndjson/test_metric.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
2-
from labelbox.data.serialization.labelbox_v1.converter import LBV1Converter
32

3+
from labelbox.data.serialization.labelbox_v1.converter import LBV1Converter
44
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
55

66

@@ -14,3 +14,16 @@ def test_metric():
1414

1515
# Just make sure that this doesn't break
1616
list(LBV1Converter.serialize(label_list))
17+
18+
19+
def test_metric():
20+
with open('tests/data/assets/ndjson/custom_scalar_import.json',
21+
'r') as file:
22+
data = json.load(file)
23+
24+
label_list = NDJsonConverter.deserialize(data).as_list()
25+
reserialized = list(NDJsonConverter.serialize(label_list))
26+
assert reserialized == data
27+
28+
# Just make sure that this doesn't break
29+
list(LBV1Converter.serialize(label_list))

0 commit comments

Comments
 (0)