diff --git a/libs/labelbox/tests/data/annotation_import/conftest.py b/libs/labelbox/tests/data/annotation_import/conftest.py index e2b0a838c..cdc8e793b 100644 --- a/libs/labelbox/tests/data/annotation_import/conftest.py +++ b/libs/labelbox/tests/data/annotation_import/conftest.py @@ -15,7 +15,6 @@ from labelbox.schema.annotation_import import LabelImport, AnnotationImportState from pytest import FixtureRequest from contextlib import suppress - """ The main fixtures of this library are configured_project and configured_project_by_global_key. Both fixtures generate data rows with a parametrize media type. They create the amount of data rows equal to the DATA_ROW_COUNT variable below. The data rows are generated with a factory fixture that returns a function that allows you to pass a global key. The ontologies are generated normalized and based on the MediaType given (i.e. only features supported by MediaType are created). This ontology is later used to obtain the correct annotations with the prediction_id_mapping and corresponding inferences. Each data row will have all possible annotations attached supported for the MediaType. """ @@ -24,8 +23,10 @@ DATA_ROW_PROCESSING_WAIT_TIMEOUT_SECONDS = 40 DATA_ROW_PROCESSING_WAIT_SLEEP_INTERNAL_SECONDS = 7 + @pytest.fixture(scope="module", autouse=True) def video_data_row_factory(): + def video_data_row(global_key): return { "row_data": @@ -35,10 +36,13 @@ def video_data_row(global_key): "media_type": "VIDEO", } + return video_data_row + @pytest.fixture(scope="module", autouse=True) def audio_data_row_factory(): + def audio_data_row(global_key): return { "row_data": @@ -48,10 +52,13 @@ def audio_data_row(global_key): "media_type": "AUDIO", } + return audio_data_row + @pytest.fixture(scope="module", autouse=True) def conversational_data_row_factory(): + def conversational_data_row(global_key): return { "row_data": @@ -59,10 +66,13 @@ def conversational_data_row(global_key): "global_key": f"https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json-{global_key}", } + return conversational_data_row + @pytest.fixture(scope="module", autouse=True) def dicom_data_row_factory(): + def dicom_data_row(global_key): return { "row_data": @@ -72,10 +82,13 @@ def dicom_data_row(global_key): "media_type": "DICOM", } + return dicom_data_row + @pytest.fixture(scope="module", autouse=True) def geospatial_data_row_factory(): + def geospatial_data_row(global_key): return { "row_data": { @@ -97,11 +110,13 @@ def geospatial_data_row(global_key): "media_type": "TMS_GEO", } + return geospatial_data_row @pytest.fixture(scope="module", autouse=True) def html_data_row_factory(): + def html_data_row(global_key): return { "row_data": @@ -109,11 +124,13 @@ def html_data_row(global_key): "global_key": f"https://storage.googleapis.com/labelbox-datasets/html_sample_data/sample_html_1.html-{global_key}", } + return html_data_row @pytest.fixture(scope="module", autouse=True) def image_data_row_factory(): + def image_data_row(global_key): return { "row_data": @@ -123,11 +140,13 @@ def image_data_row(global_key): "media_type": "IMAGE", } + return image_data_row @pytest.fixture(scope="module", autouse=True) def document_data_row_factory(): + def document_data_row(global_key): return { "row_data": { @@ -141,11 +160,13 @@ def document_data_row(global_key): "media_type": "PDF", } + return document_data_row @pytest.fixture(scope="module", autouse=True) def text_data_row_factory(): + def text_data_row(global_key): return { "row_data": @@ -155,15 +176,21 @@ def text_data_row(global_key): "media_type": "TEXT", } + return text_data_row + @pytest.fixture(scope="module", autouse=True) def llm_human_preference_data_row_factory(): - def llm_human_preference_data_row(global_key): + + def llm_human_preference_data_row(global_key): return { - "row_data": "https://storage.googleapis.com/labelbox-datasets/sdk_test/llm_prompt_response_conv.json", - "global_key": global_key, + "row_data": + "https://storage.googleapis.com/labelbox-datasets/sdk_test/llm_prompt_response_conv.json", + "global_key": + global_key, } + return llm_human_preference_data_row @@ -190,12 +217,12 @@ def data_row_json_by_media_type( MediaType.Text: text_data_row_factory, MediaType.Video: video_data_row_factory, } - + @pytest.fixture(scope="module", autouse=True) def normalized_ontology_by_media_type(): """Returns NDJSON of ontology based on media type""" - + bbox_tool_with_nested_text = { "required": False, @@ -254,14 +281,10 @@ def normalized_ontology_by_media_type(): } bbox_tool = { - "required": - False, - "name": - "bbox", - "tool": - "rectangle", - "color": - "#a23030", + "required": False, + "name": "bbox", + "tool": "rectangle", + "color": "#a23030", "classifications": [], } @@ -386,7 +409,7 @@ def normalized_ontology_by_media_type(): }, ], } - + prompt_text = { "instructions": "prompt-text", "name": "prompt-text", @@ -397,41 +420,39 @@ def normalized_ontology_by_media_type(): "schemaNodeId": None, "type": "prompt" } - + response_radio = { "instructions": "radio-response", "name": "radio-response", "options": [{ - "label": "first_radio_answer", - "value": "first_radio_answer", - "options": [] - }, - { - "label": "second_radio_answer", - "value": "second_radio_answer", - "options": [] - }], + "label": "first_radio_answer", + "value": "first_radio_answer", + "options": [] + }, { + "label": "second_radio_answer", + "value": "second_radio_answer", + "options": [] + }], "required": True, "type": "response-radio" } - + response_checklist = { "instructions": "checklist-response", "name": "checklist-response", "options": [{ - "label": "first_checklist_answer", - "value": "first_checklist_answer", - "options": [] - }, - { - "label": "second_checklist_answer", - "value": "second_checklist_answer", - "options": [] - }], + "label": "first_checklist_answer", + "value": "first_checklist_answer", + "options": [] + }, { + "label": "second_checklist_answer", + "value": "second_checklist_answer", + "options": [] + }], "required": True, "type": "response-checklist" } - + response_text = { "instructions": "response-text", "maxCharacters": 20, @@ -442,150 +463,126 @@ def normalized_ontology_by_media_type(): } return { - MediaType.Image: { - "tools": [ - bbox_tool, - bbox_tool_with_nested_text, - polygon_tool, - polyline_tool, - point_tool, - raster_segmentation_tool, + MediaType.Image: { + "tools": [ + bbox_tool, + bbox_tool_with_nested_text, + polygon_tool, + polyline_tool, + point_tool, + raster_segmentation_tool, ], - "classifications": [ - checklist, - free_form_text, - radio, + "classifications": [ + checklist, + free_form_text, + radio, ] - }, - MediaType.Text: { - "tools": [ - entity_tool - ], - "classifications": [ - checklist, - free_form_text, - radio, + }, + MediaType.Text: { + "tools": [entity_tool], + "classifications": [ + checklist, + free_form_text, + radio, ] }, - MediaType.Video: { - "tools": [ - bbox_tool, - bbox_tool_with_nested_text, - polyline_tool, - point_tool, - raster_segmentation_tool, + MediaType.Video: { + "tools": [ + bbox_tool, + bbox_tool_with_nested_text, + polyline_tool, + point_tool, + raster_segmentation_tool, ], - "classifications": [ - checklist, - free_form_text, - radio, - checklist_index, - free_form_text_index + "classifications": [ + checklist, free_form_text, radio, checklist_index, + free_form_text_index ] }, - MediaType.Geospatial_Tile: { - "tools": [ - bbox_tool, - bbox_tool_with_nested_text, - polygon_tool, - polyline_tool, - point_tool, + MediaType.Geospatial_Tile: { + "tools": [ + bbox_tool, + bbox_tool_with_nested_text, + polygon_tool, + polyline_tool, + point_tool, ], - "classifications": [ - checklist, - free_form_text, - radio, + "classifications": [ + checklist, + free_form_text, + radio, ] }, - MediaType.Document: { - "tools": [ - entity_tool, - bbox_tool, - bbox_tool_with_nested_text - ], - "classifications": [ - checklist, - free_form_text, - radio, + MediaType.Document: { + "tools": [entity_tool, bbox_tool, bbox_tool_with_nested_text], + "classifications": [ + checklist, + free_form_text, + radio, ] }, - MediaType.Audio: { - "tools":[], - "classifications": [ - checklist, - free_form_text, - radio, + MediaType.Audio: { + "tools": [], + "classifications": [ + checklist, + free_form_text, + radio, ] }, - MediaType.Html: { - "tools": [], - "classifications": [ - checklist, - free_form_text, - radio, + MediaType.Html: { + "tools": [], + "classifications": [ + checklist, + free_form_text, + radio, ] }, - MediaType.Dicom: { - "tools": [ - raster_segmentation_tool, - polyline_tool - ], - "classifications": [] + MediaType.Dicom: { + "tools": [raster_segmentation_tool, polyline_tool], + "classifications": [] }, - MediaType.Conversational: { - "tools": [ - entity_tool - ], - "classifications": [ - checklist, - free_form_text, - radio, - checklist_index, - free_form_text_index + MediaType.Conversational: { + "tools": [entity_tool], + "classifications": [ + checklist, free_form_text, radio, checklist_index, + free_form_text_index ] }, - MediaType.LLMPromptResponseCreation: { - "tools": [], - "classifications": [ - prompt_text, - response_text, - response_radio, - response_checklist - ] - }, - MediaType.LLMPromptCreation: { - "tools": [], - "classifications": [ - prompt_text - ] - }, - OntologyKind.ResponseCreation: { - "tools": [], - "classifications": [ - response_text, - response_radio, - response_checklist - ] - }, - "all": { - "tools":[ - bbox_tool, - bbox_tool_with_nested_text, - polygon_tool, - polyline_tool, - point_tool, - entity_tool, - segmentation_tool, - raster_segmentation_tool, - ], - "classifications": [ - checklist, - checklist_index, - free_form_text, - free_form_text_index, - radio, - ] - } + MediaType.LLMPromptResponseCreation: { + "tools": [], + "classifications": [ + prompt_text, response_text, response_radio, response_checklist + ] + }, + MediaType.LLMPromptCreation: { + "tools": [], + "classifications": [prompt_text] + }, + OntologyKind.ResponseCreation: { + "tools": [], + "classifications": [ + response_text, response_radio, response_checklist + ] + }, + "all": { + "tools": [ + bbox_tool, + bbox_tool_with_nested_text, + polygon_tool, + polyline_tool, + point_tool, + entity_tool, + segmentation_tool, + raster_segmentation_tool, + ], + "classifications": [ + checklist, + checklist_index, + free_form_text, + free_form_text_index, + radio, + ] + } } @@ -617,6 +614,7 @@ def func(project): ##### Unit test strategies ##### + @pytest.fixture def hardcoded_datarow_id(): data_row_id = 'ck8q9q9qj00003g5z3q1q9q9q' @@ -639,25 +637,31 @@ def get_global_key(): ##### Integration test strategies ##### -def _create_response_creation_project(client: Client, rand_gen, data_row_json_by_media_type, ontology_kind, normalized_ontology_by_media_type) -> Tuple[Project, Ontology, Dataset]: + +def _create_response_creation_project( + client: Client, rand_gen, data_row_json_by_media_type, ontology_kind, + normalized_ontology_by_media_type) -> Tuple[Project, Ontology, Dataset]: "For response creation projects" - + dataset = client.create_dataset(name=rand_gen(str)) - - project = client.create_response_creation_project(name=f"{ontology_kind}-{rand_gen(str)}") - - ontology = client.create_ontology(name=f"{ontology_kind}-{rand_gen(str)}", - normalized=normalized_ontology_by_media_type[ontology_kind], - media_type=MediaType.Text, - ontology_kind=ontology_kind) + + project = client.create_response_creation_project( + name=f"{ontology_kind}-{rand_gen(str)}") + + ontology = client.create_ontology( + name=f"{ontology_kind}-{rand_gen(str)}", + normalized=normalized_ontology_by_media_type[ontology_kind], + media_type=MediaType.Text, + ontology_kind=ontology_kind) project.connect_ontology(ontology) data_row_data = [] for _ in range(DATA_ROW_COUNT): - data_row_data.append(data_row_json_by_media_type[MediaType.Text](rand_gen(str))) - + data_row_data.append(data_row_json_by_media_type[MediaType.Text]( + rand_gen(str))) + task = dataset.create_data_rows(data_row_data) task.wait_till_done() global_keys = [row['global_key'] for row in task.result] @@ -670,49 +674,74 @@ def _create_response_creation_project(client: Client, rand_gen, data_row_json_by ) project.data_row_ids = data_row_ids project.global_keys = global_keys - + return project, ontology, dataset -def _create_prompt_response_project(client: Client, rand_gen, media_type, normalized_ontology_by_media_type, export_v2_test_helpers) -> Tuple[Project, Ontology]: + +def _create_prompt_response_project( + client: Client, rand_gen, media_type, normalized_ontology_by_media_type, + export_v2_test_helpers) -> Tuple[Project, Ontology]: """For prompt response data row auto gen projects""" - - prompt_response_project = client.create_prompt_response_generation_project(name=f"{media_type.value}-{rand_gen(str)}", - dataset_name=rand_gen(str), - data_row_count=1, - media_type=media_type) - - ontology = client.create_ontology(name=f"{media_type}-{rand_gen(str)}", normalized=normalized_ontology_by_media_type[media_type], media_type=media_type) - + + prompt_response_project = client.create_prompt_response_generation_project( + name=f"{media_type.value}-{rand_gen(str)}", + dataset_name=rand_gen(str), + data_row_count=1, + media_type=media_type) + + ontology = client.create_ontology( + name=f"{media_type}-{rand_gen(str)}", + normalized=normalized_ontology_by_media_type[media_type], + media_type=media_type) + prompt_response_project.connect_ontology(ontology) - + # We have to export to get data row ids - result = export_v2_test_helpers.run_project_export_v2_task(prompt_response_project) - - data_row_ids = [dr["data_row"]["id"] for dr in result] - global_keys = [dr["data_row"]["global_key"] for dr in result] - + data_row_ids = [] + global_keys = [] + timeout = 0 + while len(data_row_ids) < 1 and timeout < 5: + result = export_v2_test_helpers.run_project_export_v2_task( + prompt_response_project) + + data_row_ids.extend([dr["data_row"]["id"] for dr in result]) + global_keys.extend([dr["data_row"]["global_key"] for dr in result]) + + time.sleep(5) + timeout += 1 + + if len(data_row_ids) < 1: + raise Exception("Failed to get data row ids") + prompt_response_project.data_row_ids = data_row_ids prompt_response_project.global_keys = global_keys - + return prompt_response_project, ontology -def _create_project(client: Client, rand_gen, data_row_json_by_media_type, media_type, normalized_ontology_by_media_type) -> Tuple[Project, Ontology, Dataset]: + +def _create_project( + client: Client, rand_gen, data_row_json_by_media_type, media_type, + normalized_ontology_by_media_type) -> Tuple[Project, Ontology, Dataset]: """ Shared function to configure project for integration tests """ - + dataset = client.create_dataset(name=rand_gen(str)) - + project = client.create_project(name=f"{media_type}-{rand_gen(str)}", media_type=media_type) - - ontology = client.create_ontology(name=f"{media_type}-{rand_gen(str)}", normalized=normalized_ontology_by_media_type[media_type], media_type=media_type) + + ontology = client.create_ontology( + name=f"{media_type}-{rand_gen(str)}", + normalized=normalized_ontology_by_media_type[media_type], + media_type=media_type) project.connect_ontology(ontology) data_row_data = [] for _ in range(DATA_ROW_COUNT): - data_row_data.append(data_row_json_by_media_type[media_type](rand_gen(str))) - + data_row_data.append(data_row_json_by_media_type[media_type]( + rand_gen(str))) + task = dataset.create_data_rows(data_row_data) task.wait_till_done() global_keys = [row['global_key'] for row in task.result] @@ -725,81 +754,105 @@ def _create_project(client: Client, rand_gen, data_row_json_by_media_type, media ) project.data_row_ids = data_row_ids project.global_keys = global_keys - + return project, ontology, dataset @pytest.fixture -def configured_project(client: Client, rand_gen, data_row_json_by_media_type, request: FixtureRequest, normalized_ontology_by_media_type, export_v2_test_helpers): +def configured_project(client: Client, rand_gen, data_row_json_by_media_type, + request: FixtureRequest, + normalized_ontology_by_media_type, + export_v2_test_helpers): """Configure project for test. Request.param will contain the media type if not present will use Image MediaType. The project will have 10 data rows.""" - + media_type = getattr(request, "param", MediaType.Image) dataset = None - + if media_type == MediaType.LLMPromptCreation or media_type == MediaType.LLMPromptResponseCreation: - project, ontology = _create_prompt_response_project(client, rand_gen, media_type, normalized_ontology_by_media_type, export_v2_test_helpers) + project, ontology = _create_prompt_response_project( + client, rand_gen, media_type, normalized_ontology_by_media_type, + export_v2_test_helpers) elif media_type == OntologyKind.ResponseCreation: - project, ontology, dataset = _create_response_creation_project(client, rand_gen, data_row_json_by_media_type, media_type, normalized_ontology_by_media_type) - else: - project, ontology, dataset = _create_project(client, rand_gen, data_row_json_by_media_type, media_type, normalized_ontology_by_media_type) + project, ontology, dataset = _create_response_creation_project( + client, rand_gen, data_row_json_by_media_type, media_type, + normalized_ontology_by_media_type) + else: + project, ontology, dataset = _create_project( + client, rand_gen, data_row_json_by_media_type, media_type, + normalized_ontology_by_media_type) yield project - + project.delete() - + if dataset: dataset.delete() - + client.delete_unused_ontology(ontology.uid) @pytest.fixture() -def configured_project_by_global_key(client: Client, rand_gen, data_row_json_by_media_type, request: FixtureRequest, normalized_ontology_by_media_type, export_v2_test_helpers): +def configured_project_by_global_key(client: Client, rand_gen, + data_row_json_by_media_type, + request: FixtureRequest, + normalized_ontology_by_media_type, + export_v2_test_helpers): """Does the same thing as configured project but with global keys focus.""" - + media_type = getattr(request, "param", MediaType.Image) dataset = None - + if media_type == MediaType.LLMPromptCreation or media_type == MediaType.LLMPromptResponseCreation: - project, ontology = _create_prompt_response_project(client, rand_gen, media_type, normalized_ontology_by_media_type, export_v2_test_helpers) + project, ontology = _create_prompt_response_project( + client, rand_gen, media_type, normalized_ontology_by_media_type, + export_v2_test_helpers) elif media_type == OntologyKind.ResponseCreation: - project, ontology, dataset = _create_response_creation_project(client, rand_gen, data_row_json_by_media_type, media_type, normalized_ontology_by_media_type) - else: - project, ontology, dataset = _create_project(client, rand_gen, data_row_json_by_media_type, media_type, normalized_ontology_by_media_type) + project, ontology, dataset = _create_response_creation_project( + client, rand_gen, data_row_json_by_media_type, media_type, + normalized_ontology_by_media_type) + else: + project, ontology, dataset = _create_project( + client, rand_gen, data_row_json_by_media_type, media_type, + normalized_ontology_by_media_type) yield project - + project.delete() - + if dataset: dataset.delete() - - client.delete_unused_ontology(ontology.uid) + client.delete_unused_ontology(ontology.uid) @pytest.fixture(scope="module") -def module_project(client: Client, rand_gen, data_row_json_by_media_type, request: FixtureRequest, normalized_ontology_by_media_type): +def module_project(client: Client, rand_gen, data_row_json_by_media_type, + request: FixtureRequest, normalized_ontology_by_media_type): """Generates a image project that scopes to the test module(file). Used to reduce api calls.""" - + media_type = getattr(request, "param", MediaType.Image) media_type = getattr(request, "param", MediaType.Image) dataset = None - + if media_type == MediaType.LLMPromptCreation or media_type == MediaType.LLMPromptResponseCreation: - project, ontology = _create_prompt_response_project(client, rand_gen, media_type, normalized_ontology_by_media_type) + project, ontology = _create_prompt_response_project( + client, rand_gen, media_type, normalized_ontology_by_media_type) elif media_type == OntologyKind.ResponseCreation: - project, ontology, dataset = _create_response_creation_project(client, rand_gen, data_row_json_by_media_type, media_type, normalized_ontology_by_media_type) - else: - project, ontology, dataset = _create_project(client, rand_gen, data_row_json_by_media_type, media_type, normalized_ontology_by_media_type) + project, ontology, dataset = _create_response_creation_project( + client, rand_gen, data_row_json_by_media_type, media_type, + normalized_ontology_by_media_type) + else: + project, ontology, dataset = _create_project( + client, rand_gen, data_row_json_by_media_type, media_type, + normalized_ontology_by_media_type) yield project - + project.delete() - + if dataset: dataset.delete() - + client.delete_unused_ontology(ontology.uid) @@ -821,64 +874,74 @@ def prediction_id_mapping(request, normalized_ontology_by_media_type): Data row identifiers (ids the annotation uses) Ontology: normalized ontology """ - + if "configured_project" in request.fixturenames: project = request.getfixturevalue("configured_project") - data_row_identifiers = [{"id": data_row_id} for data_row_id in project.data_row_ids] + data_row_identifiers = [{ + "id": data_row_id + } for data_row_id in project.data_row_ids] ontology = project.ontology().normalized - + elif "configured_project_by_global_key" in request.fixturenames: project = request.getfixturevalue("configured_project_by_global_key") - data_row_identifiers = [{"globalKey": global_key} for global_key in project.global_keys] + data_row_identifiers = [{ + "globalKey": global_key + } for global_key in project.global_keys] ontology = project.ontology().normalized - + elif "module_project" in request.fixturenames: project = request.getfixturevalue("module_project") - data_row_identifiers = [{"id": data_row_id} for data_row_id in project.data_row_ids] + data_row_identifiers = [{ + "id": data_row_id + } for data_row_id in project.data_row_ids] ontology = project.ontology().normalized - + elif "hardcoded_datarow_id" in request.fixturenames: if "media_type" not in request.fixturenames: raise Exception("Please include a 'media_type' fixture") project = None media_type = request.getfixturevalue("media_type") ontology = normalized_ontology_by_media_type[media_type] - data_row_identifiers = [{"id": request.getfixturevalue("hardcoded_datarow_id")()}] - + data_row_identifiers = [{ + "id": request.getfixturevalue("hardcoded_datarow_id")() + }] + elif "hardcoded_global_key" in request.fixturenames: if "media_type" not in request.fixturenames: raise Exception("Please include a 'media_type' fixture") project = None media_type = request.getfixturevalue("media_type") ontology = normalized_ontology_by_media_type[media_type] - data_row_identifiers = [{"globalKey": request.getfixturevalue("hardcoded_global_key")()}] + data_row_identifiers = [{ + "globalKey": request.getfixturevalue("hardcoded_global_key")() + }] # Used for tests that need access to every ontology else: project = None media_type = None ontology = normalized_ontology_by_media_type["all"] - data_row_identifiers = [{"id":"ck8q9q9qj00003g5z3q1q9q9q"}] - + data_row_identifiers = [{"id": "ck8q9q9qj00003g5z3q1q9q9q"}] + base_annotations = [] for data_row_identifier in data_row_identifiers: base_annotation = {} for feature in (ontology["tools"] + ontology["classifications"]): if "tool" in feature: - feature_type = (feature["tool"] if feature["classifications"] == [] else - f"{feature['tool']}_nested" - ) # tool vs nested classification tool + feature_type = (feature["tool"] if feature["classifications"] + == [] else f"{feature['tool']}_nested" + ) # tool vs nested classification tool else: feature_type = (feature["type"] if "scope" not in feature else - f"{feature['type']}_{feature['scope']}" - ) # checklist vs indexed checklist + f"{feature['type']}_{feature['scope']}" + ) # checklist vs indexed checklist base_annotation[feature_type] = { "uuid": str(uuid.uuid4()), "name": feature["name"], "tool": feature, "dataRow": data_row_identifier - } + } base_annotations.append(base_annotation) return base_annotations @@ -952,16 +1015,15 @@ def rectangle_inference_with_confidence(prediction_id_mapping): "width": 12 }, "classifications": [{ - "name": - rectangle["tool"]["classifications"][0]["name"], + "name": rectangle["tool"]["classifications"][0]["name"], "answer": { "name": rectangle["tool"]["classifications"][0]["options"][0] ["value"], "classifications": [{ "name": - rectangle["tool"]["classifications"][0]["options"][0] - ["options"][1]["name"], + rectangle["tool"]["classifications"][0]["options"] + [0]["options"][1]["name"], "answer": "nested answer", }], @@ -996,14 +1058,15 @@ def line_inference(prediction_id_mapping): if "line" not in feature: continue line = feature["line"].copy() - line.update( - {"line": [{ + line.update({ + "line": [{ "x": 147.692, "y": 118.154 }, { "x": 150.692, "y": 160.154 - }]}) + }] + }) del line["tool"] lines.append(line) return lines @@ -1180,10 +1243,11 @@ def checklist_inference(prediction_id_mapping): continue checklist = feature["checklist"].copy() checklist.update({ - "answers": [ - {"name": "first_checklist_answer"}, - {"name": "second_checklist_answer"} - ] + "answers": [{ + "name": "first_checklist_answer" + }, { + "name": "second_checklist_answer" + }] }) del checklist["tool"] checklists.append(checklist) @@ -1198,16 +1262,18 @@ def checklist_inference_index(prediction_id_mapping): return None checklist = feature["checklist_index"].copy() checklist.update({ - "answers": [ - {"name": "first_checklist_answer"}, - {"name": "second_checklist_answer"} - ], + "answers": [{ + "name": "first_checklist_answer" + }, { + "name": "second_checklist_answer" + }], "messageId": "0", }) del checklist["tool"] checklists.append(checklist) return checklists + @pytest.fixture def prompt_text_inference(prediction_id_mapping): prompt_texts = [] @@ -1220,6 +1286,7 @@ def prompt_text_inference(prediction_id_mapping): prompt_texts.append(text) return prompt_texts + @pytest.fixture def radio_response_inference(prediction_id_mapping): response_radios = [] @@ -1228,12 +1295,15 @@ def radio_response_inference(prediction_id_mapping): continue response_radio = feature["response-radio"].copy() response_radio.update({ - "answer": {"name": "first_radio_answer"}, + "answer": { + "name": "first_radio_answer" + }, }) del response_radio["tool"] response_radios.append(response_radio) return response_radios + @pytest.fixture def checklist_response_inference(prediction_id_mapping): response_checklists = [] @@ -1242,15 +1312,17 @@ def checklist_response_inference(prediction_id_mapping): continue response_checklist = feature["response-checklist"].copy() response_checklist.update({ - "answer": [ - {"name": "first_checklist_answer"}, - {"name": "second_checklist_answer"} - ] + "answer": [{ + "name": "first_checklist_answer" + }, { + "name": "second_checklist_answer" + }] }) del response_checklist["tool"] response_checklists.append(response_checklist) return response_checklists + @pytest.fixture def text_response_inference(prediction_id_mapping): response_texts = [] @@ -1308,10 +1380,11 @@ def video_checklist_inference(prediction_id_mapping): continue checklist = feature["checklist"].copy() checklist.update({ - "answers": [ - {"name": "first_checklist_answer"}, - {"name": "second_checklist_answer"} - ] + "answers": [{ + "name": "first_checklist_answer" + }, { + "name": "second_checklist_answer" + }] }) checklist.update( @@ -1332,24 +1405,13 @@ def video_checklist_inference(prediction_id_mapping): @pytest.fixture def annotations_by_media_type( - polygon_inference, - rectangle_inference, - rectangle_inference_document, - line_inference_v2, - line_inference, - entity_inference, - entity_inference_index, - entity_inference_document, - checklist_inference_index, - text_inference_index, - checklist_inference, - text_inference, - video_checklist_inference, - prompt_text_inference, - checklist_response_inference, - radio_response_inference, - text_response_inference -): + polygon_inference, rectangle_inference, rectangle_inference_document, + line_inference_v2, line_inference, entity_inference, + entity_inference_index, entity_inference_document, + checklist_inference_index, text_inference_index, checklist_inference, + text_inference, video_checklist_inference, prompt_text_inference, + checklist_response_inference, radio_response_inference, + text_response_inference): return { MediaType.Audio: [checklist_inference, text_inference], MediaType.Conversational: [ @@ -1374,9 +1436,15 @@ def annotations_by_media_type( ], MediaType.Text: [checklist_inference, text_inference, entity_inference], MediaType.Video: [video_checklist_inference], - MediaType.LLMPromptResponseCreation: [prompt_text_inference, text_response_inference, checklist_response_inference, radio_response_inference], + MediaType.LLMPromptResponseCreation: [ + prompt_text_inference, text_response_inference, + checklist_response_inference, radio_response_inference + ], MediaType.LLMPromptCreation: [prompt_text_inference], - OntologyKind.ResponseCreation: [text_response_inference, checklist_response_inference, radio_response_inference] + OntologyKind.ResponseCreation: [ + text_response_inference, checklist_response_inference, + radio_response_inference + ] } @@ -1395,13 +1463,8 @@ def object_predictions( entity_inference, segmentation_inference, ): - return ( - polygon_inference + - rectangle_inference + - line_inference + - entity_inference + - segmentation_inference - ) + return (polygon_inference + rectangle_inference + line_inference + + entity_inference + segmentation_inference) @pytest.fixture @@ -1409,19 +1472,15 @@ def object_predictions_for_annotation_import(polygon_inference, rectangle_inference, line_inference, segmentation_inference): - return ( - polygon_inference + - rectangle_inference + - line_inference + - segmentation_inference - ) - + return (polygon_inference + rectangle_inference + line_inference + + segmentation_inference) @pytest.fixture def classification_predictions(checklist_inference, text_inference): return checklist_inference + text_inference + @pytest.fixture def predictions(object_predictions, classification_predictions): return object_predictions + classification_predictions @@ -1502,14 +1561,12 @@ def model_run_with_data_rows( @pytest.fixture -def model_run_with_all_project_labels( - client, - configured_project, - model_run_predictions, - model_run: ModelRun, - wait_for_label_processing -): - use_data_row_ids = list(set([p["dataRow"]["id"] for p in model_run_predictions])) +def model_run_with_all_project_labels(client, configured_project, + model_run_predictions, + model_run: ModelRun, + wait_for_label_processing): + use_data_row_ids = list( + set([p["dataRow"]["id"] for p in model_run_predictions])) model_run.upsert_data_rows(use_data_row_ids) @@ -1643,11 +1700,10 @@ def expected_export_v2_image(): "name": "first_checklist_answer", "value": "first_checklist_answer", "classifications": [] - }, - { + }, { "name": "second_checklist_answer", "value": "second_checklist_answer", - "classifications": [] + "classifications": [] }], }, { @@ -1677,11 +1733,10 @@ def expected_export_v2_audio(): "name": "first_checklist_answer", "value": "first_checklist_answer", "classifications": [] - }, - { + }, { "name": "second_checklist_answer", "value": "second_checklist_answer", - "classifications": [] + "classifications": [] }], }, { @@ -1719,11 +1774,10 @@ def expected_export_v2_html(): "name": "first_checklist_answer", "value": "first_checklist_answer", "classifications": [] - }, - { + }, { "name": "second_checklist_answer", "value": "second_checklist_answer", - "classifications": [] + "classifications": [] }], }, ], @@ -1756,11 +1810,10 @@ def expected_export_v2_text(): "name": "first_checklist_answer", "value": "first_checklist_answer", "classifications": [] - }, - { + }, { "name": "second_checklist_answer", "value": "second_checklist_answer", - "classifications": [] + "classifications": [] }], }, { @@ -1789,16 +1842,15 @@ def expected_export_v2_video(): "checklist", "value": "checklist", - "checklist_answers": [{ - "name": "first_checklist_answer", - "value": "first_checklist_answer", - "classifications": [] - }, - { - "name": "second_checklist_answer", - "value": "second_checklist_answer", - "classifications": [] - }], + "checklist_answers": [{ + "name": "first_checklist_answer", + "value": "first_checklist_answer", + "classifications": [] + }, { + "name": "second_checklist_answer", + "value": "second_checklist_answer", + "classifications": [] + }], }], } return expected_annotations @@ -1832,11 +1884,10 @@ def expected_export_v2_conversation(): "name": "first_checklist_answer", "value": "first_checklist_answer", "classifications": [] - }, - { + }, { "name": "second_checklist_answer", "value": "second_checklist_answer", - "classifications": [] + "classifications": [] }], }, { @@ -1973,11 +2024,10 @@ def expected_export_v2_document(): "name": "first_checklist_answer", "value": "first_checklist_answer", "classifications": [] - }, - { + }, { "name": "second_checklist_answer", "value": "second_checklist_answer", - "classifications": [] + "classifications": [] }], }, { @@ -1992,100 +2042,122 @@ def expected_export_v2_document(): } return expected_annotations + @pytest.fixture() def expected_export_v2_llm_prompt_response_creation(): expected_annotations = { "objects": [], "classifications": [ - { + { "name": "prompt-text", "value": "prompt-text", "text_answer": { "content": "free form text..." }, + }, + { + 'name': 'response-text', + 'text_answer': { + 'content': 'free form text...' }, - {'name': 'response-text', - 'text_answer': {'content': 'free form text...'}, - 'value': 'response-text'}, - {'checklist_answers': [ - {'classifications': [], - 'name': 'first_checklist_answer', - 'value': 'first_checklist_answer'}, - {'classifications': [], - 'name': 'second_checklist_answer', - 'value': 'second_checklist_answer'}], + 'value': 'response-text' + }, + { + 'checklist_answers': [{ + 'classifications': [], + 'name': 'first_checklist_answer', + 'value': 'first_checklist_answer' + }, { + 'classifications': [], + 'name': 'second_checklist_answer', + 'value': 'second_checklist_answer' + }], 'name': 'checklist-response', - 'value': 'checklist-response'}, - {'name': 'radio-response', - 'radio_answer': {'classifications': [], - 'name': 'first_radio_answer', - 'value': 'first_radio_answer'}, + 'value': 'checklist-response' + }, + { + 'name': 'radio-response', + 'radio_answer': { + 'classifications': [], + 'name': 'first_radio_answer', + 'value': 'first_radio_answer' + }, 'name': 'radio-response', - 'value': 'radio-response'}, + 'value': 'radio-response' + }, ], "relationships": [], } return expected_annotations + @pytest.fixture() def expected_export_v2_llm_prompt_creation(): expected_annotations = { "objects": [], - "classifications": [ - { - "name": "prompt-text", - "value": "prompt-text", - "text_answer": { - "content": "free form text..." - }, + "classifications": [{ + "name": "prompt-text", + "value": "prompt-text", + "text_answer": { + "content": "free form text..." }, - ], + },], "relationships": [], } return expected_annotations + @pytest.fixture() def expected_export_v2_llm_response_creation(): expected_annotations = { 'objects': [], 'relationships': [], "classifications": [ - {'name': 'response-text', - 'text_answer': {'content': 'free form text...'}, - 'value': 'response-text'}, - {'checklist_answers': [ - {'classifications': [], - 'name': 'first_checklist_answer', - 'value': 'first_checklist_answer'}, - {'classifications': [], - 'name': 'second_checklist_answer', - 'value': 'second_checklist_answer'}], - 'name': 'checklist-response', - 'value': 'checklist-response'}, - {'name': 'radio-response', - 'radio_answer': {'classifications': [], - 'name': 'first_radio_answer', - 'value': 'first_radio_answer'}, - 'name': 'radio-response', - 'value': 'radio-response'}, + { + 'name': 'response-text', + 'text_answer': { + 'content': 'free form text...' + }, + 'value': 'response-text' + }, + { + 'checklist_answers': [{ + 'classifications': [], + 'name': 'first_checklist_answer', + 'value': 'first_checklist_answer' + }, { + 'classifications': [], + 'name': 'second_checklist_answer', + 'value': 'second_checklist_answer' + }], + 'name': 'checklist-response', + 'value': 'checklist-response' + }, + { + 'name': 'radio-response', + 'radio_answer': { + 'classifications': [], + 'name': 'first_radio_answer', + 'value': 'first_radio_answer' + }, + 'name': 'radio-response', + 'value': 'radio-response' + }, ], } return expected_annotations + @pytest.fixture -def exports_v2_by_media_type( - expected_export_v2_image, - expected_export_v2_audio, - expected_export_v2_html, - expected_export_v2_text, - expected_export_v2_video, - expected_export_v2_conversation, - expected_export_v2_dicom, - expected_export_v2_document, - expected_export_v2_llm_prompt_response_creation, - expected_export_v2_llm_prompt_creation, - expected_export_v2_llm_response_creation -): +def exports_v2_by_media_type(expected_export_v2_image, expected_export_v2_audio, + expected_export_v2_html, expected_export_v2_text, + expected_export_v2_video, + expected_export_v2_conversation, + expected_export_v2_dicom, + expected_export_v2_document, + expected_export_v2_llm_prompt_response_creation, + expected_export_v2_llm_prompt_creation, + expected_export_v2_llm_response_creation): return { MediaType.Image: expected_export_v2_image, @@ -2110,8 +2182,8 @@ def exports_v2_by_media_type( OntologyKind.ResponseCreation: expected_export_v2_llm_response_creation } - - + + class Helpers: @staticmethod diff --git a/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py b/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py index 8ec2d9214..131f0d59d 100644 --- a/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py +++ b/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py @@ -10,33 +10,32 @@ from labelbox.schema.annotation_import import AnnotationImportState from labelbox import Project, Client, OntologyKind import itertools - """ - integration test for importing mal labels and ground truths with each supported MediaType. - NDJSON is used to generate annotations. """ + def validate_iso_format(date_string: str): parsed_t = datetime.datetime.fromisoformat( date_string) # this will blow up if the string is not in iso format assert parsed_t.hour is not None assert parsed_t.minute is not None assert parsed_t.second is not None - + + @pytest.mark.parametrize( "media_type, data_type_class", - [ - (MediaType.Audio, GenericDataRowData), - (MediaType.Html, GenericDataRowData), - (MediaType.Image, GenericDataRowData), - (MediaType.Text, GenericDataRowData), - (MediaType.Video, GenericDataRowData), - (MediaType.Conversational, GenericDataRowData), - (MediaType.Document, GenericDataRowData), - (MediaType.LLMPromptResponseCreation, GenericDataRowData), - (MediaType.LLMPromptCreation, GenericDataRowData), - (OntologyKind.ResponseCreation, GenericDataRowData) - ], + [(MediaType.Audio, GenericDataRowData), + (MediaType.Html, GenericDataRowData), + (MediaType.Image, GenericDataRowData), + (MediaType.Text, GenericDataRowData), + (MediaType.Video, GenericDataRowData), + (MediaType.Conversational, GenericDataRowData), + (MediaType.Document, GenericDataRowData), + (MediaType.LLMPromptResponseCreation, GenericDataRowData), + (MediaType.LLMPromptCreation, GenericDataRowData), + (OntologyKind.ResponseCreation, GenericDataRowData)], ) def test_generic_data_row_type_by_data_row_id( media_type, @@ -44,12 +43,12 @@ def test_generic_data_row_type_by_data_row_id( annotations_by_media_type, hardcoded_datarow_id, ): - annotations_ndjson = annotations_by_media_type[media_type] + annotations_ndjson = annotations_by_media_type[media_type] annotations_ndjson = [annotation[0] for annotation in annotations_ndjson] - + label = list(NDJsonConverter.deserialize(annotations_ndjson))[0] - - data_label = Label(data=data_type_class(uid = hardcoded_datarow_id()), + + data_label = Label(data=data_type_class(uid=hardcoded_datarow_id()), annotations=label.annotations) assert data_label.data.uid == label.data.uid @@ -58,18 +57,16 @@ def test_generic_data_row_type_by_data_row_id( @pytest.mark.parametrize( "media_type, data_type_class", - [ - (MediaType.Audio, GenericDataRowData), - (MediaType.Html, GenericDataRowData), - (MediaType.Image, GenericDataRowData), - (MediaType.Text, GenericDataRowData), - (MediaType.Video, GenericDataRowData), - (MediaType.Conversational, GenericDataRowData), - (MediaType.Document, GenericDataRowData), - # (MediaType.LLMPromptResponseCreation, GenericDataRowData), - # (MediaType.LLMPromptCreation, GenericDataRowData), - (OntologyKind.ResponseCreation, GenericDataRowData) - ], + [(MediaType.Audio, GenericDataRowData), + (MediaType.Html, GenericDataRowData), + (MediaType.Image, GenericDataRowData), + (MediaType.Text, GenericDataRowData), + (MediaType.Video, GenericDataRowData), + (MediaType.Conversational, GenericDataRowData), + (MediaType.Document, GenericDataRowData), + (MediaType.LLMPromptResponseCreation, GenericDataRowData), + (MediaType.LLMPromptCreation, GenericDataRowData), + (OntologyKind.ResponseCreation, GenericDataRowData)], ) def test_generic_data_row_type_by_global_key( media_type, @@ -77,12 +74,12 @@ def test_generic_data_row_type_by_global_key( annotations_by_media_type, hardcoded_global_key, ): - annotations_ndjson = annotations_by_media_type[media_type] + annotations_ndjson = annotations_by_media_type[media_type] annotations_ndjson = [annotation[0] for annotation in annotations_ndjson] - + label = list(NDJsonConverter.deserialize(annotations_ndjson))[0] - - data_label = Label(data=data_type_class(global_key = hardcoded_global_key()), + + data_label = Label(data=data_type_class(global_key=hardcoded_global_key()), annotations=label.annotations) assert data_label.data.global_key == label.data.global_key @@ -91,21 +88,16 @@ def test_generic_data_row_type_by_global_key( @pytest.mark.parametrize( "configured_project, media_type", - [ - (MediaType.Audio, MediaType.Audio), - (MediaType.Html, MediaType.Html), - (MediaType.Image, MediaType.Image), - (MediaType.Text, MediaType.Text), - (MediaType.Video, MediaType.Video), - (MediaType.Conversational, MediaType.Conversational), - (MediaType.Document, MediaType.Document), - (MediaType.Dicom, MediaType.Dicom), - # (MediaType.LLMPromptResponseCreation, MediaType.LLMPromptResponseCreation), - # (MediaType.LLMPromptCreation, MediaType.LLMPromptCreation), - (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation) - ], - indirect=["configured_project"] -) + [(MediaType.Audio, MediaType.Audio), (MediaType.Html, MediaType.Html), + (MediaType.Image, MediaType.Image), (MediaType.Text, MediaType.Text), + (MediaType.Video, MediaType.Video), + (MediaType.Conversational, MediaType.Conversational), + (MediaType.Document, MediaType.Document), + (MediaType.Dicom, MediaType.Dicom), + (MediaType.LLMPromptResponseCreation, MediaType.LLMPromptResponseCreation), + (MediaType.LLMPromptCreation, MediaType.LLMPromptCreation), + (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation)], + indirect=["configured_project"]) def test_import_media_types( client: Client, configured_project: Project, @@ -115,16 +107,19 @@ def test_import_media_types( helpers, media_type, ): - annotations_ndjson = list(itertools.chain.from_iterable(annotations_by_media_type[media_type])) + annotations_ndjson = list( + itertools.chain.from_iterable(annotations_by_media_type[media_type])) label_import = lb.LabelImport.create_from_objects( - client, configured_project.uid, f"test-import-{media_type}", annotations_ndjson) + client, configured_project.uid, f"test-import-{media_type}", + annotations_ndjson) label_import.wait_until_done() assert label_import.state == AnnotationImportState.FINISHED assert len(label_import.errors) == 0 - result = export_v2_test_helpers.run_project_export_v2_task(configured_project) + result = export_v2_test_helpers.run_project_export_v2_task( + configured_project) assert result @@ -132,58 +127,51 @@ def test_import_media_types( # timestamp fields are in iso format validate_iso_format(exported_data["data_row"]["details"]["created_at"]) validate_iso_format(exported_data["data_row"]["details"]["updated_at"]) - validate_iso_format(exported_data["projects"][configured_project.uid]["labels"][0] - ["label_details"]["created_at"]) - validate_iso_format(exported_data["projects"][configured_project.uid]["labels"][0] - ["label_details"]["updated_at"]) + validate_iso_format(exported_data["projects"][configured_project.uid] + ["labels"][0]["label_details"]["created_at"]) + validate_iso_format(exported_data["projects"][configured_project.uid] + ["labels"][0]["label_details"]["updated_at"]) - assert exported_data["data_row"]["id"] in configured_project.data_row_ids + assert exported_data["data_row"][ + "id"] in configured_project.data_row_ids exported_project = exported_data["projects"][configured_project.uid] exported_project_labels = exported_project["labels"][0] exported_annotations = exported_project_labels["annotations"] expected_data = exports_v2_by_media_type[media_type] helpers.remove_keys_recursive(exported_annotations, - ["feature_id", "feature_schema_id"]) + ["feature_id", "feature_schema_id"]) helpers.rename_cuid_key_recursive(exported_annotations) - assert exported_annotations == expected_data + assert exported_annotations == expected_data @pytest.mark.parametrize( "configured_project_by_global_key, media_type", - [ - (MediaType.Audio, MediaType.Audio), - (MediaType.Html, MediaType.Html), - (MediaType.Image, MediaType.Image), - (MediaType.Text, MediaType.Text), - (MediaType.Video, MediaType.Video), - (MediaType.Conversational, MediaType.Conversational), - (MediaType.Document, MediaType.Document), - (MediaType.Dicom, MediaType.Dicom), - (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation) - ], - indirect=["configured_project_by_global_key"] -) + [(MediaType.Audio, MediaType.Audio), (MediaType.Html, MediaType.Html), + (MediaType.Image, MediaType.Image), (MediaType.Text, MediaType.Text), + (MediaType.Video, MediaType.Video), + (MediaType.Conversational, MediaType.Conversational), + (MediaType.Document, MediaType.Document), + (MediaType.Dicom, MediaType.Dicom), + (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation)], + indirect=["configured_project_by_global_key"]) def test_import_media_types_by_global_key( - client, - configured_project_by_global_key, - annotations_by_media_type, - exports_v2_by_media_type, - export_v2_test_helpers, - helpers, - media_type - ): - annotations_ndjson = list(itertools.chain.from_iterable(annotations_by_media_type[media_type])) + client, configured_project_by_global_key, annotations_by_media_type, + exports_v2_by_media_type, export_v2_test_helpers, helpers, media_type): + annotations_ndjson = list( + itertools.chain.from_iterable(annotations_by_media_type[media_type])) label_import = lb.LabelImport.create_from_objects( - client, configured_project_by_global_key.uid, f"test-import-{media_type}", annotations_ndjson) + client, configured_project_by_global_key.uid, + f"test-import-{media_type}", annotations_ndjson) label_import.wait_until_done() assert label_import.state == AnnotationImportState.FINISHED assert len(label_import.errors) == 0 - result = export_v2_test_helpers.run_project_export_v2_task(configured_project_by_global_key) + result = export_v2_test_helpers.run_project_export_v2_task( + configured_project_by_global_key) assert result @@ -191,48 +179,44 @@ def test_import_media_types_by_global_key( # timestamp fields are in iso format validate_iso_format(exported_data["data_row"]["details"]["created_at"]) validate_iso_format(exported_data["data_row"]["details"]["updated_at"]) - validate_iso_format(exported_data["projects"][configured_project_by_global_key.uid]["labels"][0] - ["label_details"]["created_at"]) - validate_iso_format(exported_data["projects"][configured_project_by_global_key.uid]["labels"][0] - ["label_details"]["updated_at"]) - - assert exported_data["data_row"]["id"] in configured_project_by_global_key.data_row_ids - exported_project = exported_data["projects"][configured_project_by_global_key.uid] + validate_iso_format( + exported_data["projects"][configured_project_by_global_key.uid] + ["labels"][0]["label_details"]["created_at"]) + validate_iso_format( + exported_data["projects"][configured_project_by_global_key.uid] + ["labels"][0]["label_details"]["updated_at"]) + + assert exported_data["data_row"][ + "id"] in configured_project_by_global_key.data_row_ids + exported_project = exported_data["projects"][ + configured_project_by_global_key.uid] exported_project_labels = exported_project["labels"][0] exported_annotations = exported_project_labels["annotations"] expected_data = exports_v2_by_media_type[media_type] helpers.remove_keys_recursive(exported_annotations, - ["feature_id", "feature_schema_id"]) + ["feature_id", "feature_schema_id"]) helpers.rename_cuid_key_recursive(exported_annotations) - assert exported_annotations == expected_data + assert exported_annotations == expected_data @pytest.mark.parametrize( "configured_project, media_type", - [ - (MediaType.Audio, MediaType.Audio), - (MediaType.Html, MediaType.Html), - (MediaType.Image, MediaType.Image), - (MediaType.Text, MediaType.Text), - (MediaType.Video, MediaType.Video), - (MediaType.Conversational, MediaType.Conversational), - (MediaType.Document, MediaType.Document), - (MediaType.Dicom, MediaType.Dicom), - # (MediaType.LLMPromptResponseCreation, MediaType.LLMPromptResponseCreation), - # (MediaType.LLMPromptCreation, MediaType.LLMPromptCreation), - (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation) - ], - indirect=["configured_project"] -) -def test_import_mal_annotations( - client, - configured_project: Project, - annotations_by_media_type, - media_type -): - annotations_ndjson = list(itertools.chain.from_iterable(annotations_by_media_type[media_type])) + [(MediaType.Audio, MediaType.Audio), (MediaType.Html, MediaType.Html), + (MediaType.Image, MediaType.Image), (MediaType.Text, MediaType.Text), + (MediaType.Video, MediaType.Video), + (MediaType.Conversational, MediaType.Conversational), + (MediaType.Document, MediaType.Document), + (MediaType.Dicom, MediaType.Dicom), + (MediaType.LLMPromptResponseCreation, MediaType.LLMPromptResponseCreation), + (MediaType.LLMPromptCreation, MediaType.LLMPromptCreation), + (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation)], + indirect=["configured_project"]) +def test_import_mal_annotations(client, configured_project: Project, + annotations_by_media_type, media_type): + annotations_ndjson = list( + itertools.chain.from_iterable(annotations_by_media_type[media_type])) import_annotations = lb.MALPredictionImport.create_from_objects( client=client, @@ -244,29 +228,24 @@ def test_import_mal_annotations( assert import_annotations.errors == [] # MAL Labels cannot be exported and compared to input labels - + @pytest.mark.parametrize( "configured_project_by_global_key, media_type", - [ - (MediaType.Audio, MediaType.Audio), - (MediaType.Html, MediaType.Html), - (MediaType.Image, MediaType.Image), - (MediaType.Text, MediaType.Text), - (MediaType.Video, MediaType.Video), - (MediaType.Conversational, MediaType.Conversational), - (MediaType.Document, MediaType.Document), - (MediaType.Dicom, MediaType.Dicom), - (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation) - ], - indirect=["configured_project_by_global_key"] -) -def test_import_mal_annotations_global_key(client, - configured_project_by_global_key: Project, - annotations_by_media_type, - media_type): - - annotations_ndjson = list(itertools.chain.from_iterable(annotations_by_media_type[media_type])) + [(MediaType.Audio, MediaType.Audio), (MediaType.Html, MediaType.Html), + (MediaType.Image, MediaType.Image), (MediaType.Text, MediaType.Text), + (MediaType.Video, MediaType.Video), + (MediaType.Conversational, MediaType.Conversational), + (MediaType.Document, MediaType.Document), + (MediaType.Dicom, MediaType.Dicom), + (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation)], + indirect=["configured_project_by_global_key"]) +def test_import_mal_annotations_global_key( + client, configured_project_by_global_key: Project, + annotations_by_media_type, media_type): + + annotations_ndjson = list( + itertools.chain.from_iterable(annotations_by_media_type[media_type])) import_annotations = lb.MALPredictionImport.create_from_objects( client=client,