Skip to content

MODEL-1489: Allow marking Label with "is_benchmark_reference" flag #1718

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions libs/labelbox/src/labelbox/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,26 +874,26 @@ def create_offline_model_evaluation_project(self, **kwargs) -> Project:
kwargs.pop("data_row_count", None)

return self._create_project(**kwargs)


def create_prompt_response_generation_project(self,
dataset_id: Optional[str] = None,
dataset_name: Optional[str] = None,
data_row_count: int = 100,
**kwargs) -> Project:
"""
Use this method exclusively to create a prompt and response generation project.
Use this method exclusively to create a prompt and response generation project.

Args:
dataset_name: When creating a new dataset, pass the name
dataset_id: When using an existing dataset, pass the id
data_row_count: The number of data row assets to use for the project
**kwargs: Additional parameters to pass see the create_project method
Returns:
Project: The created project
NOTE: Only a dataset_name or dataset_id should be included

NOTE: Only a dataset_name or dataset_id should be included

Examples:
>>> client.create_prompt_response_generation_project(name=project_name, dataset_name="new data set", project_kind=MediaType.LLMPromptResponseCreation)
>>> This creates a new dataset with a default number of rows (100), creates new project and assigns a batch of the newly created datarows to the project.
Expand All @@ -912,12 +912,12 @@ def create_prompt_response_generation_project(self,
raise ValueError(
"dataset_name or dataset_id must be present and not be an empty string."
)

if dataset_id and dataset_name:
raise ValueError(
"Only provide a dataset_name or dataset_id, not both."
)
)

if data_row_count <= 0:
raise ValueError("data_row_count must be a positive integer.")

Expand All @@ -927,7 +927,7 @@ def create_prompt_response_generation_project(self,
else:
append_to_existing_dataset = False
dataset_name_or_id = dataset_name

if "media_type" in kwargs and kwargs.get("media_type") not in [MediaType.LLMPromptCreation, MediaType.LLMPromptResponseCreation]:
raise ValueError(
"media_type must be either LLMPromptCreation or LLMPromptResponseCreation"
Expand All @@ -936,11 +936,11 @@ def create_prompt_response_generation_project(self,
kwargs["dataset_name_or_id"] = dataset_name_or_id
kwargs["append_to_existing_dataset"] = append_to_existing_dataset
kwargs["data_row_count"] = data_row_count

kwargs.pop("editor_task_type", None)

return self._create_project(**kwargs)

def create_response_creation_project(self, **kwargs) -> Project:
"""
Creates a project for response creation.
Expand Down Expand Up @@ -1280,7 +1280,7 @@ def create_ontology_from_feature_schemas(
leave as None otherwise.
Returns:
The created Ontology

NOTE for chat evaluation, we currently force media_type to Conversational and for response creation, we force media_type to Text.
"""
tools, classifications = [], []
Expand Down
13 changes: 6 additions & 7 deletions libs/labelbox/src/labelbox/data/annotation_types/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ class Label(pydantic_compat.BaseModel):
data: DataType
annotations: List[Union[ClassificationAnnotation, ObjectAnnotation,
VideoMaskAnnotation, ScalarMetric,
ConfusionMatrixMetric,
RelationshipAnnotation,
ConfusionMatrixMetric, RelationshipAnnotation,
PromptClassificationAnnotation]] = []
extra: Dict[str, Any] = {}
is_benchmark_reference: Optional[bool] = False

@pydantic_compat.root_validator(pre=True)
def validate_data(cls, label):
Expand Down Expand Up @@ -219,9 +219,8 @@ def validate_union(cls, value):
)
# Validates only one prompt annotation is included
if isinstance(v, PromptClassificationAnnotation):
prompt_count+=1
if prompt_count > 1:
raise TypeError(
f"Only one prompt annotation is allowed per label"
)
prompt_count += 1
if prompt_count > 1:
raise TypeError(
f"Only one prompt annotation is allowed per label")
return value
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,16 @@ def serialize(
if not isinstance(annotation, RelationshipAnnotation):
uuid_safe_annotations.append(annotation)
label.annotations = uuid_safe_annotations
for example in NDLabel.from_common([label]):
annotation_uuid = getattr(example, "uuid", None)
for annotation in NDLabel.from_common([label]):
annotation_uuid = getattr(annotation, "uuid", None)

res = example.dict(
res = annotation.dict(
by_alias=True,
exclude={"uuid"} if annotation_uuid == "None" else None,
)
for k, v in list(res.items()):
if k in IGNORE_IF_NONE and v is None:
del res[k]
if getattr(label, 'is_benchmark_reference'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you try to declare class Label(_CamelCaseMixin) , I do not think you would need this line

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tend to avoid juggling with mixins because this caused MRO problems in the past. Do you think we can leave it as is?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the way I have it above should prob work, can you try? If it does not work (i.e. test fails etc) we could use your approach, but I rather not have custom serialization if we can avoid it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, the logic is the following: if there is is_benchmark_reference attribute set for the label, then all child annotations should have isBenchmarkReferenceLabel set to True. I don't see how the code above would help to avoid this logic

res['isBenchmarkReferenceLabel'] = True
yield res
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import datetime
from labelbox.schema.label import Label
import pytest
import uuid

from labelbox.data.annotation_types.data import (
AudioData,
ConversationData,
DicomData,
DocumentData,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why was this removed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the imports were unused

HTMLData,
ImageData,
Expand All @@ -15,11 +11,8 @@
from labelbox.data.serialization import NDJsonConverter
from labelbox.data.annotation_types.data.video import VideoData

import labelbox as lb
import labelbox.types as lb_types
from labelbox.schema.media_type import MediaType
from labelbox.schema.annotation_import import AnnotationImportState
from labelbox import Project, Client

# Unit test for label based on data type.
# TODO: Dicom removed it is unstable when you deserialize and serialize on label import. If we intend to keep this library this needs add generic data types tests work with this data type.
Expand Down Expand Up @@ -83,4 +76,4 @@ def test_data_row_type_by_global_key(
annotations=label.annotations)

assert data_label.data.global_key == label.data.global_key
assert label.annotations == data_label.annotations
assert label.annotations == data_label.annotations
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,18 @@ def test_conversation_entity_import(filename: str):
res = list(NDJsonConverter.deserialize(data))
res = list(NDJsonConverter.serialize(res))
assert res == data


def test_benchmark_reference_label_flag():
label = lb_types.Label(data=lb_types.ConversationData(global_key='my_global_key'),
annotations=[
lb_types.ClassificationAnnotation(
name='free_text',
message_id="0",
value=lb_types.Text(answer="sample text"))
],
is_benchmark_reference=True
)

res = list(NDJsonConverter.serialize([label]))
assert res[0]["isBenchmarkReferenceLabel"]
20 changes: 20 additions & 0 deletions libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,23 @@ def test_rectangle_mixed_start_end_points():

res = list(NDJsonConverter.deserialize(res))
assert res == [label]


def test_benchmark_reference_label_flag():
bbox = lb_types.ObjectAnnotation(
name="bbox",
value=lb_types.Rectangle(
start=lb_types.Point(x=81, y=28),
end=lb_types.Point(x=38, y=69),
),
extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"}
)

label = lb_types.Label(
data={"uid":DATAROW_ID},
annotations=[bbox],
is_benchmark_reference=True
)

res = list(NDJsonConverter.serialize([label]))
assert res[0]["isBenchmarkReferenceLabel"]
Loading