1
+ from pydantic import ValidationError
1
2
import pytest
2
3
3
- from labelbox .data .annotation_types .metrics . aggregations import MetricAggregation
4
- from labelbox .data .annotation_types .metrics . scalar import ScalarMetric
4
+ from labelbox .data .annotation_types .metrics import ConfusionMatrixAggregation , ScalarMetricAggregation
5
+ from labelbox .data .annotation_types .metrics import ConfusionMatrixMetric , ScalarMetric
5
6
from labelbox .data .annotation_types .collection import LabelList
6
7
from labelbox .data .annotation_types import ScalarMetric , Label , ImageData
7
8
@@ -30,23 +31,28 @@ def test_legacy_scalar_metric():
30
31
'uid' : None
31
32
}
32
33
assert label .dict () == expected
33
- next (LabelList ([label ])).dict () == expected
34
+ assert next (LabelList ([label ])).dict () == expected
34
35
35
36
36
37
# TODO: Test with confidence
37
38
38
- @pytest .mark .parametrize ('feature_name,subclass_name,aggregation' , [
39
- ("cat" , "orange" , MetricAggregation .ARITHMETIC_MEAN ),
40
- ("cat" , None , MetricAggregation .ARITHMETIC_MEAN ),
41
- (None , None , MetricAggregation .ARITHMETIC_MEAN ),
42
- (None , None , None ),
43
- ("cat" , "orange" , MetricAggregation .ARITHMETIC_MEAN ),
44
- ("cat" , None , MetricAggregation .HARMONIC_MEAN ),
45
- (None , None , MetricAggregation .GEOMETRIC_MEAN ),
46
- (None , None , MetricAggregation .SUM )
39
+
40
+ @pytest .mark .parametrize ('feature_name,subclass_name,aggregation,value' , [
41
+ ("cat" , "orange" , ScalarMetricAggregation .ARITHMETIC_MEAN , 0.5 ),
42
+ ("cat" , None , ScalarMetricAggregation .ARITHMETIC_MEAN , 0.5 ),
43
+ (None , None , ScalarMetricAggregation .ARITHMETIC_MEAN , 0.5 ),
44
+ (None , None , None , 0.5 ),
45
+ ("cat" , "orange" , ScalarMetricAggregation .ARITHMETIC_MEAN , 0.5 ),
46
+ ("cat" , None , ScalarMetricAggregation .HARMONIC_MEAN , 0.5 ),
47
+ (None , None , ScalarMetricAggregation .GEOMETRIC_MEAN , 0.5 ),
48
+ (None , None , ScalarMetricAggregation .SUM , 0.5 ),
49
+ ("cat" , "orange" , ScalarMetricAggregation .ARITHMETIC_MEAN , {
50
+ 0.1 : 0.2 ,
51
+ 0.3 : 0.5 ,
52
+ 0.4 : 0.8
53
+ }),
47
54
])
48
- def test_custom_scalar_metric (feature_name , subclass_name , aggregation ):
49
- value = 0.5
55
+ def test_custom_scalar_metric (feature_name , subclass_name , aggregation , value ):
50
56
kwargs = {'aggregation' : aggregation } if aggregation is not None else {}
51
57
metric = ScalarMetric (metric_name = "iou" ,
52
58
value = value ,
@@ -77,36 +83,37 @@ def test_custom_scalar_metric(feature_name, subclass_name, aggregation):
77
83
** ({
78
84
'subclass_name' : subclass_name
79
85
} if subclass_name else {}), 'aggregation' :
80
- aggregation or MetricAggregation .ARITHMETIC_MEAN ,
86
+ aggregation or ScalarMetricAggregation .ARITHMETIC_MEAN ,
81
87
'extra' : {}
82
88
}],
83
89
'extra' : {},
84
90
'uid' : None
85
91
}
86
- assert label .dict () == expected
87
- next (LabelList ([label ])).dict () == expected
88
-
89
92
93
+ assert label .dict () == expected
94
+ assert next (LabelList ([label ])).dict () == expected
90
95
91
96
92
- @pytest .mark .parametrize ('feature_name,subclass_name,aggregation' , [
93
- ("cat" , "orange" , MetricAggregation .ARITHMETIC_MEAN ),
94
- ("cat" , None , MetricAggregation .ARITHMETIC_MEAN ),
95
- (None , None , MetricAggregation .ARITHMETIC_MEAN ),
96
- (None , None , None ),
97
- ("cat" , "orange" , MetricAggregation .ARITHMETIC_MEAN ),
98
- ("cat" , None , MetricAggregation .HARMONIC_MEAN ),
99
- (None , None , MetricAggregation .GEOMETRIC_MEAN ),
100
- (None , None , MetricAggregation .SUM ),
97
+ @pytest .mark .parametrize ('feature_name,subclass_name,aggregation,value' , [
98
+ ("cat" , "orange" , ConfusionMatrixAggregation .CONFUSION_MATRIX ,
99
+ (0 , 1 , 2 , 3 )),
100
+ ("cat" , None , ConfusionMatrixAggregation .CONFUSION_MATRIX , (0 , 1 , 2 , 3 )),
101
+ (None , None , ConfusionMatrixAggregation .CONFUSION_MATRIX , (0 , 1 , 2 , 3 )),
102
+ (None , None , None , (0 , 1 , 2 , 3 )),
103
+ ("cat" , "orange" , ConfusionMatrixAggregation .CONFUSION_MATRIX , {
104
+ 0.1 : (0 , 1 , 2 , 3 ),
105
+ 0.3 : (0 , 1 , 2 , 3 ),
106
+ 0.4 : (0 , 1 , 2 , 3 )
107
+ }),
101
108
])
102
- def test_custom_scalar_metric (feature_name , subclass_name , aggregation ):
103
- value = 0.5
109
+ def test_custom_confusison_matrix_metric (feature_name , subclass_name ,
110
+ aggregation , value ):
104
111
kwargs = {'aggregation' : aggregation } if aggregation is not None else {}
105
- metric = ScalarMetric (metric_name = "iou " ,
106
- value = value ,
107
- feature_name = feature_name ,
108
- subclass_name = subclass_name ,
109
- ** kwargs )
112
+ metric = ConfusionMatrixMetric (metric_name = "confusion_matrix_50_pct_iou " ,
113
+ value = value ,
114
+ feature_name = feature_name ,
115
+ subclass_name = subclass_name ,
116
+ ** kwargs )
110
117
assert metric .value == value
111
118
112
119
label = Label (data = ImageData (uid = "ckrmd9q8g000009mg6vej7hzg" ),
@@ -124,18 +131,58 @@ def test_custom_scalar_metric(feature_name, subclass_name, aggregation):
124
131
'value' :
125
132
value ,
126
133
'metric_name' :
127
- 'iou ' ,
134
+ 'confusion_matrix_50_pct_iou ' ,
128
135
** ({
129
136
'feature_name' : feature_name
130
137
} if feature_name else {}),
131
138
** ({
132
139
'subclass_name' : subclass_name
133
140
} if subclass_name else {}), 'aggregation' :
134
- aggregation or MetricAggregation . ARITHMETIC_MEAN ,
141
+ aggregation or ConfusionMatrixAggregation . CONFUSION_MATRIX ,
135
142
'extra' : {}
136
143
}],
137
144
'extra' : {},
138
145
'uid' : None
139
146
}
140
147
assert label .dict () == expected
141
- next (LabelList ([label ])).dict () == expected
148
+ assert next (LabelList ([label ])).dict () == expected
149
+
150
+
151
+ def test_name_exists ():
152
+ # Name is only required for ConfusionMatrixMetric for now.
153
+ with pytest .raises (ValidationError ) as exc_info :
154
+ metric = ConfusionMatrixMetric (value = [0 , 1 , 2 , 3 ])
155
+ assert "field required (type=value_error.missing)" in str (exc_info .value )
156
+
157
+
158
+ def test_invalid_aggregations ():
159
+ with pytest .raises (ValidationError ) as exc_info :
160
+ metric = ScalarMetric (
161
+ metric_name = "invalid aggregation" ,
162
+ value = 0.1 ,
163
+ aggregation = ConfusionMatrixAggregation .CONFUSION_MATRIX )
164
+ assert "value is not a valid enumeration member" in str (exc_info .value )
165
+ with pytest .raises (ValidationError ) as exc_info :
166
+ metric = ConfusionMatrixMetric (metric_name = "invalid aggregation" ,
167
+ value = [0 , 1 , 2 , 3 ],
168
+ aggregation = ScalarMetricAggregation .SUM )
169
+ assert "value is not a valid enumeration member" in str (exc_info .value )
170
+
171
+
172
+ def test_invalid_number_of_confidence_scores ():
173
+ with pytest .raises (ValidationError ) as exc_info :
174
+ metric = ScalarMetric (metric_name = "too few scores" , value = {0.1 : 0.1 })
175
+ assert "Number of confidence scores must be greater" in str (exc_info .value )
176
+ with pytest .raises (ValidationError ) as exc_info :
177
+ metric = ConfusionMatrixMetric (metric_name = "too few scores" ,
178
+ value = {0.1 : [0 , 1 , 2 , 3 ]})
179
+ assert "Number of confidence scores must be greater" in str (exc_info .value )
180
+ with pytest .raises (ValidationError ) as exc_info :
181
+ metric = ScalarMetric (metric_name = "too many scores" ,
182
+ value = {i / 20. : 0.1 for i in range (20 )})
183
+ assert "Number of confidence scores must be greater" in str (exc_info .value )
184
+ with pytest .raises (ValidationError ) as exc_info :
185
+ metric = ConfusionMatrixMetric (
186
+ metric_name = "too many scores" ,
187
+ value = {i / 20. : [0 , 1 , 2 , 3 ] for i in range (20 )})
188
+ assert "Number of confidence scores must be greater" in str (exc_info .value )
0 commit comments