Skip to content

Commit 574780e

Browse files
author
Matt Sokoloff
committed
debug
1 parent 97f6a71 commit 574780e

File tree

5 files changed

+116
-14
lines changed

5 files changed

+116
-14
lines changed

Makefile

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,12 @@ test-prod: build
2020
-e LABELBOX_TEST_ENVIRON="prod" \
2121
-e LABELBOX_TEST_API_KEY_PROD=${LABELBOX_TEST_API_KEY_PROD} \
2222
local/labelbox-python:test pytest $(PATH_TO_TEST) -svvx
23+
24+
25+
26+
test-dev: build
27+
docker run -it -v ${PWD}:/usr/src -w /usr/src \
28+
-e LABELBOX_TEST_ENVIRON="staging" \
29+
-e LABELBOX_TEST_API_KEY_PROD=${LABELBOX_TEST_API_KEY_PROD} \
30+
local/labelbox-python:test pytest $(PATH_TO_TEST) -svv
31+

labelbox/data/annotation_types/metrics/aggregations.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33

44
class MetricAggregation(Enum):
5+
CONFUSION_MATRIX = "CONFUSION_MATRIX"
56
ARITHMETIC_MEAN = "ARITHMETIC_MEAN"
67
GEOMETRIC_MEAN = "GEOMETRIC_MEAN"
78
HARMONIC_MEAN = "HARMONIC_MEAN"
89
SUM = "SUM"
10+
Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,54 @@
11
from labelbox.data.annotation_types.metrics.aggregations import MetricAggregation
2-
from typing import Any, Dict, Optional
3-
from pydantic import BaseModel
2+
from typing import Any, Dict, Optional, Tuple, Union
3+
from pydantic import BaseModel, Field, validator, confloat
44

55

6-
class ScalarMetric(BaseModel):
7-
""" Class representing metrics
6+
ScalarMetricConfidenceValue = Dict[confloat(ge=0, le=1), float]
7+
ConfusionMatrixMetricConfidenceValue = Dict[confloat(ge=0, le=1), Tuple[int,int,int,int]]
88

9-
# For backwards compatibility, metric_name is optional. This will eventually be deprecated
10-
# The metric_name will be set to a default name in the editor if it is not set.
119

12-
# aggregation will be ignored wihtout providing a metric name.
13-
# Not providing a metric name is deprecated.
14-
"""
15-
value: float
10+
class BaseMetric(BaseModel):
1611
metric_name: Optional[str] = None
1712
feature_name: Optional[str] = None
1813
subclass_name: Optional[str] = None
19-
aggregation: MetricAggregation = MetricAggregation.ARITHMETIC_MEAN
2014
extra: Dict[str, Any] = {}
2115

16+
17+
class ScalarMetric(BaseMetric):
18+
""" Class representing scalar metrics
19+
20+
For backwards compatibility, metric_name is optional.
21+
The metric_name will be set to a default name in the editor if it is not set.
22+
This is not recommended and support for empty metric_name fields will be removed.
23+
aggregation will be ignored wihtout providing a metric name.
24+
"""
25+
value: Union[float, ScalarMetricConfidenceValue]
26+
aggregation: MetricAggregation = MetricAggregation.ARITHMETIC_MEAN
27+
2228
def dict(self, *args, **kwargs):
2329
res = super().dict(*args, **kwargs)
2430
if res['metric_name'] is None:
2531
res.pop('aggregation')
2632
return {k: v for k, v in res.items() if v is not None}
33+
34+
@validator('aggregation')
35+
def validate_aggregation(cls, aggregation):
36+
if aggregation == MetricAggregation.CONFUSION_MATRIX:
37+
raise ValueError("Cannot assign `MetricAggregation.CONFUSION_MATRIX` to `ScalarMetric.aggregation`")
38+
39+
40+
41+
class ConfusionMatrixMetric(BaseMetric):
42+
""" Class representing confusion matrix metrics.
43+
44+
In the editor, this provides precision, recall, and f-scores.
45+
This should be used over multiple scalar metrics so that aggregations are accurate.
46+
47+
value should be a tuple representing:
48+
[True Positive Count, False Positive Count, True Negative Count, False Negative Count]
49+
50+
aggregation cannot be adjusted for confusion matrix metrics.
51+
"""
52+
value: Union[Tuple[int,int,int,int], ConfusionMatrixMetricConfidenceValue]
53+
aggregation: MetricAggregation = Field(MetricAggregation.CONFUSION_MATRIX, const = True)
54+

tests/data/annotation_types/test_metrics.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,62 @@ def test_legacy_scalar_metric():
3333
next(LabelList([label])).dict() == expected
3434

3535

36+
# TODO: Test with confidence
37+
38+
@pytest.mark.parametrize('feature_name,subclass_name,aggregation', [
39+
("cat", "orange", MetricAggregation.ARITHMETIC_MEAN),
40+
("cat", None, MetricAggregation.ARITHMETIC_MEAN),
41+
(None, None, MetricAggregation.ARITHMETIC_MEAN),
42+
(None, None, None),
43+
("cat", "orange", MetricAggregation.ARITHMETIC_MEAN),
44+
("cat", None, MetricAggregation.HARMONIC_MEAN),
45+
(None, None, MetricAggregation.GEOMETRIC_MEAN),
46+
(None, None, MetricAggregation.SUM)
47+
])
48+
def test_custom_scalar_metric(feature_name, subclass_name, aggregation):
49+
value = 0.5
50+
kwargs = {'aggregation': aggregation} if aggregation is not None else {}
51+
metric = ScalarMetric(metric_name="iou",
52+
value=value,
53+
feature_name=feature_name,
54+
subclass_name=subclass_name,
55+
**kwargs)
56+
assert metric.value == value
57+
58+
label = Label(data=ImageData(uid="ckrmd9q8g000009mg6vej7hzg"),
59+
annotations=[metric])
60+
expected = {
61+
'data': {
62+
'external_id': None,
63+
'uid': 'ckrmd9q8g000009mg6vej7hzg',
64+
'im_bytes': None,
65+
'file_path': None,
66+
'url': None,
67+
'arr': None
68+
},
69+
'annotations': [{
70+
'value':
71+
value,
72+
'metric_name':
73+
'iou',
74+
**({
75+
'feature_name': feature_name
76+
} if feature_name else {}),
77+
**({
78+
'subclass_name': subclass_name
79+
} if subclass_name else {}), 'aggregation':
80+
aggregation or MetricAggregation.ARITHMETIC_MEAN,
81+
'extra': {}
82+
}],
83+
'extra': {},
84+
'uid': None
85+
}
86+
assert label.dict() == expected
87+
next(LabelList([label])).dict() == expected
88+
89+
90+
91+
3692
@pytest.mark.parametrize('feature_name,subclass_name,aggregation', [
3793
("cat", "orange", MetricAggregation.ARITHMETIC_MEAN),
3894
("cat", None, MetricAggregation.ARITHMETIC_MEAN),

tests/integration/conftest.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,25 @@ def environ() -> Environ:
4747

4848

4949
def graphql_url(environ: str) -> str:
50+
return "https://app.replicated-6bd9012.labelbox.dev/api/_gql"
51+
"""
5052
if environ == Environ.PROD:
5153
return 'https://api.labelbox.com/graphql'
5254
elif environ == Environ.STAGING:
5355
return 'https://staging-api.labelbox.com/graphql'
5456
return 'http://host.docker.internal:8080/graphql'
57+
"""
5558

5659

5760
def testing_api_key(environ: str) -> str:
61+
return "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VySWQiOiJja3RhbXFnN2MwMDAxMHljaTAxZDFhaGVqIiwib3JnYW5pemF0aW9uSWQiOiJja3RhbXFnNm4wMDAwMHljaTgwaDI1NXZ3IiwiYXBpS2V5SWQiOiJja3RhbXZmdm4wMDhrMHljaTh1MG40a2hzIiwic2VjcmV0IjoiNzcxOWViMjgyNjUyMWMxODQ5MmJhMjg1NzhmY2FmNDEiLCJpYXQiOjE2MzEwNTMwNDksImV4cCI6MjI2MjIwNTA0OX0.yBLurIRB3xYQkV8MEBm0_LxmdqP9U-8aMj25kASmGLw"
62+
"""
5863
if environ == Environ.PROD:
5964
return os.environ["LABELBOX_TEST_API_KEY_PROD"]
6065
elif environ == Environ.STAGING:
6166
return os.environ["LABELBOX_TEST_API_KEY_STAGING"]
6267
return os.environ["LABELBOX_TEST_API_KEY_LOCAL"]
68+
"""
6369

6470

6571
def cancel_invite(client, invite_id):
@@ -135,9 +141,10 @@ def client(environ: str):
135141

136142
@pytest.fixture(scope="session")
137143
def image_url(client, environ: str):
138-
if environ == Environ.LOCAL:
139-
return IMG_URL
140-
return client.upload_data(requests.get(IMG_URL).content, sign=True)
144+
return IMG_URL
145+
#if environ == Environ.LOCAL:
146+
# return IMG_URL
147+
#return client.upload_data(requests.get(IMG_URL).content, sign=True)
141148

142149

143150
@pytest.fixture

0 commit comments

Comments
 (0)