diff --git a/libs/labelbox/tests/data/annotation_import/conftest.py b/libs/labelbox/tests/data/annotation_import/conftest.py index 60dbedb4d..e2b0a838c 100644 --- a/libs/labelbox/tests/data/annotation_import/conftest.py +++ b/libs/labelbox/tests/data/annotation_import/conftest.py @@ -8,12 +8,13 @@ import time import requests -from labelbox import parser, MediaType +from labelbox import parser, MediaType, OntologyKind from labelbox import Client, Dataset from typing import Tuple, Type from labelbox.schema.annotation_import import LabelImport, AnnotationImportState from pytest import FixtureRequest +from contextlib import suppress """ The main fixtures of this library are configured_project and configured_project_by_global_key. Both fixtures generate data rows with a parametrize media type. They create the amount of data rows equal to the DATA_ROW_COUNT variable below. The data rows are generated with a factory fixture that returns a function that allows you to pass a global key. The ontologies are generated normalized and based on the MediaType given (i.e. only features supported by MediaType are created). This ontology is later used to obtain the correct annotations with the prediction_id_mapping and corresponding inferences. Each data row will have all possible annotations attached supported for the MediaType. @@ -385,6 +386,60 @@ def normalized_ontology_by_media_type(): }, ], } + + prompt_text = { + "instructions": "prompt-text", + "name": "prompt-text", + "options": [], + "required": True, + "maxCharacters": 50, + "minCharacters": 1, + "schemaNodeId": None, + "type": "prompt" + } + + response_radio = { + "instructions": "radio-response", + "name": "radio-response", + "options": [{ + "label": "first_radio_answer", + "value": "first_radio_answer", + "options": [] + }, + { + "label": "second_radio_answer", + "value": "second_radio_answer", + "options": [] + }], + "required": True, + "type": "response-radio" + } + + response_checklist = { + "instructions": "checklist-response", + "name": "checklist-response", + "options": [{ + "label": "first_checklist_answer", + "value": "first_checklist_answer", + "options": [] + }, + { + "label": "second_checklist_answer", + "value": "second_checklist_answer", + "options": [] + }], + "required": True, + "type": "response-checklist" + } + + response_text = { + "instructions": "response-text", + "maxCharacters": 20, + "minCharacters": 1, + "name": "response-text", + "required": True, + "type": "response-text" + } return { MediaType.Image: { @@ -489,6 +544,29 @@ def normalized_ontology_by_media_type(): free_form_text_index ] }, + MediaType.LLMPromptResponseCreation: { + "tools": [], + "classifications": [ + prompt_text, + response_text, + response_radio, + response_checklist + ] + }, + MediaType.LLMPromptCreation: { + "tools": [], + "classifications": [ + prompt_text + ] + }, + OntologyKind.ResponseCreation: { + "tools": [], + "classifications": [ + response_text, + response_radio, + response_checklist + ] + }, "all": { "tools":[ bbox_tool, @@ -561,14 +639,67 @@ def get_global_key(): ##### Integration test strategies ##### +def _create_response_creation_project(client: Client, rand_gen, data_row_json_by_media_type, ontology_kind, normalized_ontology_by_media_type) -> Tuple[Project, Ontology, Dataset]: + "For response creation projects" + + dataset = client.create_dataset(name=rand_gen(str)) + + project = client.create_response_creation_project(name=f"{ontology_kind}-{rand_gen(str)}") + + ontology = client.create_ontology(name=f"{ontology_kind}-{rand_gen(str)}", + normalized=normalized_ontology_by_media_type[ontology_kind], + media_type=MediaType.Text, + ontology_kind=ontology_kind) + + project.connect_ontology(ontology) + + data_row_data = [] + + for _ in range(DATA_ROW_COUNT): + data_row_data.append(data_row_json_by_media_type[MediaType.Text](rand_gen(str))) + + task = dataset.create_data_rows(data_row_data) + task.wait_till_done() + global_keys = [row['global_key'] for row in task.result] + data_row_ids = [row['id'] for row in task.result] + + project.create_batch( + rand_gen(str), + data_row_ids, # sample of data row objects + 5, # priority between 1(Highest) - 5(lowest) + ) + project.data_row_ids = data_row_ids + project.global_keys = global_keys + + return project, ontology, dataset + +def _create_prompt_response_project(client: Client, rand_gen, media_type, normalized_ontology_by_media_type, export_v2_test_helpers) -> Tuple[Project, Ontology]: + """For prompt response data row auto gen projects""" + + prompt_response_project = client.create_prompt_response_generation_project(name=f"{media_type.value}-{rand_gen(str)}", + dataset_name=rand_gen(str), + data_row_count=1, + media_type=media_type) + + ontology = client.create_ontology(name=f"{media_type}-{rand_gen(str)}", normalized=normalized_ontology_by_media_type[media_type], media_type=media_type) + + prompt_response_project.connect_ontology(ontology) + + # We have to export to get data row ids + result = export_v2_test_helpers.run_project_export_v2_task(prompt_response_project) + + data_row_ids = [dr["data_row"]["id"] for dr in result] + global_keys = [dr["data_row"]["global_key"] for dr in result] + + prompt_response_project.data_row_ids = data_row_ids + prompt_response_project.global_keys = global_keys + + return prompt_response_project, ontology + def _create_project(client: Client, rand_gen, data_row_json_by_media_type, media_type, normalized_ontology_by_media_type) -> Tuple[Project, Ontology, Dataset]: """ Shared function to configure project for integration tests """ dataset = client.create_dataset(name=rand_gen(str)) - - # LLMPromptResponseCreation is not support for project or ontology creation needs to be conversational - if media_type == MediaType.LLMPromptResponseCreation: - media_type = MediaType.Conversational project = client.create_project(name=f"{media_type}-{rand_gen(str)}", media_type=media_type) @@ -599,34 +730,50 @@ def _create_project(client: Client, rand_gen, data_row_json_by_media_type, media @pytest.fixture -def configured_project(client: Client, rand_gen, data_row_json_by_media_type, request: FixtureRequest, normalized_ontology_by_media_type): +def configured_project(client: Client, rand_gen, data_row_json_by_media_type, request: FixtureRequest, normalized_ontology_by_media_type, export_v2_test_helpers): """Configure project for test. Request.param will contain the media type if not present will use Image MediaType. The project will have 10 data rows.""" media_type = getattr(request, "param", MediaType.Image) - - project, ontology, dataset = _create_project(client, rand_gen, data_row_json_by_media_type, media_type, normalized_ontology_by_media_type) + dataset = None + + if media_type == MediaType.LLMPromptCreation or media_type == MediaType.LLMPromptResponseCreation: + project, ontology = _create_prompt_response_project(client, rand_gen, media_type, normalized_ontology_by_media_type, export_v2_test_helpers) + elif media_type == OntologyKind.ResponseCreation: + project, ontology, dataset = _create_response_creation_project(client, rand_gen, data_row_json_by_media_type, media_type, normalized_ontology_by_media_type) + else: + project, ontology, dataset = _create_project(client, rand_gen, data_row_json_by_media_type, media_type, normalized_ontology_by_media_type) yield project project.delete() - dataset.delete() + + if dataset: + dataset.delete() + client.delete_unused_ontology(ontology.uid) @pytest.fixture() -def configured_project_by_global_key(client: Client, rand_gen, data_row_json_by_media_type, request: FixtureRequest, normalized_ontology_by_media_type): +def configured_project_by_global_key(client: Client, rand_gen, data_row_json_by_media_type, request: FixtureRequest, normalized_ontology_by_media_type, export_v2_test_helpers): """Does the same thing as configured project but with global keys focus.""" - dataset = client.create_dataset(name=rand_gen(str)) - media_type = getattr(request, "param", MediaType.Image) - - project, ontology, dataset = _create_project(client, rand_gen, data_row_json_by_media_type, media_type, normalized_ontology_by_media_type) + dataset = None + + if media_type == MediaType.LLMPromptCreation or media_type == MediaType.LLMPromptResponseCreation: + project, ontology = _create_prompt_response_project(client, rand_gen, media_type, normalized_ontology_by_media_type, export_v2_test_helpers) + elif media_type == OntologyKind.ResponseCreation: + project, ontology, dataset = _create_response_creation_project(client, rand_gen, data_row_json_by_media_type, media_type, normalized_ontology_by_media_type) + else: + project, ontology, dataset = _create_project(client, rand_gen, data_row_json_by_media_type, media_type, normalized_ontology_by_media_type) yield project project.delete() - dataset.delete() + + if dataset: + dataset.delete() + client.delete_unused_ontology(ontology.uid) @@ -636,13 +783,23 @@ def module_project(client: Client, rand_gen, data_row_json_by_media_type, reques """Generates a image project that scopes to the test module(file). Used to reduce api calls.""" media_type = getattr(request, "param", MediaType.Image) + media_type = getattr(request, "param", MediaType.Image) + dataset = None - project, ontology, dataset = _create_project(client, rand_gen, data_row_json_by_media_type, media_type, normalized_ontology_by_media_type) + if media_type == MediaType.LLMPromptCreation or media_type == MediaType.LLMPromptResponseCreation: + project, ontology = _create_prompt_response_project(client, rand_gen, media_type, normalized_ontology_by_media_type) + elif media_type == OntologyKind.ResponseCreation: + project, ontology, dataset = _create_response_creation_project(client, rand_gen, data_row_json_by_media_type, media_type, normalized_ontology_by_media_type) + else: + project, ontology, dataset = _create_project(client, rand_gen, data_row_json_by_media_type, media_type, normalized_ontology_by_media_type) yield project project.delete() - dataset.delete() + + if dataset: + dataset.delete() + client.delete_unused_ontology(ontology.uid) @@ -1051,6 +1208,61 @@ def checklist_inference_index(prediction_id_mapping): checklists.append(checklist) return checklists +@pytest.fixture +def prompt_text_inference(prediction_id_mapping): + prompt_texts = [] + for feature in prediction_id_mapping: + if "prompt" not in feature: + continue + text = feature["prompt"].copy() + text.update({"answer": "free form text..."}) + del text["tool"] + prompt_texts.append(text) + return prompt_texts + +@pytest.fixture +def radio_response_inference(prediction_id_mapping): + response_radios = [] + for feature in prediction_id_mapping: + if "response-radio" not in feature: + continue + response_radio = feature["response-radio"].copy() + response_radio.update({ + "answer": {"name": "first_radio_answer"}, + }) + del response_radio["tool"] + response_radios.append(response_radio) + return response_radios + +@pytest.fixture +def checklist_response_inference(prediction_id_mapping): + response_checklists = [] + for feature in prediction_id_mapping: + if "response-checklist" not in feature: + continue + response_checklist = feature["response-checklist"].copy() + response_checklist.update({ + "answer": [ + {"name": "first_checklist_answer"}, + {"name": "second_checklist_answer"} + ] + }) + del response_checklist["tool"] + response_checklists.append(response_checklist) + return response_checklists + +@pytest.fixture +def text_response_inference(prediction_id_mapping): + response_texts = [] + for feature in prediction_id_mapping: + if "response-text" not in feature: + continue + text = feature["response-text"].copy() + text.update({"answer": "free form text..."}) + del text["tool"] + response_texts.append(text) + return response_texts + @pytest.fixture def text_inference(prediction_id_mapping): @@ -1133,6 +1345,10 @@ def annotations_by_media_type( checklist_inference, text_inference, video_checklist_inference, + prompt_text_inference, + checklist_response_inference, + radio_response_inference, + text_response_inference ): return { MediaType.Audio: [checklist_inference, text_inference], @@ -1158,6 +1374,9 @@ def annotations_by_media_type( ], MediaType.Text: [checklist_inference, text_inference, entity_inference], MediaType.Video: [video_checklist_inference], + MediaType.LLMPromptResponseCreation: [prompt_text_inference, text_response_inference, checklist_response_inference, radio_response_inference], + MediaType.LLMPromptCreation: [prompt_text_inference], + OntologyKind.ResponseCreation: [text_response_inference, checklist_response_inference, radio_response_inference] } @@ -1473,8 +1692,8 @@ def expected_export_v2_audio(): }, }, ], - "timestamp": {}, "segments": {}, + "timestamp": {} } return expected_annotations @@ -1773,65 +1992,49 @@ def expected_export_v2_document(): } return expected_annotations - @pytest.fixture() -def expected_export_v2_llm_prompt_creation(): +def expected_export_v2_llm_prompt_response_creation(): expected_annotations = { "objects": [], "classifications": [ - { - "name": - "checklist", - "value": - "checklist", - "checklist_answers": [{ - "name": "first_checklist_answer", - "value": "first_checklist_answer", - "classifications": [] - }, { - "name": "second_checklist_answer", - "value": "second_checklist_answer", - "classifications": [] - }], - }, - { - "name": "text", - "value": "text", + "name": "prompt-text", + "value": "prompt-text", "text_answer": { "content": "free form text..." }, - }, + }, + {'name': 'response-text', + 'text_answer': {'content': 'free form text...'}, + 'value': 'response-text'}, + {'checklist_answers': [ + {'classifications': [], + 'name': 'first_checklist_answer', + 'value': 'first_checklist_answer'}, + {'classifications': [], + 'name': 'second_checklist_answer', + 'value': 'second_checklist_answer'}], + 'name': 'checklist-response', + 'value': 'checklist-response'}, + {'name': 'radio-response', + 'radio_answer': {'classifications': [], + 'name': 'first_radio_answer', + 'value': 'first_radio_answer'}, + 'name': 'radio-response', + 'value': 'radio-response'}, ], "relationships": [], } return expected_annotations - @pytest.fixture() -def expected_export_v2_llm_prompt_response_creation(): +def expected_export_v2_llm_prompt_creation(): expected_annotations = { "objects": [], "classifications": [ { - "name": - "checklist", - "value": - "checklist", - "checklist_answers": [{ - "name": "first_checklist_answer", - "value": "first_checklist_answer", - "classifications": [] - }, - { - "name": "second_checklist_answer", - "value": "second_checklist_answer", - "classifications": [] - }], - }, - { - "name": "text", - "value": "text", + "name": "prompt-text", + "value": "prompt-text", "text_answer": { "content": "free form text..." }, @@ -1841,6 +2044,33 @@ def expected_export_v2_llm_prompt_response_creation(): } return expected_annotations +@pytest.fixture() +def expected_export_v2_llm_response_creation(): + expected_annotations = { + 'objects': [], + 'relationships': [], + "classifications": [ + {'name': 'response-text', + 'text_answer': {'content': 'free form text...'}, + 'value': 'response-text'}, + {'checklist_answers': [ + {'classifications': [], + 'name': 'first_checklist_answer', + 'value': 'first_checklist_answer'}, + {'classifications': [], + 'name': 'second_checklist_answer', + 'value': 'second_checklist_answer'}], + 'name': 'checklist-response', + 'value': 'checklist-response'}, + {'name': 'radio-response', + 'radio_answer': {'classifications': [], + 'name': 'first_radio_answer', + 'value': 'first_radio_answer'}, + 'name': 'radio-response', + 'value': 'radio-response'}, + ], + } + return expected_annotations @pytest.fixture def exports_v2_by_media_type( @@ -1852,6 +2082,9 @@ def exports_v2_by_media_type( expected_export_v2_conversation, expected_export_v2_dicom, expected_export_v2_document, + expected_export_v2_llm_prompt_response_creation, + expected_export_v2_llm_prompt_creation, + expected_export_v2_llm_response_creation ): return { MediaType.Image: @@ -1870,6 +2103,12 @@ def exports_v2_by_media_type( expected_export_v2_dicom, MediaType.Document: expected_export_v2_document, + MediaType.LLMPromptResponseCreation: + expected_export_v2_llm_prompt_response_creation, + MediaType.LLMPromptCreation: + expected_export_v2_llm_prompt_creation, + OntologyKind.ResponseCreation: + expected_export_v2_llm_response_creation } diff --git a/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py b/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py index ea6b5876b..56b6ba67f 100644 --- a/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py +++ b/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py @@ -8,7 +8,7 @@ import labelbox as lb from labelbox.schema.media_type import MediaType from labelbox.schema.annotation_import import AnnotationImportState -from labelbox import Project, Client +from labelbox import Project, Client, OntologyKind import itertools """ @@ -33,6 +33,9 @@ def validate_iso_format(date_string: str): (MediaType.Video, GenericDataRowData), (MediaType.Conversational, GenericDataRowData), (MediaType.Document, GenericDataRowData), + (MediaType.LLMPromptResponseCreation, GenericDataRowData), + (MediaType.LLMPromptCreation, GenericDataRowData), + (OntologyKind.ResponseCreation, GenericDataRowData) ], ) def test_generic_data_row_type_by_data_row_id( @@ -63,6 +66,9 @@ def test_generic_data_row_type_by_data_row_id( (MediaType.Video, GenericDataRowData), (MediaType.Conversational, GenericDataRowData), (MediaType.Document, GenericDataRowData), + (MediaType.LLMPromptResponseCreation, GenericDataRowData), + (MediaType.LLMPromptCreation, GenericDataRowData), + (OntologyKind.ResponseCreation, GenericDataRowData) ], ) def test_generic_data_row_type_by_global_key( @@ -83,20 +89,22 @@ def test_generic_data_row_type_by_global_key( assert label.annotations == data_label.annotations -# TODO: add MediaType.LLMPromptResponseCreation(data gen) once supported and llm human preference once media type is added @pytest.mark.parametrize( - "configured_project", + "configured_project, media_type", [ - MediaType.Audio, - MediaType.Html, - MediaType.Image, - MediaType.Text, - MediaType.Video, - MediaType.Conversational, - MediaType.Document, - MediaType.Dicom, + (MediaType.Audio, MediaType.Audio), + (MediaType.Html, MediaType.Html), + (MediaType.Image, MediaType.Image), + (MediaType.Text, MediaType.Text), + (MediaType.Video, MediaType.Video), + (MediaType.Conversational, MediaType.Conversational), + (MediaType.Document, MediaType.Document), + (MediaType.Dicom, MediaType.Dicom), + (MediaType.LLMPromptResponseCreation, MediaType.LLMPromptResponseCreation), + (MediaType.LLMPromptCreation, MediaType.LLMPromptCreation), + (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation) ], - indirect=True + indirect=["configured_project"] ) def test_import_media_types( client: Client, @@ -105,11 +113,12 @@ def test_import_media_types( exports_v2_by_media_type, export_v2_test_helpers, helpers, + media_type, ): - annotations_ndjson = list(itertools.chain.from_iterable(annotations_by_media_type[configured_project.media_type])) + annotations_ndjson = list(itertools.chain.from_iterable(annotations_by_media_type[media_type])) label_import = lb.LabelImport.create_from_objects( - client, configured_project.uid, f"test-import-{configured_project.media_type}", annotations_ndjson) + client, configured_project.uid, f"test-import-{media_type}", annotations_ndjson) label_import.wait_until_done() assert label_import.state == AnnotationImportState.FINISHED @@ -133,7 +142,7 @@ def test_import_media_types( exported_project_labels = exported_project["labels"][0] exported_annotations = exported_project_labels["annotations"] - expected_data = exports_v2_by_media_type[configured_project.media_type] + expected_data = exports_v2_by_media_type[media_type] helpers.remove_keys_recursive(exported_annotations, ["feature_id", "feature_schema_id"]) helpers.rename_cuid_key_recursive(exported_annotations) @@ -141,20 +150,20 @@ def test_import_media_types( assert exported_annotations == expected_data -@pytest.mark.order(1) @pytest.mark.parametrize( - "configured_project_by_global_key", + "configured_project_by_global_key, media_type", [ - MediaType.Audio, - MediaType.Html, - MediaType.Image, - MediaType.Text, - MediaType.Video, - MediaType.Conversational, - MediaType.Document, - MediaType.Dicom, + (MediaType.Audio, MediaType.Audio), + (MediaType.Html, MediaType.Html), + (MediaType.Image, MediaType.Image), + (MediaType.Text, MediaType.Text), + (MediaType.Video, MediaType.Video), + (MediaType.Conversational, MediaType.Conversational), + (MediaType.Document, MediaType.Document), + (MediaType.Dicom, MediaType.Dicom), + (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation) ], - indirect=True + indirect=["configured_project_by_global_key"] ) def test_import_media_types_by_global_key( client, @@ -163,11 +172,12 @@ def test_import_media_types_by_global_key( exports_v2_by_media_type, export_v2_test_helpers, helpers, -): - annotations_ndjson = list(itertools.chain.from_iterable(annotations_by_media_type[configured_project_by_global_key.media_type])) + media_type + ): + annotations_ndjson = list(itertools.chain.from_iterable(annotations_by_media_type[media_type])) label_import = lb.LabelImport.create_from_objects( - client, configured_project_by_global_key.uid, f"test-import-{configured_project_by_global_key.media_type}", annotations_ndjson) + client, configured_project_by_global_key.uid, f"test-import-{media_type}", annotations_ndjson) label_import.wait_until_done() assert label_import.state == AnnotationImportState.FINISHED @@ -191,7 +201,7 @@ def test_import_media_types_by_global_key( exported_project_labels = exported_project["labels"][0] exported_annotations = exported_project_labels["annotations"] - expected_data = exports_v2_by_media_type[configured_project_by_global_key.media_type] + expected_data = exports_v2_by_media_type[media_type] helpers.remove_keys_recursive(exported_annotations, ["feature_id", "feature_schema_id"]) helpers.rename_cuid_key_recursive(exported_annotations) @@ -200,25 +210,29 @@ def test_import_media_types_by_global_key( @pytest.mark.parametrize( - "configured_project", + "configured_project, media_type", [ - MediaType.Audio, - MediaType.Html, - MediaType.Image, - MediaType.Text, - MediaType.Video, - MediaType.Conversational, - MediaType.Document, - MediaType.Dicom, + (MediaType.Audio, MediaType.Audio), + (MediaType.Html, MediaType.Html), + (MediaType.Image, MediaType.Image), + (MediaType.Text, MediaType.Text), + (MediaType.Video, MediaType.Video), + (MediaType.Conversational, MediaType.Conversational), + (MediaType.Document, MediaType.Document), + (MediaType.Dicom, MediaType.Dicom), + (MediaType.LLMPromptResponseCreation, MediaType.LLMPromptResponseCreation), + (MediaType.LLMPromptCreation, MediaType.LLMPromptCreation), + (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation) ], - indirect=True + indirect=["configured_project"] ) def test_import_mal_annotations( client, configured_project: Project, annotations_by_media_type, + media_type ): - annotations_ndjson = list(itertools.chain.from_iterable(annotations_by_media_type[configured_project.media_type])) + annotations_ndjson = list(itertools.chain.from_iterable(annotations_by_media_type[media_type])) import_annotations = lb.MALPredictionImport.create_from_objects( client=client, @@ -233,24 +247,26 @@ def test_import_mal_annotations( @pytest.mark.parametrize( - "configured_project_by_global_key", + "configured_project_by_global_key, media_type", [ - MediaType.Audio, - MediaType.Html, - MediaType.Image, - MediaType.Text, - MediaType.Video, - MediaType.Conversational, - MediaType.Document, - MediaType.Dicom, + (MediaType.Audio, MediaType.Audio), + (MediaType.Html, MediaType.Html), + (MediaType.Image, MediaType.Image), + (MediaType.Text, MediaType.Text), + (MediaType.Video, MediaType.Video), + (MediaType.Conversational, MediaType.Conversational), + (MediaType.Document, MediaType.Document), + (MediaType.Dicom, MediaType.Dicom), + (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation) ], - indirect=True + indirect=["configured_project_by_global_key"] ) def test_import_mal_annotations_global_key(client, configured_project_by_global_key: Project, - annotations_by_media_type): + annotations_by_media_type, + media_type): - annotations_ndjson = list(itertools.chain.from_iterable(annotations_by_media_type[configured_project_by_global_key.media_type])) + annotations_ndjson = list(itertools.chain.from_iterable(annotations_by_media_type[media_type])) import_annotations = lb.MALPredictionImport.create_from_objects( client=client,