Skip to content

Commit 045e7ea

Browse files
author
Matt Sokoloff
committed
add missing files
1 parent 2a49913 commit 045e7ea

File tree

3 files changed

+69
-1
lines changed

3 files changed

+69
-1
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from abc import ABC
2+
from pydantic import ValidationError, confloat, BaseModel, validator
3+
from pydantic.error_wrappers import ErrorWrapper
4+
from typing import Dict, Optional, Any, Union
5+
6+
ConfidenceValue = confloat(ge=0, le=1)
7+
8+
MIN_CONFIDENCE_SCORES = 2
9+
MAX_CONFIDENCE_SCORES = 15
10+
11+
12+
class BaseMetric(BaseModel, ABC):
13+
value: Union[Any, Dict[float, Any]]
14+
feature_name: Optional[str] = None
15+
subclass_name: Optional[str] = None
16+
extra: Dict[str, Any] = {}
17+
18+
def dict(self, *args, **kwargs):
19+
res = super().dict(*args, **kwargs)
20+
return {k: v for k, v in res.items() if v is not None}
21+
22+
@validator('value')
23+
def validate_value(cls, value):
24+
if isinstance(value, Dict):
25+
if not (MIN_CONFIDENCE_SCORES <= len(value) <=
26+
MAX_CONFIDENCE_SCORES):
27+
raise ValidationError([
28+
ErrorWrapper(ValueError(
29+
"Number of confidence scores must be greater"
30+
f" than or equal to {MIN_CONFIDENCE_SCORES} and"
31+
f" less than or equal to {MAX_CONFIDENCE_SCORES}. Found {len(value)}"
32+
),
33+
loc='value')
34+
], cls)
35+
return value
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from enum import Enum
2+
from typing import Tuple, Dict, Union
3+
4+
from pydantic import conint, Field
5+
6+
from .base import ConfidenceValue, BaseMetric
7+
8+
Count = conint(ge=0, le=10_000)
9+
ConfusionMatrixMetricValue = Tuple[Count, Count, Count, Count]
10+
ConfusionMatrixMetricConfidenceValue = Dict[ConfidenceValue,
11+
ConfusionMatrixMetricValue]
12+
13+
14+
class ConfusionMatrixAggregation(Enum):
15+
CONFUSION_MATRIX = "CONFUSION_MATRIX"
16+
17+
18+
class ConfusionMatrixMetric(BaseMetric):
19+
""" Class representing confusion matrix metrics.
20+
21+
In the editor, this provides precision, recall, and f-scores.
22+
This should be used over multiple scalar metrics so that aggregations are accurate.
23+
24+
value should be a tuple representing:
25+
[True Positive Count, False Positive Count, True Negative Count, False Negative Count]
26+
27+
aggregation cannot be adjusted for confusion matrix metrics.
28+
"""
29+
metric_name: str
30+
value: Union[ConfusionMatrixMetricValue,
31+
ConfusionMatrixMetricConfidenceValue]
32+
aggregation: ConfusionMatrixAggregation = Field(
33+
ConfusionMatrixAggregation.CONFUSION_MATRIX, const=True)

labelbox/data/annotation_types/metrics/scalar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class ScalarMetric(BaseMetric):
2525
aggregation will be ignored wihtout providing a metric name.
2626
"""
2727
metric_name: Optional[str] = None
28-
value: Union[float, ScalarMetricConfidenceValue]
28+
value: Union[ScalarMetricValue, ScalarMetricConfidenceValue]
2929
aggregation: ScalarMetricAggregation = ScalarMetricAggregation.ARITHMETIC_MEAN
3030

3131
def dict(self, *args, **kwargs):

0 commit comments

Comments
 (0)