Skip to content

Commit 1254354

Browse files
authored
Merge pull request #265 from Labelbox/ms/custom-metric-models
custom scalar metric
2 parents 1038a42 + d5a8379 commit 1254354

File tree

8 files changed

+85
-12
lines changed

8 files changed

+85
-12
lines changed

labelbox/data/annotation_types/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,5 @@
2929
from .collection import LabelGenerator
3030

3131
from .metrics import ScalarMetric
32+
from .metrics import CustomScalarMetric
33+
from .metrics import MetricAggregation

labelbox/data/annotation_types/label.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections import defaultdict
2+
from labelbox.data.annotation_types.metrics.scalar import CustomScalarMetric
23

34
from typing import Any, Callable, Dict, List, Union, Optional
45

@@ -21,7 +22,8 @@ class Label(BaseModel):
2122
data: Union[VideoData, ImageData, TextData]
2223
annotations: List[Union[ClassificationAnnotation, ObjectAnnotation,
2324
VideoObjectAnnotation,
24-
VideoClassificationAnnotation, ScalarMetric]] = []
25+
VideoClassificationAnnotation, CustomScalarMetric,
26+
ScalarMetric]] = []
2527
extra: Dict[str, Any] = {}
2628

2729
def object_annotations(self) -> List[ObjectAnnotation]:

labelbox/data/annotation_types/metrics.py

Lines changed: 0 additions & 9 deletions
This file was deleted.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .scalar import ScalarMetric, CustomScalarMetric
2+
from .aggregations import MetricAggregation
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from enum import Enum
2+
3+
4+
class MetricAggregation(Enum):
5+
ARITHMETIC_MEAN = "ARITHMETIC_MEAN"
6+
GEOMETRIC_MEAN = "GEOMETRIC_MEAN"
7+
HARMONIC_MEAN = "HARMONIC_MEAN"
8+
SUM = "SUM"
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from labelbox.data.annotation_types.metrics.aggregations import MetricAggregation
2+
from typing import Any, Dict, Optional
3+
from pydantic import BaseModel
4+
5+
6+
class ScalarMetric(BaseModel):
7+
""" Class representing metrics """
8+
value: float
9+
extra: Dict[str, Any] = {}
10+
11+
12+
class CustomScalarMetric(BaseModel):
13+
metric_name: str
14+
value: float
15+
feature_name: Optional[str] = None
16+
subclass_name: Optional[str] = None
17+
aggregation: MetricAggregation = MetricAggregation.ARITHMETIC_MEAN
18+
extra: Dict[str, Any] = {}

labelbox/schema/bulk_import_request.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,8 @@ def create_from_local_file(cls,
388388

389389
def delete(self) -> None:
390390
""" Deletes the import job and also any annotations created by this import.
391-
392-
Returns:
391+
392+
Returns:
393393
None
394394
"""
395395
id_param = "bulk_request_id"

tests/data/annotation_types/test_metrics.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
from labelbox.data.annotation_types.metrics.aggregations import MetricAggregation
2+
from labelbox.data.annotation_types.metrics.scalar import CustomScalarMetric
13
from labelbox.data.annotation_types.collection import LabelList
24
from labelbox.data.annotation_types import ScalarMetric, Label, ImageData
35

6+
import pytest
7+
48

59
def test_scalar_metric():
610
value = 10
@@ -27,3 +31,49 @@ def test_scalar_metric():
2731
}
2832
assert label.dict() == expected
2933
next(LabelList([label])).dict() == expected
34+
35+
36+
@pytest.mark.parametrize('feature_name,subclass_name,aggregation', [
37+
("cat", "orange", MetricAggregation.ARITHMETIC_MEAN),
38+
("cat", None, MetricAggregation.ARITHMETIC_MEAN),
39+
(None, None, MetricAggregation.ARITHMETIC_MEAN),
40+
(None, None, None),
41+
("cat", "orange", MetricAggregation.ARITHMETIC_MEAN),
42+
("cat", None, MetricAggregation.HARMONIC_MEAN),
43+
(None, None, MetricAggregation.GEOMETRIC_MEAN),
44+
(None, None, MetricAggregation.SUM),
45+
])
46+
def test_custom_scalar_metric(feature_name, subclass_name, aggregation):
47+
value = 0.5
48+
kwargs = {'aggregation': aggregation} if aggregation is not None else {}
49+
metric = CustomScalarMetric(metric_name="iou",
50+
value=value,
51+
feature_name=feature_name,
52+
subclass_name=subclass_name,
53+
**kwargs)
54+
assert metric.value == value
55+
56+
label = Label(data=ImageData(uid="ckrmd9q8g000009mg6vej7hzg"),
57+
annotations=[metric])
58+
expected = {
59+
'data': {
60+
'external_id': None,
61+
'uid': 'ckrmd9q8g000009mg6vej7hzg',
62+
'im_bytes': None,
63+
'file_path': None,
64+
'url': None,
65+
'arr': None
66+
},
67+
'annotations': [{
68+
'value': value,
69+
'metric_name': 'iou',
70+
'feature_name': feature_name,
71+
'subclass_name': subclass_name,
72+
'aggregation': aggregation or MetricAggregation.ARITHMETIC_MEAN,
73+
'extra': {}
74+
}],
75+
'extra': {},
76+
'uid': None
77+
}
78+
assert label.dict() == expected
79+
next(LabelList([label])).dict() == expected

0 commit comments

Comments
 (0)