Skip to content

Commit 2a49913

Browse files
author
Matt Sokoloff
committed
add confusion matrix metric and confidence
1 parent 23b9375 commit 2a49913

File tree

9 files changed

+119
-103
lines changed

9 files changed

+119
-103
lines changed

examples/model_diagnostics/custom-metrics.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@
164164
"metadata": {},
165165
"outputs": [],
166166
"source": [
167-
"from labelbox.data.annotation_types import ScalarMetric, MetricAggregation"
167+
"from labelbox.data.annotation_types import ScalarMetric, ScalarMetricAggregation"
168168
]
169169
},
170170
{
@@ -226,14 +226,14 @@
226226
" metric_name = \"true_positives\",\n",
227227
" feature_name = \"cat\",\n",
228228
" value = 3,\n",
229-
" aggregation = MetricAggregation.SUM\n",
229+
" aggregation = ScalarMetricAggregation.SUM\n",
230230
")\n",
231231
"\n",
232232
"feature_metric = ScalarMetric(\n",
233233
" metric_name = \"true_positives\",\n",
234234
" feature_name = \"dog\",\n",
235235
" value = 4,\n",
236-
" aggregation = MetricAggregation.SUM\n",
236+
" aggregation = ScalarMetricAggregation.SUM\n",
237237
")\n"
238238
]
239239
},

labelbox/data/annotation_types/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,6 @@
2929
from .collection import LabelGenerator
3030

3131
from .metrics import ScalarMetric
32-
from .metrics import MetricAggregation
32+
from .metrics import ScalarMetricAggregation
33+
from .metrics import ConfusionMatrixMetric
34+
from .metrics import ConfusionMatrixAggregation

labelbox/data/annotation_types/label.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from .classification import ClassificationAnswer
1212
from .data import VideoData, TextData, ImageData
1313
from .geometry import Mask
14-
from .metrics import ScalarMetric
14+
from .metrics import ScalarMetric, ConfusionMatrixMetric
1515
from .types import Cuid
1616
from .annotation import (ClassificationAnnotation, ObjectAnnotation,
1717
VideoClassificationAnnotation, VideoObjectAnnotation)
@@ -23,7 +23,7 @@ class Label(BaseModel):
2323
annotations: List[Union[ClassificationAnnotation, ObjectAnnotation,
2424
VideoObjectAnnotation,
2525
VideoClassificationAnnotation, ScalarMetric,
26-
ScalarMetric]] = []
26+
ConfusionMatrixMetric]] = []
2727
extra: Dict[str, Any] = {}
2828

2929
def object_annotations(self) -> List[ObjectAnnotation]:
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .scalar import ScalarMetric
2-
from .aggregations import MetricAggregation
1+
from .scalar import ScalarMetric, ScalarMetricAggregation
2+
from .confusion_matrix import ConfusionMatrixMetric, ConfusionMatrixAggregation

labelbox/data/annotation_types/metrics/aggregations.py

Lines changed: 0 additions & 10 deletions
This file was deleted.
Lines changed: 16 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
1-
from labelbox.data.annotation_types.metrics.aggregations import MetricAggregation
2-
from typing import Any, Dict, Optional, Tuple, Union
3-
from pydantic import BaseModel, Field, validator, confloat
1+
from typing import Dict, Optional, Union
2+
from enum import Enum
43

4+
from pydantic import confloat
55

6-
ScalarMetricConfidenceValue = Dict[confloat(ge=0, le=1), float]
7-
ConfusionMatrixMetricConfidenceValue = Dict[confloat(ge=0, le=1), Tuple[int,int,int,int]]
6+
from .base import ConfidenceValue, BaseMetric
87

8+
ScalarMetricValue = confloat(ge=0, le=10_000)
9+
ScalarMetricConfidenceValue = Dict[ConfidenceValue, ScalarMetricValue]
910

10-
class BaseMetric(BaseModel):
11-
metric_name: Optional[str] = None
12-
feature_name: Optional[str] = None
13-
subclass_name: Optional[str] = None
14-
extra: Dict[str, Any] = {}
11+
12+
class ScalarMetricAggregation(Enum):
13+
ARITHMETIC_MEAN = "ARITHMETIC_MEAN"
14+
GEOMETRIC_MEAN = "GEOMETRIC_MEAN"
15+
HARMONIC_MEAN = "HARMONIC_MEAN"
16+
SUM = "SUM"
1517

1618

1719
class ScalarMetric(BaseMetric):
@@ -22,33 +24,12 @@ class ScalarMetric(BaseMetric):
2224
This is not recommended and support for empty metric_name fields will be removed.
2325
aggregation will be ignored wihtout providing a metric name.
2426
"""
27+
metric_name: Optional[str] = None
2528
value: Union[float, ScalarMetricConfidenceValue]
26-
aggregation: MetricAggregation = MetricAggregation.ARITHMETIC_MEAN
29+
aggregation: ScalarMetricAggregation = ScalarMetricAggregation.ARITHMETIC_MEAN
2730

2831
def dict(self, *args, **kwargs):
2932
res = super().dict(*args, **kwargs)
30-
if res['metric_name'] is None:
33+
if res.get('metric_name') is None:
3134
res.pop('aggregation')
32-
return {k: v for k, v in res.items() if v is not None}
33-
34-
@validator('aggregation')
35-
def validate_aggregation(cls, aggregation):
36-
if aggregation == MetricAggregation.CONFUSION_MATRIX:
37-
raise ValueError("Cannot assign `MetricAggregation.CONFUSION_MATRIX` to `ScalarMetric.aggregation`")
38-
39-
40-
41-
class ConfusionMatrixMetric(BaseMetric):
42-
""" Class representing confusion matrix metrics.
43-
44-
In the editor, this provides precision, recall, and f-scores.
45-
This should be used over multiple scalar metrics so that aggregations are accurate.
46-
47-
value should be a tuple representing:
48-
[True Positive Count, False Positive Count, True Negative Count, False Negative Count]
49-
50-
aggregation cannot be adjusted for confusion matrix metrics.
51-
"""
52-
value: Union[Tuple[int,int,int,int], ConfusionMatrixMetricConfidenceValue]
53-
aggregation: MetricAggregation = Field(MetricAggregation.CONFUSION_MATRIX, const = True)
54-
35+
return res

labelbox/data/serialization/ndjson/metric.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from labelbox.data.annotation_types.metrics.aggregations import MetricAggregation
1+
from labelbox.data.annotation_types.metrics import ScalarMetricAggregation
22
from typing import Union, Optional
33

44
from labelbox.data.annotation_types.data import ImageData, TextData
@@ -11,15 +11,16 @@ class NDScalarMetric(NDJsonBase):
1111
metric_name: Optional[str]
1212
feature_name: Optional[str] = None
1313
subclass_name: Optional[str] = None
14-
aggregation: MetricAggregation = MetricAggregation.ARITHMETIC_MEAN.value
14+
aggregation: ScalarMetricAggregation = ScalarMetricAggregation.ARITHMETIC_MEAN.value
1515

1616
def to_common(self) -> ScalarMetric:
17-
return ScalarMetric(value=self.metric_value,
18-
metric_name=self.metric_name,
19-
feature_name=self.feature_name,
20-
subclass_name=self.subclass_name,
21-
aggregation=MetricAggregation[self.aggregation],
22-
extra={'uuid': self.uuid})
17+
return ScalarMetric(
18+
value=self.metric_value,
19+
metric_name=self.metric_name,
20+
feature_name=self.feature_name,
21+
subclass_name=self.subclass_name,
22+
aggregation=ScalarMetricAggregation[self.aggregation],
23+
extra={'uuid': self.uuid})
2324

2425
@classmethod
2526
def from_common(cls, metric: ScalarMetric,

tests/data/annotation_types/test_metrics.py

Lines changed: 84 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
from pydantic import ValidationError
12
import pytest
23

3-
from labelbox.data.annotation_types.metrics.aggregations import MetricAggregation
4-
from labelbox.data.annotation_types.metrics.scalar import ScalarMetric
4+
from labelbox.data.annotation_types.metrics import ConfusionMatrixAggregation, ScalarMetricAggregation
5+
from labelbox.data.annotation_types.metrics import ConfusionMatrixMetric, ScalarMetric
56
from labelbox.data.annotation_types.collection import LabelList
67
from labelbox.data.annotation_types import ScalarMetric, Label, ImageData
78

@@ -30,23 +31,28 @@ def test_legacy_scalar_metric():
3031
'uid': None
3132
}
3233
assert label.dict() == expected
33-
next(LabelList([label])).dict() == expected
34+
assert next(LabelList([label])).dict() == expected
3435

3536

3637
# TODO: Test with confidence
3738

38-
@pytest.mark.parametrize('feature_name,subclass_name,aggregation', [
39-
("cat", "orange", MetricAggregation.ARITHMETIC_MEAN),
40-
("cat", None, MetricAggregation.ARITHMETIC_MEAN),
41-
(None, None, MetricAggregation.ARITHMETIC_MEAN),
42-
(None, None, None),
43-
("cat", "orange", MetricAggregation.ARITHMETIC_MEAN),
44-
("cat", None, MetricAggregation.HARMONIC_MEAN),
45-
(None, None, MetricAggregation.GEOMETRIC_MEAN),
46-
(None, None, MetricAggregation.SUM)
39+
40+
@pytest.mark.parametrize('feature_name,subclass_name,aggregation,value', [
41+
("cat", "orange", ScalarMetricAggregation.ARITHMETIC_MEAN, 0.5),
42+
("cat", None, ScalarMetricAggregation.ARITHMETIC_MEAN, 0.5),
43+
(None, None, ScalarMetricAggregation.ARITHMETIC_MEAN, 0.5),
44+
(None, None, None, 0.5),
45+
("cat", "orange", ScalarMetricAggregation.ARITHMETIC_MEAN, 0.5),
46+
("cat", None, ScalarMetricAggregation.HARMONIC_MEAN, 0.5),
47+
(None, None, ScalarMetricAggregation.GEOMETRIC_MEAN, 0.5),
48+
(None, None, ScalarMetricAggregation.SUM, 0.5),
49+
("cat", "orange", ScalarMetricAggregation.ARITHMETIC_MEAN, {
50+
0.1: 0.2,
51+
0.3: 0.5,
52+
0.4: 0.8
53+
}),
4754
])
48-
def test_custom_scalar_metric(feature_name, subclass_name, aggregation):
49-
value = 0.5
55+
def test_custom_scalar_metric(feature_name, subclass_name, aggregation, value):
5056
kwargs = {'aggregation': aggregation} if aggregation is not None else {}
5157
metric = ScalarMetric(metric_name="iou",
5258
value=value,
@@ -77,36 +83,37 @@ def test_custom_scalar_metric(feature_name, subclass_name, aggregation):
7783
**({
7884
'subclass_name': subclass_name
7985
} if subclass_name else {}), 'aggregation':
80-
aggregation or MetricAggregation.ARITHMETIC_MEAN,
86+
aggregation or ScalarMetricAggregation.ARITHMETIC_MEAN,
8187
'extra': {}
8288
}],
8389
'extra': {},
8490
'uid': None
8591
}
86-
assert label.dict() == expected
87-
next(LabelList([label])).dict() == expected
88-
8992

93+
assert label.dict() == expected
94+
assert next(LabelList([label])).dict() == expected
9095

9196

92-
@pytest.mark.parametrize('feature_name,subclass_name,aggregation', [
93-
("cat", "orange", MetricAggregation.ARITHMETIC_MEAN),
94-
("cat", None, MetricAggregation.ARITHMETIC_MEAN),
95-
(None, None, MetricAggregation.ARITHMETIC_MEAN),
96-
(None, None, None),
97-
("cat", "orange", MetricAggregation.ARITHMETIC_MEAN),
98-
("cat", None, MetricAggregation.HARMONIC_MEAN),
99-
(None, None, MetricAggregation.GEOMETRIC_MEAN),
100-
(None, None, MetricAggregation.SUM),
97+
@pytest.mark.parametrize('feature_name,subclass_name,aggregation,value', [
98+
("cat", "orange", ConfusionMatrixAggregation.CONFUSION_MATRIX,
99+
(0, 1, 2, 3)),
100+
("cat", None, ConfusionMatrixAggregation.CONFUSION_MATRIX, (0, 1, 2, 3)),
101+
(None, None, ConfusionMatrixAggregation.CONFUSION_MATRIX, (0, 1, 2, 3)),
102+
(None, None, None, (0, 1, 2, 3)),
103+
("cat", "orange", ConfusionMatrixAggregation.CONFUSION_MATRIX, {
104+
0.1: (0, 1, 2, 3),
105+
0.3: (0, 1, 2, 3),
106+
0.4: (0, 1, 2, 3)
107+
}),
101108
])
102-
def test_custom_scalar_metric(feature_name, subclass_name, aggregation):
103-
value = 0.5
109+
def test_custom_confusison_matrix_metric(feature_name, subclass_name,
110+
aggregation, value):
104111
kwargs = {'aggregation': aggregation} if aggregation is not None else {}
105-
metric = ScalarMetric(metric_name="iou",
106-
value=value,
107-
feature_name=feature_name,
108-
subclass_name=subclass_name,
109-
**kwargs)
112+
metric = ConfusionMatrixMetric(metric_name="confusion_matrix_50_pct_iou",
113+
value=value,
114+
feature_name=feature_name,
115+
subclass_name=subclass_name,
116+
**kwargs)
110117
assert metric.value == value
111118

112119
label = Label(data=ImageData(uid="ckrmd9q8g000009mg6vej7hzg"),
@@ -124,18 +131,58 @@ def test_custom_scalar_metric(feature_name, subclass_name, aggregation):
124131
'value':
125132
value,
126133
'metric_name':
127-
'iou',
134+
'confusion_matrix_50_pct_iou',
128135
**({
129136
'feature_name': feature_name
130137
} if feature_name else {}),
131138
**({
132139
'subclass_name': subclass_name
133140
} if subclass_name else {}), 'aggregation':
134-
aggregation or MetricAggregation.ARITHMETIC_MEAN,
141+
aggregation or ConfusionMatrixAggregation.CONFUSION_MATRIX,
135142
'extra': {}
136143
}],
137144
'extra': {},
138145
'uid': None
139146
}
140147
assert label.dict() == expected
141-
next(LabelList([label])).dict() == expected
148+
assert next(LabelList([label])).dict() == expected
149+
150+
151+
def test_name_exists():
152+
# Name is only required for ConfusionMatrixMetric for now.
153+
with pytest.raises(ValidationError) as exc_info:
154+
metric = ConfusionMatrixMetric(value=[0, 1, 2, 3])
155+
assert "field required (type=value_error.missing)" in str(exc_info.value)
156+
157+
158+
def test_invalid_aggregations():
159+
with pytest.raises(ValidationError) as exc_info:
160+
metric = ScalarMetric(
161+
metric_name="invalid aggregation",
162+
value=0.1,
163+
aggregation=ConfusionMatrixAggregation.CONFUSION_MATRIX)
164+
assert "value is not a valid enumeration member" in str(exc_info.value)
165+
with pytest.raises(ValidationError) as exc_info:
166+
metric = ConfusionMatrixMetric(metric_name="invalid aggregation",
167+
value=[0, 1, 2, 3],
168+
aggregation=ScalarMetricAggregation.SUM)
169+
assert "value is not a valid enumeration member" in str(exc_info.value)
170+
171+
172+
def test_invalid_number_of_confidence_scores():
173+
with pytest.raises(ValidationError) as exc_info:
174+
metric = ScalarMetric(metric_name="too few scores", value={0.1: 0.1})
175+
assert "Number of confidence scores must be greater" in str(exc_info.value)
176+
with pytest.raises(ValidationError) as exc_info:
177+
metric = ConfusionMatrixMetric(metric_name="too few scores",
178+
value={0.1: [0, 1, 2, 3]})
179+
assert "Number of confidence scores must be greater" in str(exc_info.value)
180+
with pytest.raises(ValidationError) as exc_info:
181+
metric = ScalarMetric(metric_name="too many scores",
182+
value={i / 20.: 0.1 for i in range(20)})
183+
assert "Number of confidence scores must be greater" in str(exc_info.value)
184+
with pytest.raises(ValidationError) as exc_info:
185+
metric = ConfusionMatrixMetric(
186+
metric_name="too many scores",
187+
value={i / 20.: [0, 1, 2, 3] for i in range(20)})
188+
assert "Number of confidence scores must be greater" in str(exc_info.value)

tests/integration/conftest.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,14 +212,9 @@ def datarow(dataset, image_url):
212212
@pytest.fixture
213213
def label_pack(project, rand_gen, image_url):
214214
client = project.client
215-
<<<<<<< HEAD
216215
dataset = client.create_dataset(name=rand_gen(str))
217216
project.datasets.connect(dataset)
218217
data_row = dataset.create_data_row(row_data=IMG_URL)
219-
=======
220-
dataset = client.create_dataset(name=rand_gen(str), projects=project)
221-
data_row = dataset.create_data_row(row_data=image_url)
222-
>>>>>>> 6970d60beebc6c969a81c891b4c88db7c57f98df
223218
label = project.create_label(data_row=data_row, label=rand_gen(str))
224219
time.sleep(10)
225220
yield LabelPack(project, dataset, data_row, label)

0 commit comments

Comments
 (0)