diff --git a/libs/labelbox/src/labelbox/client.py b/libs/labelbox/src/labelbox/client.py index 76d6636e3..6e05721dc 100644 --- a/libs/labelbox/src/labelbox/client.py +++ b/libs/labelbox/src/labelbox/client.py @@ -874,7 +874,7 @@ def create_offline_model_evaluation_project(self, **kwargs) -> Project: kwargs.pop("data_row_count", None) return self._create_project(**kwargs) - + def create_prompt_response_generation_project(self, dataset_id: Optional[str] = None, @@ -882,8 +882,8 @@ def create_prompt_response_generation_project(self, data_row_count: int = 100, **kwargs) -> Project: """ - Use this method exclusively to create a prompt and response generation project. - + Use this method exclusively to create a prompt and response generation project. + Args: dataset_name: When creating a new dataset, pass the name dataset_id: When using an existing dataset, pass the id @@ -891,9 +891,9 @@ def create_prompt_response_generation_project(self, **kwargs: Additional parameters to pass see the create_project method Returns: Project: The created project - - NOTE: Only a dataset_name or dataset_id should be included - + + NOTE: Only a dataset_name or dataset_id should be included + Examples: >>> client.create_prompt_response_generation_project(name=project_name, dataset_name="new data set", media_type=MediaType.LLMPromptResponseCreation) >>> This creates a new dataset with a default number of rows (100), creates new prompt and response creation project and assigns a batch of the newly created data rows to the project. @@ -912,12 +912,12 @@ def create_prompt_response_generation_project(self, raise ValueError( "dataset_name or dataset_id must be present and not be an empty string." ) - + if dataset_id and dataset_name: raise ValueError( "Only provide a dataset_name or dataset_id, not both." - ) - + ) + if data_row_count <= 0: raise ValueError("data_row_count must be a positive integer.") @@ -927,7 +927,7 @@ def create_prompt_response_generation_project(self, else: append_to_existing_dataset = False dataset_name_or_id = dataset_name - + if "media_type" in kwargs and kwargs.get("media_type") not in [MediaType.LLMPromptCreation, MediaType.LLMPromptResponseCreation]: raise ValueError( "media_type must be either LLMPromptCreation or LLMPromptResponseCreation" @@ -936,11 +936,11 @@ def create_prompt_response_generation_project(self, kwargs["dataset_name_or_id"] = dataset_name_or_id kwargs["append_to_existing_dataset"] = append_to_existing_dataset kwargs["data_row_count"] = data_row_count - + kwargs.pop("editor_task_type", None) - + return self._create_project(**kwargs) - + def create_response_creation_project(self, **kwargs) -> Project: """ Creates a project for response creation. @@ -1280,7 +1280,7 @@ def create_ontology_from_feature_schemas( leave as None otherwise. Returns: The created Ontology - + NOTE for chat evaluation, we currently force media_type to Conversational and for response creation, we force media_type to Text. """ tools, classifications = [], [] diff --git a/libs/labelbox/src/labelbox/data/annotation_types/label.py b/libs/labelbox/src/labelbox/data/annotation_types/label.py index cd209a493..c7a0cb7b8 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/label.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/label.py @@ -50,10 +50,10 @@ class Label(pydantic_compat.BaseModel): data: DataType annotations: List[Union[ClassificationAnnotation, ObjectAnnotation, VideoMaskAnnotation, ScalarMetric, - ConfusionMatrixMetric, - RelationshipAnnotation, + ConfusionMatrixMetric, RelationshipAnnotation, PromptClassificationAnnotation]] = [] extra: Dict[str, Any] = {} + is_benchmark_reference: Optional[bool] = False @pydantic_compat.root_validator(pre=True) def validate_data(cls, label): @@ -219,9 +219,8 @@ def validate_union(cls, value): ) # Validates only one prompt annotation is included if isinstance(v, PromptClassificationAnnotation): - prompt_count+=1 - if prompt_count > 1: - raise TypeError( - f"Only one prompt annotation is allowed per label" - ) + prompt_count += 1 + if prompt_count > 1: + raise TypeError( + f"Only one prompt annotation is allowed per label") return value diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py index a7c54b109..2ffeb9727 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py @@ -106,14 +106,16 @@ def serialize( if not isinstance(annotation, RelationshipAnnotation): uuid_safe_annotations.append(annotation) label.annotations = uuid_safe_annotations - for example in NDLabel.from_common([label]): - annotation_uuid = getattr(example, "uuid", None) + for annotation in NDLabel.from_common([label]): + annotation_uuid = getattr(annotation, "uuid", None) - res = example.dict( + res = annotation.dict( by_alias=True, exclude={"uuid"} if annotation_uuid == "None" else None, ) for k, v in list(res.items()): if k in IGNORE_IF_NONE and v is None: del res[k] + if getattr(label, 'is_benchmark_reference'): + res['isBenchmarkReferenceLabel'] = True yield res diff --git a/libs/labelbox/tests/data/annotation_import/test_data_types.py b/libs/labelbox/tests/data/annotation_import/test_data_types.py index 74695470d..d7b3ef825 100644 --- a/libs/labelbox/tests/data/annotation_import/test_data_types.py +++ b/libs/labelbox/tests/data/annotation_import/test_data_types.py @@ -1,12 +1,8 @@ -import datetime -from labelbox.schema.label import Label import pytest -import uuid from labelbox.data.annotation_types.data import ( AudioData, ConversationData, - DicomData, DocumentData, HTMLData, ImageData, @@ -15,11 +11,8 @@ from labelbox.data.serialization import NDJsonConverter from labelbox.data.annotation_types.data.video import VideoData -import labelbox as lb import labelbox.types as lb_types from labelbox.schema.media_type import MediaType -from labelbox.schema.annotation_import import AnnotationImportState -from labelbox import Project, Client # Unit test for label based on data type. # TODO: Dicom removed it is unstable when you deserialize and serialize on label import. If we intend to keep this library this needs add generic data types tests work with this data type. @@ -83,4 +76,4 @@ def test_data_row_type_by_global_key( annotations=label.annotations) assert data_label.data.global_key == label.data.global_key - assert label.annotations == data_label.annotations \ No newline at end of file + assert label.annotations == data_label.annotations diff --git a/libs/labelbox/tests/data/annotation_types/test_metrics.py b/libs/labelbox/tests/data/annotation_types/test_metrics.py index c68324842..db771f806 100644 --- a/libs/labelbox/tests/data/annotation_types/test_metrics.py +++ b/libs/labelbox/tests/data/annotation_types/test_metrics.py @@ -31,7 +31,8 @@ def test_legacy_scalar_metric(): 'extra': {}, }], 'extra': {}, - 'uid': None + 'uid': None, + 'is_benchmark_reference': False } assert label.dict() == expected @@ -92,7 +93,8 @@ def test_custom_scalar_metric(feature_name, subclass_name, aggregation, value): 'extra': {} }], 'extra': {}, - 'uid': None + 'uid': None, + 'is_benchmark_reference': False } assert label.dict() == expected @@ -149,7 +151,8 @@ def test_custom_confusison_matrix_metric(feature_name, subclass_name, 'extra': {} }], 'extra': {}, - 'uid': None + 'uid': None, + 'is_benchmark_reference': False } assert label.dict() == expected diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_conversation.py b/libs/labelbox/tests/data/serialization/ndjson/test_conversation.py index acf21cc21..33804ee32 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_conversation.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_conversation.py @@ -101,3 +101,33 @@ def test_conversation_entity_import(filename: str): res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) assert res == data + + +def test_benchmark_reference_label_flag_enabled(): + label = lb_types.Label(data=lb_types.ConversationData(global_key='my_global_key'), + annotations=[ + lb_types.ClassificationAnnotation( + name='free_text', + message_id="0", + value=lb_types.Text(answer="sample text")) + ], + is_benchmark_reference=True + ) + + res = list(NDJsonConverter.serialize([label])) + assert res[0]["isBenchmarkReferenceLabel"] + + +def test_benchmark_reference_label_flag_disabled(): + label = lb_types.Label(data=lb_types.ConversationData(global_key='my_global_key'), + annotations=[ + lb_types.ClassificationAnnotation( + name='free_text', + message_id="0", + value=lb_types.Text(answer="sample text")) + ], + is_benchmark_reference=False + ) + + res = list(NDJsonConverter.serialize([label])) + assert not res[0].get("isBenchmarkReferenceLabel") diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py b/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py index 73099c12f..c07dcc66d 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py @@ -85,3 +85,43 @@ def test_rectangle_mixed_start_end_points(): res = list(NDJsonConverter.deserialize(res)) assert res == [label] + + +def test_benchmark_reference_label_flag_enabled(): + bbox = lb_types.ObjectAnnotation( + name="bbox", + value=lb_types.Rectangle( + start=lb_types.Point(x=81, y=28), + end=lb_types.Point(x=38, y=69), + ), + extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"} + ) + + label = lb_types.Label( + data={"uid":DATAROW_ID}, + annotations=[bbox], + is_benchmark_reference=True + ) + + res = list(NDJsonConverter.serialize([label])) + assert res[0]["isBenchmarkReferenceLabel"] + + +def test_benchmark_reference_label_flag_disabled(): + bbox = lb_types.ObjectAnnotation( + name="bbox", + value=lb_types.Rectangle( + start=lb_types.Point(x=81, y=28), + end=lb_types.Point(x=38, y=69), + ), + extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"} + ) + + label = lb_types.Label( + data={"uid":DATAROW_ID}, + annotations=[bbox], + is_benchmark_reference=False + ) + + res = list(NDJsonConverter.serialize([label])) + assert not res[0].get("isBenchmarkReferenceLabel")