Skip to content

Commit e6ceb7e

Browse files
authored
Merge pull request #428 from Labelbox/al-1461
working progress of updating metrics for classifications
2 parents 6b6bfa6 + 44b1175 commit e6ceb7e

File tree

5 files changed

+160
-60
lines changed

5 files changed

+160
-60
lines changed

labelbox/data/annotation_types/geometry/mask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def draw(self,
9393

9494
canvas = canvas if canvas is not None else np.zeros(tuple(dims),
9595
dtype=np.uint8)
96-
canvas[mask.astype(np.bool)] = color
96+
canvas[mask.astype(bool)] = color
9797
return canvas
9898

9999
def _extract_polygons_from_contours(self, contours: List) -> MultiPolygon:

labelbox/data/metrics/group.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
"""
44
from collections import defaultdict
55
from typing import Dict, List, Tuple, Union
6+
7+
from labelbox.data.annotation_types.annotation import ClassificationAnnotation, Checklist, Radio, Text
8+
from labelbox.data.annotation_types.classification.classification import ClassificationAnswer
69
try:
710
from typing import Literal
811
except ImportError:
912
from typing_extensions import Literal
1013

1114
from ..annotation_types.feature import FeatureSchema
12-
from ..annotation_types import ObjectAnnotation, Label, LabelList
15+
from ..annotation_types import ObjectAnnotation, ClassificationAnnotation, Label, LabelList
1316

1417

1518
def get_identifying_key(
@@ -56,6 +59,19 @@ def all_have_key(features: List[FeatureSchema]) -> Tuple[bool, bool]:
5659
all_names = True
5760
all_schemas = True
5861
for feature in features:
62+
if isinstance(feature, ClassificationAnnotation):
63+
if isinstance(feature.value, Checklist):
64+
all_schemas, all_names = all_have_key(feature.value.answer)
65+
elif isinstance(feature.value, Text):
66+
if feature.name is None:
67+
all_names = False
68+
if feature.feature_schema_id is None:
69+
all_schemas = False
70+
else:
71+
if feature.value.answer.name is None:
72+
all_names = False
73+
if feature.value.answer.feature_schema_id is None:
74+
all_schemas = False
5975
if feature.name is None:
6076
all_names = False
6177
if feature.feature_schema_id is None:
@@ -155,7 +171,25 @@ def _create_feature_lookup(features: List[FeatureSchema],
155171
"""
156172
grouped_features = defaultdict(list)
157173
for feature in features:
158-
grouped_features[getattr(feature, key)].append(feature)
174+
if isinstance(feature, ClassificationAnnotation):
175+
#checklists
176+
if isinstance(feature.value, Checklist):
177+
for answer in feature.value.answer:
178+
new_answer = Radio(answer=answer)
179+
new_annotation = ClassificationAnnotation(
180+
value=new_answer,
181+
name=answer.name,
182+
feature_schema_id=answer.feature_schema_id)
183+
184+
grouped_features[getattr(answer,
185+
key)].append(new_annotation)
186+
elif isinstance(feature.value, Text):
187+
grouped_features[getattr(feature, key)].append(feature)
188+
else:
189+
grouped_features[getattr(feature.value.answer,
190+
key)].append(feature)
191+
else:
192+
grouped_features[getattr(feature, key)].append(feature)
159193
return grouped_features
160194

161195

tests/data/metrics/confusion_matrix/conftest.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -183,24 +183,27 @@ def radio_pairs():
183183
return [
184184
NameSpace(predictions=[get_radio("is_animal", answer_name="yes")],
185185
ground_truths=[get_radio("is_animal", answer_name="yes")],
186-
expected={'is_animal': [1, 0, 0, 0]}),
186+
expected={'yes': [1, 0, 0, 0]}),
187187
NameSpace(predictions=[get_radio("is_animal", answer_name="yes")],
188188
ground_truths=[get_radio("is_animal", answer_name="no")],
189-
expected={'is_animal': [0, 1, 0, 1]}),
189+
expected={
190+
'no': [0, 0, 0, 1],
191+
'yes': [0, 1, 0, 0]
192+
}),
190193
NameSpace(predictions=[get_radio("is_animal", answer_name="yes")],
191194
ground_truths=[],
192-
expected={'is_animal': [0, 1, 0, 0]}),
195+
expected={'yes': [0, 1, 0, 0]}),
193196
NameSpace(predictions=[],
194197
ground_truths=[get_radio("is_animal", answer_name="yes")],
195-
expected={'is_animal': [0, 0, 0, 1]}),
198+
expected={'yes': [0, 0, 0, 1]}),
196199
NameSpace(predictions=[
197200
get_radio("is_animal", answer_name="yes"),
198201
get_radio("is_short", answer_name="no")
199202
],
200203
ground_truths=[get_radio("is_animal", answer_name="yes")],
201204
expected={
202-
'is_animal': [1, 0, 0, 0],
203-
'is_short': [0, 1, 0, 0]
205+
'no': [0, 1, 0, 0],
206+
'yes': [1, 0, 0, 0]
204207
}),
205208
#Not supported yet:
206209
# NameSpace(
@@ -221,18 +224,18 @@ def checklist_pairs():
221224
get_checklist("animal_attributes",
222225
answer_names=["striped"])
223226
],
224-
expected={'animal_attributes': [1, 0, 0, 0]}),
227+
expected={'striped': [1, 0, 0, 0]}),
225228
NameSpace(predictions=[
226229
get_checklist("animal_attributes", answer_names=["striped"])
227230
],
228231
ground_truths=[],
229-
expected={'animal_attributes': [0, 1, 0, 0]}),
232+
expected={'striped': [0, 1, 0, 0]}),
230233
NameSpace(predictions=[],
231234
ground_truths=[
232235
get_checklist("animal_attributes",
233236
answer_names=["striped"])
234237
],
235-
expected={'animal_attributes': [0, 0, 0, 1]}),
238+
expected={'striped': [0, 0, 0, 1]}),
236239
NameSpace(predictions=[
237240
get_checklist("animal_attributes",
238241
answer_names=["striped", "short"])
@@ -241,15 +244,21 @@ def checklist_pairs():
241244
get_checklist("animal_attributes",
242245
answer_names=["striped"])
243246
],
244-
expected={'animal_attributes': [1, 1, 0, 0]}),
247+
expected={
248+
'short': [0, 1, 0, 0],
249+
'striped': [1, 0, 0, 0]
250+
}),
245251
NameSpace(predictions=[
246252
get_checklist("animal_attributes", answer_names=["striped"])
247253
],
248254
ground_truths=[
249255
get_checklist("animal_attributes",
250256
answer_names=["striped", "short"])
251257
],
252-
expected={'animal_attributes': [1, 0, 0, 1]}),
258+
expected={
259+
'short': [0, 0, 0, 1],
260+
'striped': [1, 0, 0, 0]
261+
}),
253262
NameSpace(predictions=[
254263
get_checklist("animal_attributes",
255264
answer_names=["striped", "short", "black"])
@@ -258,7 +267,11 @@ def checklist_pairs():
258267
get_checklist("animal_attributes",
259268
answer_names=["striped", "short"])
260269
],
261-
expected={'animal_attributes': [2, 1, 0, 0]}),
270+
expected={
271+
'black': [0, 1, 0, 0],
272+
'short': [1, 0, 0, 0],
273+
'striped': [1, 0, 0, 0]
274+
}),
262275
NameSpace(predictions=[
263276
get_checklist("animal_attributes",
264277
answer_names=["striped", "short", "black"]),
@@ -270,8 +283,11 @@ def checklist_pairs():
270283
get_checklist("animal_name", answer_names=["pup"])
271284
],
272285
expected={
273-
'animal_attributes': [2, 1, 0, 0],
274-
'animal_name': [1, 1, 0, 0]
286+
'black': [0, 1, 0, 0],
287+
'doggy': [0, 1, 0, 0],
288+
'pup': [1, 0, 0, 0],
289+
'short': [1, 0, 0, 0],
290+
'striped': [1, 0, 0, 0]
275291
})
276292

277293
#Not supported yet:

tests/data/metrics/iou/data_row/conftest.py

Lines changed: 66 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88

99
class NameSpace(SimpleNamespace):
1010

11-
def __init__(self, predictions, labels, expected, classifications=None):
11+
def __init__(self,
12+
predictions,
13+
labels,
14+
expected,
15+
data_row_expected=None,
16+
classifications=None):
1217
super(NameSpace,
1318
self).__init__(predictions=predictions,
1419
labels={
@@ -19,7 +24,8 @@ def __init__(self, predictions, labels, expected, classifications=None):
1924
'classifications': classifications or []
2025
}
2126
},
22-
expected=expected)
27+
expected=expected,
28+
data_row_expected=data_row_expected)
2329

2430

2531
@pytest.fixture
@@ -314,41 +320,45 @@ def empty_radio_prediction():
314320

315321
@pytest.fixture
316322
def matching_checklist():
317-
return NameSpace(labels=[],
318-
classifications=[{
319-
'featureId':
320-
'1234567890111213141516171',
321-
'schemaId':
322-
'ckppid25v0000aeyjmxfwlc7t',
323-
'uuid':
324-
'76e0dcea-fe46-43e5-95f5-a5e3f378520a',
325-
'schemaId':
326-
'ckppid25v0000aeyjmxfwlc7t',
327-
'answers': [{
328-
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
329-
}, {
330-
'schemaId': 'ckppide010001aeyj0yhiaghc'
331-
}, {
332-
'schemaId': 'ckppidq4u0002aeyjmcc4toxw'
333-
}]
334-
}],
335-
predictions=[{
336-
'uuid':
337-
'76e0dcea-fe46-43e5-95f5-a5e3f378520a',
338-
'schemaId':
339-
'ckppid25v0000aeyjmxfwlc7t',
340-
'dataRow': {
341-
'id': 'ckppihxc10005aeyjen11h7jh'
342-
},
343-
'answers': [{
344-
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
345-
}, {
346-
'schemaId': 'ckppide010001aeyj0yhiaghc'
347-
}, {
348-
'schemaId': 'ckppidq4u0002aeyjmcc4toxw'
349-
}]
350-
}],
351-
expected=1.)
323+
return NameSpace(
324+
labels=[],
325+
classifications=[{
326+
'featureId':
327+
'1234567890111213141516171',
328+
'schemaId':
329+
'ckppid25v0000aeyjmxfwlc7t',
330+
'uuid':
331+
'76e0dcea-fe46-43e5-95f5-a5e3f378520a',
332+
'schemaId':
333+
'ckppid25v0000aeyjmxfwlc7t',
334+
'answers': [{
335+
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
336+
}, {
337+
'schemaId': 'ckppide010001aeyj0yhiaghc'
338+
}, {
339+
'schemaId': 'ckppidq4u0002aeyjmcc4toxw'
340+
}]
341+
}],
342+
predictions=[{
343+
'uuid':
344+
'76e0dcea-fe46-43e5-95f5-a5e3f378520a',
345+
'schemaId':
346+
'ckppid25v0000aeyjmxfwlc7t',
347+
'dataRow': {
348+
'id': 'ckppihxc10005aeyjen11h7jh'
349+
},
350+
'answers': [{
351+
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
352+
}, {
353+
'schemaId': 'ckppide010001aeyj0yhiaghc'
354+
}, {
355+
'schemaId': 'ckppidq4u0002aeyjmcc4toxw'
356+
}]
357+
}],
358+
data_row_expected=1.,
359+
# expected = [1.]
360+
# expected=[1., 1., 1.])
361+
expected={1.0: 3})
352362

353363

354364
@pytest.fixture
@@ -391,7 +401,11 @@ def partially_matching_checklist_1():
391401
'schemaId': 'ckppiebx80004aeyjuwvos69e'
392402
}]
393403
}],
394-
expected=0.6)
404+
data_row_expected=0.6,
405+
expected={
406+
0.0: 2,
407+
1.0: 3
408+
})
395409

396410

397411
@pytest.fixture
@@ -430,7 +444,11 @@ def partially_matching_checklist_2():
430444
'schemaId': 'ckppiebx80004aeyjuwvos69e'
431445
}]
432446
}],
433-
expected=0.5)
447+
data_row_expected=0.5,
448+
expected={
449+
1.0: 2,
450+
0.0: 2
451+
})
434452

435453

436454
@pytest.fixture
@@ -469,7 +487,11 @@ def partially_matching_checklist_3():
469487
'schemaId': 'ckppide010001aeyj0yhiaghc'
470488
}]
471489
}],
472-
expected=0.5)
490+
data_row_expected=0.5,
491+
expected={
492+
1.0: 2,
493+
0.0: 2
494+
})
473495

474496

475497
@pytest.fixture
@@ -485,7 +507,8 @@ def empty_checklist_label():
485507
'schemaId': 'ckppid25v0000aeyjmxfwlc7t'
486508
}]
487509
}],
488-
expected=0)
510+
data_row_expected=0.0,
511+
expected={0.0: 1})
489512

490513

491514
@pytest.fixture
@@ -502,7 +525,8 @@ def empty_checklist_prediction():
502525
}]
503526
}],
504527
predictions=[],
505-
expected=0)
528+
data_row_expected=0.0,
529+
expected={0.0: 1})
506530

507531

508532
@pytest.fixture

tests/data/metrics/iou/data_row/test_data_row_iou.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,32 @@ def check_iou(pair, mask=None):
3131
assert math.isclose(feature_ious[0].value, pair.expected)
3232

3333

34+
def check_iou_checklist(pair, mask=None):
35+
"""specialized test since checklists have more than one feature ious """
36+
default = Label(data=ImageData(uid="ckppihxc10005aeyjen11h7jh"))
37+
prediction = next(NDJsonConverter.deserialize(pair.predictions), default)
38+
label = next(LBV1Converter.deserialize([pair.labels]))
39+
if mask:
40+
for annotation in [*prediction.annotations, *label.annotations]:
41+
if isinstance(annotation.value, Mask):
42+
annotation.value.mask.arr = np.frombuffer(
43+
base64.b64decode(annotation.value.mask.url.encode('utf-8')),
44+
dtype=np.uint8).reshape((32, 32, 3))
45+
assert math.isclose(data_row_miou(label, prediction),
46+
pair.data_row_expected)
47+
assert math.isclose(
48+
miou_metric(label.annotations, prediction.annotations)[0].value,
49+
pair.data_row_expected)
50+
feature_ious = feature_miou_metric(label.annotations,
51+
prediction.annotations)
52+
mapping = {}
53+
for iou in feature_ious:
54+
if not mapping.get(iou.value, None):
55+
mapping[iou.value] = 0
56+
mapping[iou.value] += 1
57+
assert mapping == pair.expected
58+
59+
3460
def strings_to_fixtures(strings):
3561
return [fixture_ref(x) for x in strings]
3662

@@ -70,7 +96,7 @@ def test_radio(pair):
7096
"empty_checklist_prediction",
7197
]))
7298
def test_checklist(pair):
73-
check_iou(pair)
99+
check_iou_checklist(pair)
74100

75101

76102
@parametrize("pair", strings_to_fixtures(["matching_text",

0 commit comments

Comments
 (0)