Skip to content

Commit 603bb72

Browse files
committed
working progress of updating metrics for classifications
1 parent 6269b98 commit 603bb72

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

labelbox/data/metrics/group.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing_extensions import Literal
1010

1111
from ..annotation_types.feature import FeatureSchema
12-
from ..annotation_types import ObjectAnnotation, Label, LabelList
12+
from ..annotation_types import ObjectAnnotation, Label, LabelList, ClassificationAnnotation
1313

1414

1515
def get_identifying_key(
@@ -63,6 +63,18 @@ def all_have_key(features: List[FeatureSchema]) -> Tuple[bool, bool]:
6363
return all_schemas, all_names
6464

6565

66+
def update_classification_answers(label: Label):
67+
"""
68+
This function is to update classifications' names to the answers. This prevents
69+
metrics from being calculated only at the description of the classification, and
70+
instead at the answer level.
71+
"""
72+
for annotation in label.annotations:
73+
if isinstance(annotation, ClassificationAnnotation):
74+
annotation.name = annotation.value.answer.name
75+
return label
76+
77+
6678
def get_label_pairs(labels_a: LabelList,
6779
labels_b: LabelList,
6880
match_on="uid",
@@ -112,6 +124,8 @@ def get_label_pairs(labels_a: LabelList,
112124
)
113125
else:
114126
continue
127+
a, b = update_classification_answers(a), update_classification_answers(
128+
b)
115129
pairs[key].extend([a, b])
116130
return pairs
117131

0 commit comments

Comments
 (0)