Skip to content

Commit 4e78990

Browse files
dubininsergeykozikkamilgithub-actions[bot]
authored
MODEL-1489: Allow marking Label with "is_benchmark_reference" flag (#1718)
Co-authored-by: kozikkamil <kkozik@labelbox.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 37f4386 commit 4e78990

File tree

7 files changed

+102
-35
lines changed

7 files changed

+102
-35
lines changed

libs/labelbox/src/labelbox/client.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -874,26 +874,26 @@ def create_offline_model_evaluation_project(self, **kwargs) -> Project:
874874
kwargs.pop("data_row_count", None)
875875

876876
return self._create_project(**kwargs)
877-
877+
878878

879879
def create_prompt_response_generation_project(self,
880880
dataset_id: Optional[str] = None,
881881
dataset_name: Optional[str] = None,
882882
data_row_count: int = 100,
883883
**kwargs) -> Project:
884884
"""
885-
Use this method exclusively to create a prompt and response generation project.
886-
885+
Use this method exclusively to create a prompt and response generation project.
886+
887887
Args:
888888
dataset_name: When creating a new dataset, pass the name
889889
dataset_id: When using an existing dataset, pass the id
890890
data_row_count: The number of data row assets to use for the project
891891
**kwargs: Additional parameters to pass see the create_project method
892892
Returns:
893893
Project: The created project
894-
895-
NOTE: Only a dataset_name or dataset_id should be included
896-
894+
895+
NOTE: Only a dataset_name or dataset_id should be included
896+
897897
Examples:
898898
>>> client.create_prompt_response_generation_project(name=project_name, dataset_name="new data set", media_type=MediaType.LLMPromptResponseCreation)
899899
>>> This creates a new dataset with a default number of rows (100), creates new prompt and response creation project and assigns a batch of the newly created data rows to the project.
@@ -912,12 +912,12 @@ def create_prompt_response_generation_project(self,
912912
raise ValueError(
913913
"dataset_name or dataset_id must be present and not be an empty string."
914914
)
915-
915+
916916
if dataset_id and dataset_name:
917917
raise ValueError(
918918
"Only provide a dataset_name or dataset_id, not both."
919-
)
920-
919+
)
920+
921921
if data_row_count <= 0:
922922
raise ValueError("data_row_count must be a positive integer.")
923923

@@ -927,7 +927,7 @@ def create_prompt_response_generation_project(self,
927927
else:
928928
append_to_existing_dataset = False
929929
dataset_name_or_id = dataset_name
930-
930+
931931
if "media_type" in kwargs and kwargs.get("media_type") not in [MediaType.LLMPromptCreation, MediaType.LLMPromptResponseCreation]:
932932
raise ValueError(
933933
"media_type must be either LLMPromptCreation or LLMPromptResponseCreation"
@@ -936,11 +936,11 @@ def create_prompt_response_generation_project(self,
936936
kwargs["dataset_name_or_id"] = dataset_name_or_id
937937
kwargs["append_to_existing_dataset"] = append_to_existing_dataset
938938
kwargs["data_row_count"] = data_row_count
939-
939+
940940
kwargs.pop("editor_task_type", None)
941-
941+
942942
return self._create_project(**kwargs)
943-
943+
944944
def create_response_creation_project(self, **kwargs) -> Project:
945945
"""
946946
Creates a project for response creation.
@@ -1280,7 +1280,7 @@ def create_ontology_from_feature_schemas(
12801280
leave as None otherwise.
12811281
Returns:
12821282
The created Ontology
1283-
1283+
12841284
NOTE for chat evaluation, we currently force media_type to Conversational and for response creation, we force media_type to Text.
12851285
"""
12861286
tools, classifications = [], []

libs/labelbox/src/labelbox/data/annotation_types/label.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ class Label(pydantic_compat.BaseModel):
5050
data: DataType
5151
annotations: List[Union[ClassificationAnnotation, ObjectAnnotation,
5252
VideoMaskAnnotation, ScalarMetric,
53-
ConfusionMatrixMetric,
54-
RelationshipAnnotation,
53+
ConfusionMatrixMetric, RelationshipAnnotation,
5554
PromptClassificationAnnotation]] = []
5655
extra: Dict[str, Any] = {}
56+
is_benchmark_reference: Optional[bool] = False
5757

5858
@pydantic_compat.root_validator(pre=True)
5959
def validate_data(cls, label):
@@ -219,9 +219,8 @@ def validate_union(cls, value):
219219
)
220220
# Validates only one prompt annotation is included
221221
if isinstance(v, PromptClassificationAnnotation):
222-
prompt_count+=1
223-
if prompt_count > 1:
224-
raise TypeError(
225-
f"Only one prompt annotation is allowed per label"
226-
)
222+
prompt_count += 1
223+
if prompt_count > 1:
224+
raise TypeError(
225+
f"Only one prompt annotation is allowed per label")
227226
return value

libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,16 @@ def serialize(
106106
if not isinstance(annotation, RelationshipAnnotation):
107107
uuid_safe_annotations.append(annotation)
108108
label.annotations = uuid_safe_annotations
109-
for example in NDLabel.from_common([label]):
110-
annotation_uuid = getattr(example, "uuid", None)
109+
for annotation in NDLabel.from_common([label]):
110+
annotation_uuid = getattr(annotation, "uuid", None)
111111

112-
res = example.dict(
112+
res = annotation.dict(
113113
by_alias=True,
114114
exclude={"uuid"} if annotation_uuid == "None" else None,
115115
)
116116
for k, v in list(res.items()):
117117
if k in IGNORE_IF_NONE and v is None:
118118
del res[k]
119+
if getattr(label, 'is_benchmark_reference'):
120+
res['isBenchmarkReferenceLabel'] = True
119121
yield res

libs/labelbox/tests/data/annotation_import/test_data_types.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
1-
import datetime
2-
from labelbox.schema.label import Label
31
import pytest
4-
import uuid
52

63
from labelbox.data.annotation_types.data import (
74
AudioData,
85
ConversationData,
9-
DicomData,
106
DocumentData,
117
HTMLData,
128
ImageData,
@@ -15,11 +11,8 @@
1511
from labelbox.data.serialization import NDJsonConverter
1612
from labelbox.data.annotation_types.data.video import VideoData
1713

18-
import labelbox as lb
1914
import labelbox.types as lb_types
2015
from labelbox.schema.media_type import MediaType
21-
from labelbox.schema.annotation_import import AnnotationImportState
22-
from labelbox import Project, Client
2316

2417
# Unit test for label based on data type.
2518
# TODO: Dicom removed it is unstable when you deserialize and serialize on label import. If we intend to keep this library this needs add generic data types tests work with this data type.
@@ -83,4 +76,4 @@ def test_data_row_type_by_global_key(
8376
annotations=label.annotations)
8477

8578
assert data_label.data.global_key == label.data.global_key
86-
assert label.annotations == data_label.annotations
79+
assert label.annotations == data_label.annotations

libs/labelbox/tests/data/annotation_types/test_metrics.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ def test_legacy_scalar_metric():
3131
'extra': {},
3232
}],
3333
'extra': {},
34-
'uid': None
34+
'uid': None,
35+
'is_benchmark_reference': False
3536
}
3637
assert label.dict() == expected
3738

@@ -92,7 +93,8 @@ def test_custom_scalar_metric(feature_name, subclass_name, aggregation, value):
9293
'extra': {}
9394
}],
9495
'extra': {},
95-
'uid': None
96+
'uid': None,
97+
'is_benchmark_reference': False
9698
}
9799

98100
assert label.dict() == expected
@@ -149,7 +151,8 @@ def test_custom_confusison_matrix_metric(feature_name, subclass_name,
149151
'extra': {}
150152
}],
151153
'extra': {},
152-
'uid': None
154+
'uid': None,
155+
'is_benchmark_reference': False
153156
}
154157
assert label.dict() == expected
155158

libs/labelbox/tests/data/serialization/ndjson/test_conversation.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,33 @@ def test_conversation_entity_import(filename: str):
101101
res = list(NDJsonConverter.deserialize(data))
102102
res = list(NDJsonConverter.serialize(res))
103103
assert res == data
104+
105+
106+
def test_benchmark_reference_label_flag_enabled():
107+
label = lb_types.Label(data=lb_types.ConversationData(global_key='my_global_key'),
108+
annotations=[
109+
lb_types.ClassificationAnnotation(
110+
name='free_text',
111+
message_id="0",
112+
value=lb_types.Text(answer="sample text"))
113+
],
114+
is_benchmark_reference=True
115+
)
116+
117+
res = list(NDJsonConverter.serialize([label]))
118+
assert res[0]["isBenchmarkReferenceLabel"]
119+
120+
121+
def test_benchmark_reference_label_flag_disabled():
122+
label = lb_types.Label(data=lb_types.ConversationData(global_key='my_global_key'),
123+
annotations=[
124+
lb_types.ClassificationAnnotation(
125+
name='free_text',
126+
message_id="0",
127+
value=lb_types.Text(answer="sample text"))
128+
],
129+
is_benchmark_reference=False
130+
)
131+
132+
res = list(NDJsonConverter.serialize([label]))
133+
assert not res[0].get("isBenchmarkReferenceLabel")

libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,43 @@ def test_rectangle_mixed_start_end_points():
8585

8686
res = list(NDJsonConverter.deserialize(res))
8787
assert res == [label]
88+
89+
90+
def test_benchmark_reference_label_flag_enabled():
91+
bbox = lb_types.ObjectAnnotation(
92+
name="bbox",
93+
value=lb_types.Rectangle(
94+
start=lb_types.Point(x=81, y=28),
95+
end=lb_types.Point(x=38, y=69),
96+
),
97+
extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"}
98+
)
99+
100+
label = lb_types.Label(
101+
data={"uid":DATAROW_ID},
102+
annotations=[bbox],
103+
is_benchmark_reference=True
104+
)
105+
106+
res = list(NDJsonConverter.serialize([label]))
107+
assert res[0]["isBenchmarkReferenceLabel"]
108+
109+
110+
def test_benchmark_reference_label_flag_disabled():
111+
bbox = lb_types.ObjectAnnotation(
112+
name="bbox",
113+
value=lb_types.Rectangle(
114+
start=lb_types.Point(x=81, y=28),
115+
end=lb_types.Point(x=38, y=69),
116+
),
117+
extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"}
118+
)
119+
120+
label = lb_types.Label(
121+
data={"uid":DATAROW_ID},
122+
annotations=[bbox],
123+
is_benchmark_reference=False
124+
)
125+
126+
res = list(NDJsonConverter.serialize([label]))
127+
assert not res[0].get("isBenchmarkReferenceLabel")

0 commit comments

Comments
 (0)