Skip to content

Commit e9d956d

Browse files
committed
update to have metrics calculated at the answer level
1 parent 16288d7 commit e9d956d

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

labelbox/data/metrics/group.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
"""
44
from collections import defaultdict
55
from typing import Dict, List, Tuple, Union
6+
7+
from labelbox.data.annotation_types.annotation import ClassificationAnnotation, Checklist, Radio
68
try:
79
from typing import Literal
810
except ImportError:
911
from typing_extensions import Literal
1012

1113
from ..annotation_types.feature import FeatureSchema
12-
from ..annotation_types import ObjectAnnotation, Label, LabelList
14+
from ..annotation_types import ObjectAnnotation, ClassificationAnnotation, Label, LabelList
1315

1416

1517
def get_identifying_key(
@@ -56,6 +58,14 @@ def all_have_key(features: List[FeatureSchema]) -> Tuple[bool, bool]:
5658
all_names = True
5759
all_schemas = True
5860
for feature in features:
61+
if isinstance(feature, ClassificationAnnotation):
62+
if isinstance(feature.value, Checklist):
63+
all_names, all_schemas = all_have_key(feature.value.answer)
64+
else:
65+
if feature.value.answer.name is None:
66+
all_names = False
67+
if feature.value.answer.feature_schema_id is None:
68+
all_schemas = False
5969
if feature.name is None:
6070
all_names = False
6171
if feature.feature_schema_id is None:
@@ -155,7 +165,17 @@ def _create_feature_lookup(features: List[FeatureSchema],
155165
"""
156166
grouped_features = defaultdict(list)
157167
for feature in features:
158-
grouped_features[getattr(feature, key)].append(feature)
168+
if isinstance(feature, ClassificationAnnotation):
169+
#checklists
170+
if isinstance(feature.value, Checklist):
171+
for answer in feature.value.answer:
172+
new_feature = Radio(answer=answer)
173+
grouped_features[getattr(answer, key)] = new_feature
174+
else:
175+
grouped_features[getattr(feature.value.answer,
176+
key)].append(feature)
177+
else:
178+
grouped_features[getattr(feature, key)].append(feature)
159179
return grouped_features
160180

161181

0 commit comments

Comments
 (0)