Skip to content

Commit d153f7d

Browse files
committed
updates to code and test cases and fix some deprecation
1 parent e9d956d commit d153f7d

File tree

5 files changed

+157
-70
lines changed

5 files changed

+157
-70
lines changed

labelbox/data/metrics/group.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from collections import defaultdict
55
from typing import Dict, List, Tuple, Union
66

7-
from labelbox.data.annotation_types.annotation import ClassificationAnnotation, Checklist, Radio
7+
from labelbox.data.annotation_types.annotation import ClassificationAnnotation, Checklist, Radio, Text
8+
from labelbox.data.annotation_types.classification.classification import ClassificationAnswer
89
try:
910
from typing import Literal
1011
except ImportError:
@@ -60,7 +61,13 @@ def all_have_key(features: List[FeatureSchema]) -> Tuple[bool, bool]:
6061
for feature in features:
6162
if isinstance(feature, ClassificationAnnotation):
6263
if isinstance(feature.value, Checklist):
63-
all_names, all_schemas = all_have_key(feature.value.answer)
64+
all_schemas, all_names = all_have_key(feature.value.answer)
65+
#this code should be able to be refactored better
66+
elif isinstance(feature.value, Text):
67+
if feature.name is None:
68+
all_names = False
69+
if feature.feature_schema_id is None:
70+
all_schemas = False
6471
else:
6572
if feature.value.answer.name is None:
6673
all_names = False
@@ -169,8 +176,17 @@ def _create_feature_lookup(features: List[FeatureSchema],
169176
#checklists
170177
if isinstance(feature.value, Checklist):
171178
for answer in feature.value.answer:
172-
new_feature = Radio(answer=answer)
173-
grouped_features[getattr(answer, key)] = new_feature
179+
new_answer = Radio(answer=answer)
180+
new_annotation = ClassificationAnnotation(
181+
value=new_answer,
182+
name=answer.name,
183+
feature_schema_id=answer.feature_schema_id)
184+
185+
grouped_features[getattr(answer,
186+
key)].append(new_annotation)
187+
#likely can be refactored
188+
elif isinstance(feature.value, Text):
189+
grouped_features[getattr(feature, key)].append(feature)
174190
else:
175191
grouped_features[getattr(feature.value.answer,
176192
key)].append(feature)

tests/data/metrics/confusion_matrix/conftest.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -183,24 +183,27 @@ def radio_pairs():
183183
return [
184184
NameSpace(predictions=[get_radio("is_animal", answer_name="yes")],
185185
ground_truths=[get_radio("is_animal", answer_name="yes")],
186-
expected={'is_animal': [1, 0, 0, 0]}),
186+
expected={'yes': [1, 0, 0, 0]}),
187187
NameSpace(predictions=[get_radio("is_animal", answer_name="yes")],
188188
ground_truths=[get_radio("is_animal", answer_name="no")],
189-
expected={'is_animal': [0, 1, 0, 1]}),
189+
expected={
190+
'no': [0, 0, 0, 1],
191+
'yes': [0, 1, 0, 0]
192+
}),
190193
NameSpace(predictions=[get_radio("is_animal", answer_name="yes")],
191194
ground_truths=[],
192-
expected={'is_animal': [0, 1, 0, 0]}),
195+
expected={'yes': [0, 1, 0, 0]}),
193196
NameSpace(predictions=[],
194197
ground_truths=[get_radio("is_animal", answer_name="yes")],
195-
expected={'is_animal': [0, 0, 0, 1]}),
198+
expected={'yes': [0, 0, 0, 1]}),
196199
NameSpace(predictions=[
197200
get_radio("is_animal", answer_name="yes"),
198201
get_radio("is_short", answer_name="no")
199202
],
200203
ground_truths=[get_radio("is_animal", answer_name="yes")],
201204
expected={
202-
'is_animal': [1, 0, 0, 0],
203-
'is_short': [0, 1, 0, 0]
205+
'no': [0, 1, 0, 0],
206+
'yes': [1, 0, 0, 0]
204207
}),
205208
#Not supported yet:
206209
# NameSpace(
@@ -221,18 +224,18 @@ def checklist_pairs():
221224
get_checklist("animal_attributes",
222225
answer_names=["striped"])
223226
],
224-
expected={'animal_attributes': [1, 0, 0, 0]}),
227+
expected={'striped': [1, 0, 0, 0]}),
225228
NameSpace(predictions=[
226229
get_checklist("animal_attributes", answer_names=["striped"])
227230
],
228231
ground_truths=[],
229-
expected={'animal_attributes': [0, 1, 0, 0]}),
232+
expected={'striped': [0, 1, 0, 0]}),
230233
NameSpace(predictions=[],
231234
ground_truths=[
232235
get_checklist("animal_attributes",
233236
answer_names=["striped"])
234237
],
235-
expected={'animal_attributes': [0, 0, 0, 1]}),
238+
expected={'striped': [0, 0, 0, 1]}),
236239
NameSpace(predictions=[
237240
get_checklist("animal_attributes",
238241
answer_names=["striped", "short"])
@@ -241,15 +244,21 @@ def checklist_pairs():
241244
get_checklist("animal_attributes",
242245
answer_names=["striped"])
243246
],
244-
expected={'animal_attributes': [1, 1, 0, 0]}),
247+
expected={
248+
'short': [0, 1, 0, 0],
249+
'striped': [1, 0, 0, 0]
250+
}),
245251
NameSpace(predictions=[
246252
get_checklist("animal_attributes", answer_names=["striped"])
247253
],
248254
ground_truths=[
249255
get_checklist("animal_attributes",
250256
answer_names=["striped", "short"])
251257
],
252-
expected={'animal_attributes': [1, 0, 0, 1]}),
258+
expected={
259+
'short': [0, 0, 0, 1],
260+
'striped': [1, 0, 0, 0]
261+
}),
253262
NameSpace(predictions=[
254263
get_checklist("animal_attributes",
255264
answer_names=["striped", "short", "black"])
@@ -258,7 +267,11 @@ def checklist_pairs():
258267
get_checklist("animal_attributes",
259268
answer_names=["striped", "short"])
260269
],
261-
expected={'animal_attributes': [2, 1, 0, 0]}),
270+
expected={
271+
'black': [0, 1, 0, 0],
272+
'short': [1, 0, 0, 0],
273+
'striped': [1, 0, 0, 0]
274+
}),
262275
NameSpace(predictions=[
263276
get_checklist("animal_attributes",
264277
answer_names=["striped", "short", "black"]),
@@ -270,8 +283,11 @@ def checklist_pairs():
270283
get_checklist("animal_name", answer_names=["pup"])
271284
],
272285
expected={
273-
'animal_attributes': [2, 1, 0, 0],
274-
'animal_name': [1, 1, 0, 0]
286+
'black': [0, 1, 0, 0],
287+
'doggy': [0, 1, 0, 0],
288+
'pup': [1, 0, 0, 0],
289+
'short': [1, 0, 0, 0],
290+
'striped': [1, 0, 0, 0]
275291
})
276292

277293
#Not supported yet:

tests/data/metrics/iou/data_row/conftest.py

Lines changed: 66 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88

99
class NameSpace(SimpleNamespace):
1010

11-
def __init__(self, predictions, labels, expected, classifications=None):
11+
def __init__(self,
12+
predictions,
13+
labels,
14+
expected,
15+
data_row_expected=None,
16+
classifications=None):
1217
super(NameSpace,
1318
self).__init__(predictions=predictions,
1419
labels={
@@ -19,7 +24,8 @@ def __init__(self, predictions, labels, expected, classifications=None):
1924
'classifications': classifications or []
2025
}
2126
},
22-
expected=expected)
27+
expected=expected,
28+
data_row_expected=data_row_expected)
2329

2430

2531
@pytest.fixture
@@ -314,41 +320,45 @@ def empty_radio_prediction():
314320

315321
@pytest.fixture
316322
def matching_checklist():
317-
return NameSpace(labels=[],
318-
classifications=[{
319-
'featureId':
320-
'1234567890111213141516171',
321-
'schemaId':
322-
'ckppid25v0000aeyjmxfwlc7t',
323-
'uuid':
324-
'76e0dcea-fe46-43e5-95f5-a5e3f378520a',
325-
'schemaId':
326-
'ckppid25v0000aeyjmxfwlc7t',
327-
'answers': [{
328-
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
329-
}, {
330-
'schemaId': 'ckppide010001aeyj0yhiaghc'
331-
}, {
332-
'schemaId': 'ckppidq4u0002aeyjmcc4toxw'
333-
}]
334-
}],
335-
predictions=[{
336-
'uuid':
337-
'76e0dcea-fe46-43e5-95f5-a5e3f378520a',
338-
'schemaId':
339-
'ckppid25v0000aeyjmxfwlc7t',
340-
'dataRow': {
341-
'id': 'ckppihxc10005aeyjen11h7jh'
342-
},
343-
'answers': [{
344-
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
345-
}, {
346-
'schemaId': 'ckppide010001aeyj0yhiaghc'
347-
}, {
348-
'schemaId': 'ckppidq4u0002aeyjmcc4toxw'
349-
}]
350-
}],
351-
expected=1.)
323+
return NameSpace(
324+
labels=[],
325+
classifications=[{
326+
'featureId':
327+
'1234567890111213141516171',
328+
'schemaId':
329+
'ckppid25v0000aeyjmxfwlc7t',
330+
'uuid':
331+
'76e0dcea-fe46-43e5-95f5-a5e3f378520a',
332+
'schemaId':
333+
'ckppid25v0000aeyjmxfwlc7t',
334+
'answers': [{
335+
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
336+
}, {
337+
'schemaId': 'ckppide010001aeyj0yhiaghc'
338+
}, {
339+
'schemaId': 'ckppidq4u0002aeyjmcc4toxw'
340+
}]
341+
}],
342+
predictions=[{
343+
'uuid':
344+
'76e0dcea-fe46-43e5-95f5-a5e3f378520a',
345+
'schemaId':
346+
'ckppid25v0000aeyjmxfwlc7t',
347+
'dataRow': {
348+
'id': 'ckppihxc10005aeyjen11h7jh'
349+
},
350+
'answers': [{
351+
'schemaId': 'ckppid25v0000aeyjmxfwlc7t',
352+
}, {
353+
'schemaId': 'ckppide010001aeyj0yhiaghc'
354+
}, {
355+
'schemaId': 'ckppidq4u0002aeyjmcc4toxw'
356+
}]
357+
}],
358+
data_row_expected=1.,
359+
# expected = [1.]
360+
# expected=[1., 1., 1.])
361+
expected={1.0: 3})
352362

353363

354364
@pytest.fixture
@@ -391,7 +401,11 @@ def partially_matching_checklist_1():
391401
'schemaId': 'ckppiebx80004aeyjuwvos69e'
392402
}]
393403
}],
394-
expected=0.6)
404+
data_row_expected=0.6,
405+
expected={
406+
0.0: 2,
407+
1.0: 3
408+
})
395409

396410

397411
@pytest.fixture
@@ -430,7 +444,11 @@ def partially_matching_checklist_2():
430444
'schemaId': 'ckppiebx80004aeyjuwvos69e'
431445
}]
432446
}],
433-
expected=0.5)
447+
data_row_expected=0.5,
448+
expected={
449+
1.0: 2,
450+
0.0: 2
451+
})
434452

435453

436454
@pytest.fixture
@@ -469,7 +487,11 @@ def partially_matching_checklist_3():
469487
'schemaId': 'ckppide010001aeyj0yhiaghc'
470488
}]
471489
}],
472-
expected=0.5)
490+
data_row_expected=0.5,
491+
expected={
492+
1.0: 2,
493+
0.0: 2
494+
})
473495

474496

475497
@pytest.fixture
@@ -485,7 +507,8 @@ def empty_checklist_label():
485507
'schemaId': 'ckppid25v0000aeyjmxfwlc7t'
486508
}]
487509
}],
488-
expected=0)
510+
data_row_expected=0.0,
511+
expected={0.0: 1})
489512

490513

491514
@pytest.fixture
@@ -502,7 +525,8 @@ def empty_checklist_prediction():
502525
}]
503526
}],
504527
predictions=[],
505-
expected=0)
528+
data_row_expected=0.0,
529+
expected={0.0: 1})
506530

507531

508532
@pytest.fixture

tests/data/metrics/iou/data_row/test_data_row_iou.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,37 @@ def check_iou(pair, mask=None):
3131
assert math.isclose(feature_ious[0].value, pair.expected)
3232

3333

34+
def check_iou_checklist(pair, mask=None):
35+
"""specialized test since checklists have more than one feature ious """
36+
default = Label(data=ImageData(uid="ckppihxc10005aeyjen11h7jh"))
37+
prediction = next(NDJsonConverter.deserialize(pair.predictions), default)
38+
label = next(LBV1Converter.deserialize([pair.labels]))
39+
if mask:
40+
for annotation in [*prediction.annotations, *label.annotations]:
41+
if isinstance(annotation.value, Mask):
42+
annotation.value.mask.arr = np.frombuffer(
43+
base64.b64decode(annotation.value.mask.url.encode('utf-8')),
44+
dtype=np.uint8).reshape((32, 32, 3))
45+
# print("drowmiou\n\n", data_row_miou(label, prediction),
46+
# pair.data_row_expected, "\n")
47+
assert math.isclose(data_row_miou(label, prediction),
48+
pair.data_row_expected)
49+
assert math.isclose(
50+
miou_metric(label.annotations, prediction.annotations)[0].value,
51+
pair.data_row_expected)
52+
feature_ious = feature_miou_metric(label.annotations,
53+
prediction.annotations)
54+
b = [feature.value for feature in feature_ious]
55+
# print("fiou\n\n", b, pair.expected, "\n")
56+
mapping = {}
57+
for iou in feature_ious:
58+
print(iou)
59+
if not mapping.get(iou.value, None):
60+
mapping[iou.value] = 0
61+
mapping[iou.value] += 1
62+
assert mapping == pair.expected
63+
64+
3465
def strings_to_fixtures(strings):
3566
return [fixture_ref(x) for x in strings]
3667

@@ -70,7 +101,7 @@ def test_radio(pair):
70101
"empty_checklist_prediction",
71102
]))
72103
def test_checklist(pair):
73-
check_iou(pair)
104+
check_iou_checklist(pair)
74105

75106

76107
@parametrize("pair", strings_to_fixtures(["matching_text",

tests/integration/annotation_import/test_ndjson_validation.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22
import ndjson
3-
from pytest_cases import pytest_parametrize_plus, fixture_ref
3+
from pytest_cases import parametrize, fixture_ref
44

55
from labelbox.exceptions import MALValidationError
66
from labelbox.schema.bulk_import_request import (NDChecklist, NDClassification,
@@ -23,14 +23,14 @@ def test_subclassification_construction(rectangle_inference):
2323
assert isinstance(tool.classifications[0], NDRadio)
2424

2525

26-
@pytest_parametrize_plus("inference, expected_type",
27-
[(fixture_ref('polygon_inference'), NDPolygon),
28-
(fixture_ref('rectangle_inference'), NDRectangle),
29-
(fixture_ref('line_inference'), NDPolyline),
30-
(fixture_ref('entity_inference'), NDTextEntity),
31-
(fixture_ref('segmentation_inference'), NDMask),
32-
(fixture_ref('segmentation_inference_rle'), NDMask),
33-
(fixture_ref('segmentation_inference_png'), NDMask)])
26+
@parametrize("inference, expected_type",
27+
[(fixture_ref('polygon_inference'), NDPolygon),
28+
(fixture_ref('rectangle_inference'), NDRectangle),
29+
(fixture_ref('line_inference'), NDPolyline),
30+
(fixture_ref('entity_inference'), NDTextEntity),
31+
(fixture_ref('segmentation_inference'), NDMask),
32+
(fixture_ref('segmentation_inference_rle'), NDMask),
33+
(fixture_ref('segmentation_inference_png'), NDMask)])
3434
def test_tool_construction(inference, expected_type):
3535
assert isinstance(NDTool.build(inference), expected_type)
3636

0 commit comments

Comments
 (0)