Skip to content

Commit 3bb2aaf

Browse files
authored
Merge pull request #285 from Labelbox/ms/compute-conf-matrix
confusion matrix metrics
2 parents c18c0f4 + 183ba64 commit 3bb2aaf

File tree

25 files changed

+1243
-161
lines changed

25 files changed

+1243
-161
lines changed

examples/model_diagnostics/custom-metrics.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@
164164
"metadata": {},
165165
"outputs": [],
166166
"source": [
167-
"from labelbox.data.annotation_types import ScalarMetric, MetricAggregation"
167+
"from labelbox.data.annotation_types import ScalarMetric, ScalarMetricAggregation"
168168
]
169169
},
170170
{
@@ -226,14 +226,14 @@
226226
" metric_name = \"true_positives\",\n",
227227
" feature_name = \"cat\",\n",
228228
" value = 3,\n",
229-
" aggregation = MetricAggregation.SUM\n",
229+
" aggregation = ScalarMetricAggregation.SUM\n",
230230
")\n",
231231
"\n",
232232
"feature_metric = ScalarMetric(\n",
233233
" metric_name = \"true_positives\",\n",
234234
" feature_name = \"dog\",\n",
235235
" value = 4,\n",
236-
" aggregation = MetricAggregation.SUM\n",
236+
" aggregation = ScalarMetricAggregation.SUM\n",
237237
")\n"
238238
]
239239
},

labelbox/data/annotation_types/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,8 @@
2929
from .collection import LabelGenerator
3030

3131
from .metrics import ScalarMetric
32-
from .metrics import MetricAggregation
32+
from .metrics import ScalarMetricAggregation
33+
from .metrics import ConfusionMatrixMetric
34+
from .metrics import ConfusionMatrixAggregation
35+
from .metrics import ScalarMetricValue
36+
from .metrics import ConfusionMatrixMetricValue

labelbox/data/annotation_types/label.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from .classification import ClassificationAnswer
1212
from .data import VideoData, TextData, ImageData
1313
from .geometry import Mask
14-
from .metrics import ScalarMetric
14+
from .metrics import ScalarMetric, ConfusionMatrixMetric
1515
from .types import Cuid
1616
from .annotation import (ClassificationAnnotation, ObjectAnnotation,
1717
VideoClassificationAnnotation, VideoObjectAnnotation)
@@ -23,7 +23,7 @@ class Label(BaseModel):
2323
annotations: List[Union[ClassificationAnnotation, ObjectAnnotation,
2424
VideoObjectAnnotation,
2525
VideoClassificationAnnotation, ScalarMetric,
26-
ScalarMetric]] = []
26+
ConfusionMatrixMetric]] = []
2727
extra: Dict[str, Any] = {}
2828

2929
def object_annotations(self) -> List[ObjectAnnotation]:
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .scalar import ScalarMetric
2-
from .aggregations import MetricAggregation
1+
from .scalar import ScalarMetric, ScalarMetricAggregation, ScalarMetricValue
2+
from .confusion_matrix import ConfusionMatrixMetric, ConfusionMatrixAggregation, ConfusionMatrixMetricValue

labelbox/data/annotation_types/metrics/aggregations.py

Lines changed: 0 additions & 8 deletions
This file was deleted.
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from abc import ABC
2+
from pydantic import ValidationError, confloat, BaseModel, validator
3+
from pydantic.error_wrappers import ErrorWrapper
4+
from typing import Dict, Optional, Any, Union
5+
6+
ConfidenceValue = confloat(ge=0, le=1)
7+
8+
MIN_CONFIDENCE_SCORES = 2
9+
MAX_CONFIDENCE_SCORES = 15
10+
11+
12+
class BaseMetric(BaseModel, ABC):
13+
value: Union[Any, Dict[float, Any]]
14+
feature_name: Optional[str] = None
15+
subclass_name: Optional[str] = None
16+
extra: Dict[str, Any] = {}
17+
18+
def dict(self, *args, **kwargs):
19+
res = super().dict(*args, **kwargs)
20+
return {k: v for k, v in res.items() if v is not None}
21+
22+
@validator('value')
23+
def validate_value(cls, value):
24+
if isinstance(value, Dict):
25+
if not (MIN_CONFIDENCE_SCORES <= len(value) <=
26+
MAX_CONFIDENCE_SCORES):
27+
raise ValidationError([
28+
ErrorWrapper(ValueError(
29+
"Number of confidence scores must be greater"
30+
f" than or equal to {MIN_CONFIDENCE_SCORES} and"
31+
f" less than or equal to {MAX_CONFIDENCE_SCORES}. Found {len(value)}"
32+
),
33+
loc='value')
34+
], cls)
35+
return value
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from enum import Enum
2+
from typing import Tuple, Dict, Union
3+
4+
from pydantic import conint, Field
5+
from pydantic.main import BaseModel
6+
7+
from .base import ConfidenceValue, BaseMetric
8+
9+
Count = conint(ge=0, le=1e10)
10+
11+
ConfusionMatrixMetricValue = Tuple[Count, Count, Count, Count]
12+
ConfusionMatrixMetricConfidenceValue = Dict[ConfidenceValue,
13+
ConfusionMatrixMetricValue]
14+
15+
16+
class ConfusionMatrixAggregation(Enum):
17+
CONFUSION_MATRIX = "CONFUSION_MATRIX"
18+
19+
20+
class ConfusionMatrixMetric(BaseMetric):
21+
""" Class representing confusion matrix metrics.
22+
23+
In the editor, this provides precision, recall, and f-scores.
24+
This should be used over multiple scalar metrics so that aggregations are accurate.
25+
26+
Value should be a tuple representing:
27+
[True Positive Count, False Positive Count, True Negative Count, False Negative Count]
28+
29+
aggregation cannot be adjusted for confusion matrix metrics.
30+
"""
31+
metric_name: str
32+
value: Union[ConfusionMatrixMetricValue,
33+
ConfusionMatrixMetricConfidenceValue]
34+
aggregation: ConfusionMatrixAggregation = Field(
35+
ConfusionMatrixAggregation.CONFUSION_MATRIX, const=True)
Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,35 @@
1-
from labelbox.data.annotation_types.metrics.aggregations import MetricAggregation
2-
from typing import Any, Dict, Optional
3-
from pydantic import BaseModel
1+
from typing import Dict, Optional, Union
2+
from enum import Enum
43

4+
from pydantic import confloat
55

6-
class ScalarMetric(BaseModel):
7-
""" Class representing metrics
6+
from .base import ConfidenceValue, BaseMetric
87

9-
# For backwards compatibility, metric_name is optional. This will eventually be deprecated
10-
# The metric_name will be set to a default name in the editor if it is not set.
8+
ScalarMetricValue = confloat(ge=0, le=10_000)
9+
ScalarMetricConfidenceValue = Dict[ConfidenceValue, ScalarMetricValue]
1110

12-
# aggregation will be ignored wihtout providing a metric name.
13-
# Not providing a metric name is deprecated.
11+
12+
class ScalarMetricAggregation(Enum):
13+
ARITHMETIC_MEAN = "ARITHMETIC_MEAN"
14+
GEOMETRIC_MEAN = "GEOMETRIC_MEAN"
15+
HARMONIC_MEAN = "HARMONIC_MEAN"
16+
SUM = "SUM"
17+
18+
19+
class ScalarMetric(BaseMetric):
20+
""" Class representing scalar metrics
21+
22+
For backwards compatibility, metric_name is optional.
23+
The metric_name will be set to a default name in the editor if it is not set.
24+
This is not recommended and support for empty metric_name fields will be removed.
25+
aggregation will be ignored wihtout providing a metric name.
1426
"""
15-
value: float
1627
metric_name: Optional[str] = None
17-
feature_name: Optional[str] = None
18-
subclass_name: Optional[str] = None
19-
aggregation: MetricAggregation = MetricAggregation.ARITHMETIC_MEAN
20-
extra: Dict[str, Any] = {}
28+
value: Union[ScalarMetricValue, ScalarMetricConfidenceValue]
29+
aggregation: ScalarMetricAggregation = ScalarMetricAggregation.ARITHMETIC_MEAN
2130

2231
def dict(self, *args, **kwargs):
2332
res = super().dict(*args, **kwargs)
24-
if res['metric_name'] is None:
33+
if res.get('metric_name') is None:
2534
res.pop('aggregation')
26-
return {k: v for k, v in res.items() if v is not None}
35+
return res

labelbox/data/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .confusion_matrix import confusion_matrix_metric, feature_confusion_matrix_metric
2+
from .iou import miou_metric, feature_miou_metric
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .calculation import *
2+
from .confusion_matrix import *

0 commit comments

Comments
 (0)