Skip to content

Commit 29deb0c

Browse files
committed
MODEL-1489: Allow marking Label with "is_benchmark_reference" flag
1 parent 48bcfb3 commit 29deb0c

File tree

3 files changed

+28
-11
lines changed

3 files changed

+28
-11
lines changed

libs/labelbox/src/labelbox/data/annotation_types/label.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ class Label(pydantic_compat.BaseModel):
5050
data: DataType
5151
annotations: List[Union[ClassificationAnnotation, ObjectAnnotation,
5252
VideoMaskAnnotation, ScalarMetric,
53-
ConfusionMatrixMetric,
54-
RelationshipAnnotation,
53+
ConfusionMatrixMetric, RelationshipAnnotation,
5554
PromptClassificationAnnotation]] = []
5655
extra: Dict[str, Any] = {}
56+
is_benchmark_reference: Optional[bool] = False
5757

5858
@pydantic_compat.root_validator(pre=True)
5959
def validate_data(cls, label):
@@ -219,9 +219,8 @@ def validate_union(cls, value):
219219
)
220220
# Validates only one prompt annotation is included
221221
if isinstance(v, PromptClassificationAnnotation):
222-
prompt_count+=1
223-
if prompt_count > 1:
224-
raise TypeError(
225-
f"Only one prompt annotation is allowed per label"
226-
)
222+
prompt_count += 1
223+
if prompt_count > 1:
224+
raise TypeError(
225+
f"Only one prompt annotation is allowed per label")
227226
return value

libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,16 @@ def serialize(
106106
if not isinstance(annotation, RelationshipAnnotation):
107107
uuid_safe_annotations.append(annotation)
108108
label.annotations = uuid_safe_annotations
109-
for example in NDLabel.from_common([label]):
110-
annotation_uuid = getattr(example, "uuid", None)
109+
for annotation in NDLabel.from_common([label]):
110+
annotation_uuid = getattr(annotation, "uuid", None)
111111

112-
res = example.dict(
112+
res = annotation.dict(
113113
by_alias=True,
114114
exclude={"uuid"} if annotation_uuid == "None" else None,
115115
)
116116
for k, v in list(res.items()):
117117
if k in IGNORE_IF_NONE and v is None:
118118
del res[k]
119+
if getattr(label, 'is_benchmark_reference'):
120+
res['isBenchmarkReferenceLabel'] = True
119121
yield res

libs/labelbox/tests/data/annotation_import/test_data_types.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def validate_iso_format(date_string: str):
198198
assert parsed_t.minute is not None
199199
assert parsed_t.second is not None
200200

201+
201202
@pytest.mark.order(1)
202203
@pytest.mark.parametrize(
203204
"data_type_class",
@@ -333,6 +334,22 @@ def test_import_label_annotations(
333334
data_row.delete()
334335

335336

337+
@pytest.mark.parametrize("_, data_class, annotations", test_params)
338+
def test_import_label_annotations_with_is_benchmark_reference_flag(
339+
data_class, annotations, _):
340+
labels = [
341+
lb_types.Label(data=data_class(uid=str(uuid.uuid4()),
342+
url="http://test.com"),
343+
annotations=annotations,
344+
is_benchmark_reference=True)
345+
]
346+
serialized_annotations = get_annotation_comparison_dicts_from_labels(labels)
347+
348+
assert len(serialized_annotations) == len(annotations)
349+
for serialized_annotation in serialized_annotations:
350+
assert serialized_annotation["isBenchmarkReferenceLabel"]
351+
352+
336353
@pytest.mark.parametrize("data_type, data_class, annotations", test_params)
337354
@pytest.fixture
338355
def one_datarow(client, rand_gen, data_row_json_by_data_type, data_type):
@@ -423,4 +440,3 @@ def test_import_mal_annotations_global_key(client,
423440

424441
assert import_annotations.errors == []
425442
# MAL Labels cannot be exported and compared to input labels
426-

0 commit comments

Comments
 (0)