Skip to content

Commit d5a8379

Browse files
author
Matt Sokoloff
committed
add custom scalar metric
1 parent f30d572 commit d5a8379

File tree

7 files changed

+53
-16
lines changed

7 files changed

+53
-16
lines changed

labelbox/data/annotation_types/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,3 @@
3131
from .metrics import ScalarMetric
3232
from .metrics import CustomScalarMetric
3333
from .metrics import MetricAggregation
34-

labelbox/data/annotation_types/label.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ class Label(BaseModel):
2222
data: Union[VideoData, ImageData, TextData]
2323
annotations: List[Union[ClassificationAnnotation, ObjectAnnotation,
2424
VideoObjectAnnotation,
25-
VideoClassificationAnnotation, ScalarMetric, CustomScalarMetric]] = []
25+
VideoClassificationAnnotation, CustomScalarMetric,
26+
ScalarMetric]] = []
2627
extra: Dict[str, Any] = {}
2728

2829
def object_annotations(self) -> List[ObjectAnnotation]:
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .scalar import ScalarMetric, CustomScalarMetric
2+
from .aggregations import MetricAggregation
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from enum import Enum
2+
3+
4+
class MetricAggregation(Enum):
5+
ARITHMETIC_MEAN = "ARITHMETIC_MEAN"
6+
GEOMETRIC_MEAN = "GEOMETRIC_MEAN"
7+
HARMONIC_MEAN = "HARMONIC_MEAN"
8+
SUM = "SUM"
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from labelbox.data.annotation_types.metrics.aggregations import MetricAggregation
2+
from typing import Any, Dict, Optional
3+
from pydantic import BaseModel
4+
5+
6+
class ScalarMetric(BaseModel):
7+
""" Class representing metrics """
8+
value: float
9+
extra: Dict[str, Any] = {}
10+
11+
12+
class CustomScalarMetric(BaseModel):
13+
metric_name: str
14+
value: float
15+
feature_name: Optional[str] = None
16+
subclass_name: Optional[str] = None
17+
aggregation: MetricAggregation = MetricAggregation.ARITHMETIC_MEAN
18+
extra: Dict[str, Any] = {}

labelbox/schema/bulk_import_request.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,7 @@ def determinants(parent_cls) -> List[str]:
635635

636636
###### Classifications ######
637637

638+
638639
class NDText(NDBase):
639640
ontology_type: Literal["text"] = "text"
640641
answer: str = pydantic.Field(determinant=True)

tests/data/annotation_types/test_metrics.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import pytest
77

8+
89
def test_scalar_metric():
910
value = 10
1011
metric = ScalarMetric(value=value)
@@ -32,17 +33,24 @@ def test_scalar_metric():
3233
next(LabelList([label])).dict() == expected
3334

3435

35-
@pytest.mark.parametrize(
36-
'feature_name,subclass_name,aggregation',
37-
[
38-
("cat", "orange" , MetricAggregation.ARITHMETIC_MEAN),
39-
("cat", None, MetricAggregation.ARITHMETIC_MEAN),
40-
(None, None, MetricAggregation.ARITHMETIC_MEAN),
41-
(None, None, None),
42-
])
36+
@pytest.mark.parametrize('feature_name,subclass_name,aggregation', [
37+
("cat", "orange", MetricAggregation.ARITHMETIC_MEAN),
38+
("cat", None, MetricAggregation.ARITHMETIC_MEAN),
39+
(None, None, MetricAggregation.ARITHMETIC_MEAN),
40+
(None, None, None),
41+
("cat", "orange", MetricAggregation.ARITHMETIC_MEAN),
42+
("cat", None, MetricAggregation.HARMONIC_MEAN),
43+
(None, None, MetricAggregation.GEOMETRIC_MEAN),
44+
(None, None, MetricAggregation.SUM),
45+
])
4346
def test_custom_scalar_metric(feature_name, subclass_name, aggregation):
4447
value = 0.5
45-
metric = CustomScalarMetric(metric_name = "iou", value=value, feature_name=feature_name, subclass_name = subclass_name, aggregation = aggregation)
48+
kwargs = {'aggregation': aggregation} if aggregation is not None else {}
49+
metric = CustomScalarMetric(metric_name="iou",
50+
value=value,
51+
feature_name=feature_name,
52+
subclass_name=subclass_name,
53+
**kwargs)
4654
assert metric.value == value
4755

4856
label = Label(data=ImageData(uid="ckrmd9q8g000009mg6vej7hzg"),
@@ -57,15 +65,15 @@ def test_custom_scalar_metric(feature_name, subclass_name, aggregation):
5765
'arr': None
5866
},
5967
'annotations': [{
60-
'value': 10.0,
61-
68+
'value': value,
69+
'metric_name': 'iou',
70+
'feature_name': feature_name,
71+
'subclass_name': subclass_name,
72+
'aggregation': aggregation or MetricAggregation.ARITHMETIC_MEAN,
6273
'extra': {}
6374
}],
6475
'extra': {},
6576
'uid': None
6677
}
6778
assert label.dict() == expected
6879
next(LabelList([label])).dict() == expected
69-
70-
71-

0 commit comments

Comments
 (0)