1
- from labelbox .data .annotation_types .metrics import ScalarMetricAggregation
2
- from typing import Union , Optional
1
+ from typing import Optional , Union , Type
3
2
4
3
from labelbox .data .annotation_types .data import ImageData , TextData
5
- from labelbox .data .annotation_types .metrics import ScalarMetric
6
4
from labelbox .data .serialization .ndjson .base import NDJsonBase
5
+ from labelbox .data .annotation_types .metrics .scalar import (
6
+ ScalarMetric , ScalarMetricAggregation , ScalarMetricValue ,
7
+ ScalarMetricConfidenceValue )
8
+ from labelbox .data .annotation_types .metrics .confusion_matrix import (
9
+ ConfusionMatrixAggregation , ConfusionMatrixMetric ,
10
+ ConfusionMatrixMetricValue , ConfusionMatrixMetricConfidenceValue )
7
11
8
12
9
- class NDScalarMetric (NDJsonBase ):
13
+ class BaseNDMetric (NDJsonBase ):
10
14
metric_value : float
11
- metric_name : Optional [str ]
12
15
feature_name : Optional [str ] = None
13
16
subclass_name : Optional [str ] = None
14
- aggregation : ScalarMetricAggregation = ScalarMetricAggregation .ARITHMETIC_MEAN .value
17
+
18
+ class Config :
19
+ use_enum_values = True
20
+
21
+ def dict (self , * args , ** kwargs ):
22
+ res = super ().dict (* args , ** kwargs )
23
+ for field in ['featureName' , 'subclassName' ]:
24
+ if res [field ] is None :
25
+ res .pop (field )
26
+ return res
27
+
28
+
29
+ class NDConfusionMatrixMetric (BaseNDMetric ):
30
+ metric_value : Union [ConfusionMatrixMetricValue ,
31
+ ConfusionMatrixMetricConfidenceValue ]
32
+ metric_name : str
33
+ aggregation : ConfusionMatrixAggregation
34
+
35
+ def to_common (self ) -> ConfusionMatrixMetric :
36
+ return ConfusionMatrixMetric (value = self .metric_value ,
37
+ metric_name = self .metric_name ,
38
+ feature_name = self .feature_name ,
39
+ subclass_name = self .subclass_name ,
40
+ aggregation = self .aggregation ,
41
+ extra = {'uuid' : self .uuid })
42
+
43
+ @classmethod
44
+ def from_common (
45
+ cls , metric : ConfusionMatrixMetric ,
46
+ data : Union [TextData , ImageData ]) -> "NDConfusionMatrixMetric" :
47
+ return cls (uuid = metric .extra .get ('uuid' ),
48
+ metric_value = metric .value ,
49
+ metric_name = metric .metric_name ,
50
+ feature_name = metric .feature_name ,
51
+ subclass_name = metric .subclass_name ,
52
+ aggregation = metric .aggregation ,
53
+ data_row = {'id' : data .uid })
54
+
55
+
56
+ class NDScalarMetric (BaseNDMetric ):
57
+ metric_value : Union [ScalarMetricValue , ScalarMetricConfidenceValue ]
58
+ metric_name : Optional [str ]
59
+ aggregation : ScalarMetricAggregation = ScalarMetricAggregation .ARITHMETIC_MEAN
15
60
16
61
def to_common (self ) -> ScalarMetric :
17
- return ScalarMetric (
18
- value = self .metric_value ,
19
- metric_name = self .metric_name ,
20
- feature_name = self .feature_name ,
21
- subclass_name = self .subclass_name ,
22
- aggregation = ScalarMetricAggregation [self .aggregation ],
23
- extra = {'uuid' : self .uuid })
62
+ return ScalarMetric (value = self .metric_value ,
63
+ metric_name = self .metric_name ,
64
+ feature_name = self .feature_name ,
65
+ subclass_name = self .subclass_name ,
66
+ aggregation = self .aggregation ,
67
+ extra = {'uuid' : self .uuid })
24
68
25
69
@classmethod
26
70
def from_common (cls , metric : ScalarMetric ,
@@ -35,38 +79,39 @@ def from_common(cls, metric: ScalarMetric,
35
79
36
80
def dict (self , * args , ** kwargs ):
37
81
res = super ().dict (* args , ** kwargs )
38
- for field in ['featureName' , 'subclassName' ]:
39
- if res [field ] is None :
40
- res .pop (field )
41
-
42
82
# For backwards compatibility.
43
83
if res ['metricName' ] is None :
44
84
res .pop ('metricName' )
45
85
res .pop ('aggregation' )
46
86
return res
47
87
48
- class Config :
49
- use_enum_values = True
50
-
51
88
52
89
class NDMetricAnnotation :
53
90
54
91
@classmethod
55
- def to_common (cls , annotation : "NDScalarMetric" ) -> ScalarMetric :
92
+ def to_common (
93
+ cls , annotation : Union [NDScalarMetric , NDConfusionMatrixMetric ]
94
+ ) -> Union [ScalarMetric , ConfusionMatrixMetric ]:
56
95
return annotation .to_common ()
57
96
58
97
@classmethod
59
- def from_common (cls , annotation : ScalarMetric ,
60
- data : Union [TextData , ImageData ]) -> "NDScalarMetric" :
98
+ def from_common (
99
+ cls , annotation : Union [ScalarMetric ,
100
+ ConfusionMatrixMetric ], data : Union [TextData ,
101
+ ImageData ]
102
+ ) -> Union [NDScalarMetric , NDConfusionMatrixMetric ]:
61
103
obj = cls .lookup_object (annotation )
62
104
return obj .from_common (annotation , data )
63
105
64
106
@staticmethod
65
- def lookup_object (metric : ScalarMetric ) -> "NDScalarMetric" :
107
+ def lookup_object (
108
+ annotation : Union [ScalarMetric , ConfusionMatrixMetric ]
109
+ ) -> Union [Type [NDScalarMetric ], Type [NDConfusionMatrixMetric ]]:
66
110
result = {
67
111
ScalarMetric : NDScalarMetric ,
68
- }.get (type (metric ))
112
+ ConfusionMatrixMetric : NDConfusionMatrixMetric ,
113
+ }.get (type (annotation ))
69
114
if result is None :
70
115
raise TypeError (
71
- f"Unable to convert object to MAL format. `{ type (metric )} `" )
116
+ f"Unable to convert object to MAL format. `{ type (annotation )} `" )
72
117
return result
0 commit comments