diff --git a/libs/labelbox/tests/data/annotation_import/conftest.py b/libs/labelbox/tests/data/annotation_import/conftest.py index 9ef66e61e..c7817f2e2 100644 --- a/libs/labelbox/tests/data/annotation_import/conftest.py +++ b/libs/labelbox/tests/data/annotation_import/conftest.py @@ -1,307 +1,200 @@ +import itertools import uuid +from labelbox.schema.model_run import ModelRun +from labelbox.schema.ontology import Ontology +from labelbox.schema.project import Project import pytest import time import requests from labelbox import parser, MediaType +from labelbox import Client, Dataset -from typing import Type -from labelbox.schema.labeling_frontend import LabelingFrontend +from typing import Tuple, Type from labelbox.schema.annotation_import import LabelImport, AnnotationImportState -from labelbox.schema.project import Project -from labelbox.schema.queue_mode import QueueMode +from pytest import FixtureRequest + +""" +The main fixtures of this library are configured_project and configured_project_by_global_key. Both fixtures generate data rows with a parametrize media type. They create the amount of data rows equal to the DATA_ROW_COUNT variable below. The data rows are generated with a factory fixture that returns a function that allows you to pass a global key. The ontologies are generated normalized and based on the MediaType given (i.e. only features supported by MediaType are created). This ontology is later used to obtain the correct annotations with the prediction_id_mapping and corresponding inferences. Each data row will have all possible annotations attached supported for the MediaType. +""" +DATA_ROW_COUNT = 3 DATA_ROW_PROCESSING_WAIT_TIMEOUT_SECONDS = 40 DATA_ROW_PROCESSING_WAIT_SLEEP_INTERNAL_SECONDS = 7 - -@pytest.fixture() -def audio_data_row(rand_gen): - return { - "row_data": - "https://storage.googleapis.com/labelbox-datasets/audio-sample-data/sample-audio-1.mp3", - "global_key": - f"https://storage.googleapis.com/labelbox-datasets/audio-sample-data/sample-audio-1.mp3-{rand_gen(str)}", - "media_type": - "AUDIO", - } - - -@pytest.fixture() -def conversation_data_row(rand_gen): - return { - "row_data": - "https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json", - "global_key": - f"https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json-{rand_gen(str)}", - } - - -@pytest.fixture() -def dicom_data_row(rand_gen): - return { - "row_data": - "https://storage.googleapis.com/labelbox-datasets/dicom-sample-data/sample-dicom-1.dcm", - "global_key": - f"https://storage.googleapis.com/labelbox-datasets/dicom-sample-data/sample-dicom-1.dcm-{rand_gen(str)}", - "media_type": - "DICOM", - } - - -@pytest.fixture() -def geospatial_data_row(rand_gen): - return { - "row_data": { - "tile_layer_url": - "https://s3-us-west-1.amazonaws.com/lb-tiler-layers/mexico_city/{z}/{x}/{y}.png", - "bounds": [ - [19.405662413477728, -99.21052827588443], - [19.400498983095076, -99.20534818927473], - ], - "min_zoom": - 12, - "max_zoom": - 20, - "epsg": - "EPSG4326", - }, - "global_key": - f"https://s3-us-west-1.amazonaws.com/lb-tiler-layers/mexico_city/z/x/y.png-{rand_gen(str)}", - "media_type": - "TMS_GEO", - } - - -@pytest.fixture() -def html_data_row(rand_gen): - return { - "row_data": - "https://storage.googleapis.com/labelbox-datasets/html_sample_data/sample_html_1.html", - "global_key": - f"https://storage.googleapis.com/labelbox-datasets/html_sample_data/sample_html_1.html-{rand_gen(str)}", - } - - -@pytest.fixture() -def image_data_row(rand_gen): - return { - "row_data": - "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg", - "global_key": - f"https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-{rand_gen(str)}", - "media_type": - "IMAGE", - } - - -@pytest.fixture() -def document_data_row(rand_gen): - return { - "row_data": { - "pdf_url": - "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf", - "text_layer_url": - "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483-lb-textlayer.json", - }, - "global_key": - f"https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf-{rand_gen(str)}", - "media_type": - "PDF", - } - - -@pytest.fixture() -def text_data_row(rand_gen): - return { - "row_data": - "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/sample-text-2.txt", - "global_key": - f"https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/sample-text-2.txt-{rand_gen(str)}", - "media_type": - "TEXT", - } - - -@pytest.fixture() -def llm_prompt_creation_data_row(rand_gen): - return { - "row_data": { - "type": "application/llm.prompt-creation", - "version": 1 - }, - "global_key": rand_gen(str), - } - - -@pytest.fixture() -def llm_prompt_response_data_row(rand_gen): - return { - "row_data": { - "type": "application/llm.prompt-response-creation", - "version": 1 - }, - "global_key": rand_gen(str), - } - - -@pytest.fixture -def data_row_json_by_data_type( - audio_data_row, - conversation_data_row, - dicom_data_row, - geospatial_data_row, - html_data_row, - image_data_row, - document_data_row, - text_data_row, - video_data_row, - llm_prompt_creation_data_row, - llm_prompt_response_data_row, -): - return { - "audio": audio_data_row, - "conversation": conversation_data_row, - "dicom": dicom_data_row, - "geospatial": geospatial_data_row, - "html": html_data_row, - "image": image_data_row, - "document": document_data_row, - "text": text_data_row, - "video": video_data_row, - "llmpromptcreation": llm_prompt_creation_data_row, - "llmpromptresponsecreation": llm_prompt_response_data_row, - "llmresponsecreation": text_data_row, - } - - -@pytest.fixture -def exports_v2_by_data_type( - expected_export_v2_image, - expected_export_v2_audio, - expected_export_v2_html, - expected_export_v2_text, - expected_export_v2_video, - expected_export_v2_conversation, - expected_export_v2_dicom, - expected_export_v2_document, - expected_export_v2_llm_prompt_creation, - expected_export_v2_llm_prompt_response_creation, - expected_export_v2_llm_response_creation, -): - return { - "image": - expected_export_v2_image, - "audio": - expected_export_v2_audio, - "html": - expected_export_v2_html, - "text": - expected_export_v2_text, - "video": - expected_export_v2_video, - "conversation": - expected_export_v2_conversation, - "dicom": - expected_export_v2_dicom, - "document": - expected_export_v2_document, - "llmpromptcreation": - expected_export_v2_llm_prompt_creation, - "llmpromptresponsecreation": - expected_export_v2_llm_prompt_response_creation, - "llmresponsecreation": - expected_export_v2_llm_response_creation, - } - - -@pytest.fixture -def annotations_by_data_type( - polygon_inference, - rectangle_inference, - rectangle_inference_document, - line_inference, - entity_inference, - entity_inference_document, - checklist_inference, - text_inference, - video_checklist_inference, -): - return { - "audio": [checklist_inference, text_inference], - "conversation": [checklist_inference, text_inference, entity_inference], - "dicom": [line_inference], - "document": [ - entity_inference_document, - checklist_inference, - text_inference, - rectangle_inference_document, - ], - "html": [text_inference, checklist_inference], - "image": [ - polygon_inference, - rectangle_inference, - line_inference, - checklist_inference, - text_inference, - ], - "text": [entity_inference, checklist_inference, text_inference], - "video": [video_checklist_inference], - "llmpromptcreation": [checklist_inference, text_inference], - "llmpromptresponsecreation": [checklist_inference, text_inference], - "llmresponsecreation": [checklist_inference, text_inference], - } +@pytest.fixture(scope="module", autouse=True) +def video_data_row_factory(): + def video_data_row(global_key): + return { + "row_data": + "https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4", + "global_key": + f"https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4-{global_key}", + "media_type": + "VIDEO", + } + return video_data_row + +@pytest.fixture(scope="module", autouse=True) +def audio_data_row_factory(): + def audio_data_row(global_key): + return { + "row_data": + "https://storage.googleapis.com/labelbox-datasets/audio-sample-data/sample-audio-1.mp3", + "global_key": + f"https://storage.googleapis.com/labelbox-datasets/audio-sample-data/sample-audio-1.mp3-{global_key}", + "media_type": + "AUDIO", + } + return audio_data_row + +@pytest.fixture(scope="module", autouse=True) +def conversational_data_row_factory(): + def conversational_data_row(global_key): + return { + "row_data": + "https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json", + "global_key": + f"https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json-{global_key}", + } + return conversational_data_row + +@pytest.fixture(scope="module", autouse=True) +def dicom_data_row_factory(): + def dicom_data_row(global_key): + return { + "row_data": + "https://storage.googleapis.com/labelbox-datasets/dicom-sample-data/sample-dicom-1.dcm", + "global_key": + f"https://storage.googleapis.com/labelbox-datasets/dicom-sample-data/sample-dicom-1.dcm-{global_key}", + "media_type": + "DICOM", + } + return dicom_data_row + +@pytest.fixture(scope="module", autouse=True) +def geospatial_data_row_factory(): + def geospatial_data_row(global_key): + return { + "row_data": { + "tile_layer_url": + "https://s3-us-west-1.amazonaws.com/lb-tiler-layers/mexico_city/{z}/{x}/{y}.png", + "bounds": [ + [19.405662413477728, -99.21052827588443], + [19.400498983095076, -99.20534818927473], + ], + "min_zoom": + 12, + "max_zoom": + 20, + "epsg": + "EPSG4326", + }, + "global_key": + f"https://s3-us-west-1.amazonaws.com/lb-tiler-layers/mexico_city/z/x/y.png-{global_key}", + "media_type": + "TMS_GEO", + } + return geospatial_data_row -@pytest.fixture -def annotations_by_data_type_v2( - polygon_inference, - rectangle_inference, - rectangle_inference_document, - line_inference_v2, - line_inference, - entity_inference, - entity_inference_index, - entity_inference_document, - checklist_inference_index, - text_inference_index, - checklist_inference, - text_inference, - video_checklist_inference, +@pytest.fixture(scope="module", autouse=True) +def html_data_row_factory(): + def html_data_row(global_key): + return { + "row_data": + "https://storage.googleapis.com/labelbox-datasets/html_sample_data/sample_html_1.html", + "global_key": + f"https://storage.googleapis.com/labelbox-datasets/html_sample_data/sample_html_1.html-{global_key}", + } + return html_data_row + + +@pytest.fixture(scope="module", autouse=True) +def image_data_row_factory(): + def image_data_row(global_key): + return { + "row_data": + "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg", + "global_key": + f"https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-{global_key}", + "media_type": + "IMAGE", + } + return image_data_row + + +@pytest.fixture(scope="module", autouse=True) +def document_data_row_factory(): + def document_data_row(global_key): + return { + "row_data": { + "pdf_url": + "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf", + "text_layer_url": + "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483-lb-textlayer.json", + }, + "global_key": + f"https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf-{global_key}", + "media_type": + "PDF", + } + return document_data_row + + +@pytest.fixture(scope="module", autouse=True) +def text_data_row_factory(): + def text_data_row(global_key): + return { + "row_data": + "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/sample-text-2.txt", + "global_key": + f"https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/sample-text-2.txt-{global_key}", + "media_type": + "TEXT", + } + return text_data_row + +@pytest.fixture(scope="module", autouse=True) +def llm_human_preference_data_row_factory(): + def llm_human_preference_data_row(global_key): + return { + "row_data": "https://storage.googleapis.com/labelbox-datasets/sdk_test/llm_prompt_response_conv.json", + "global_key": global_key, + } + return llm_human_preference_data_row + + +@pytest.fixture(scope="module", autouse=True) +def data_row_json_by_media_type( + audio_data_row_factory, + conversational_data_row_factory, + dicom_data_row_factory, + geospatial_data_row_factory, + html_data_row_factory, + image_data_row_factory, + document_data_row_factory, + text_data_row_factory, + video_data_row_factory, ): return { - "audio": [checklist_inference, text_inference], - "conversation": [ - checklist_inference_index, - text_inference_index, - entity_inference_index, - ], - "dicom": [line_inference_v2], - "document": [ - entity_inference_document, - checklist_inference, - text_inference, - rectangle_inference_document, - ], - "html": [text_inference, checklist_inference], - "image": [ - polygon_inference, - rectangle_inference, - line_inference, - checklist_inference, - text_inference, - ], - "text": [entity_inference, checklist_inference, text_inference], - "video": [video_checklist_inference], - "llmpromptcreation": [checklist_inference, text_inference], - "llmpromptresponsecreation": [checklist_inference, text_inference], - "llmresponsecreation": [checklist_inference, text_inference], + MediaType.Audio: audio_data_row_factory, + MediaType.Conversational: conversational_data_row_factory, + MediaType.Dicom: dicom_data_row_factory, + MediaType.Geospatial_Tile: geospatial_data_row_factory, + MediaType.Html: html_data_row_factory, + MediaType.Image: image_data_row_factory, + MediaType.Document: document_data_row_factory, + MediaType.Text: text_data_row_factory, + MediaType.Video: video_data_row_factory, } + - -@pytest.fixture(scope="session") -def ontology(): +@pytest.fixture(scope="module", autouse=True) +def normalized_ontology_by_media_type(): + """Returns NDJSON of ontology based on media type""" + bbox_tool_with_nested_text = { "required": False, @@ -322,7 +215,7 @@ def ontology(): "radio", "options": [{ "label": - "radio_option_1", + "radio_value_1", "value": "radio_value_1", "options": [ @@ -338,12 +231,12 @@ def ontology(): "options": [ { "label": "nested_checkbox_option_1", - "value": "nested_checkbox_value_1", + "value": "nested_checkbox_option_1", "options": [], }, { "label": "nested_checkbox_option_2", - "value": "nested_checkbox_value_2", + "value": "nested_checkbox_option_2", }, ], }, @@ -368,43 +261,7 @@ def ontology(): "rectangle", "color": "#a23030", - "classifications": [{ - "required": - False, - "instructions": - "nested", - "name": - "nested", - "type": - "radio", - "options": [{ - "label": - "radio_option_1", - "value": - "radio_value_1", - "options": [{ - "required": - False, - "instructions": - "nested_checkbox", - "name": - "nested_checkbox", - "type": - "checklist", - "options": [ - { - "label": "nested_checkbox_option_1", - "value": "nested_checkbox_value_1", - "options": [], - }, - { - "label": "nested_checkbox_option_2", - "value": "nested_checkbox_value_2", - }, - ], - }], - },], - }], + "classifications": [], } polygon_tool = { @@ -430,18 +287,11 @@ def ontology(): } entity_tool = { "required": False, - "name": "entity--", + "name": "named-entity", "tool": "named-entity", "color": "#006FA6", "classifications": [], } - segmentation_tool = { - "required": False, - "name": "segmentation--", - "tool": "superpixel", - "color": "#A30059", - "classifications": [], - } raster_segmentation_tool = { "required": False, "name": "segmentation_mask", @@ -449,6 +299,13 @@ def ontology(): "color": "#ff0000", "classifications": [], } + segmentation_tool = { + "required": False, + "name": "segmentation--", + "tool": "superpixel", + "color": "#A30059", + "classifications": [], + } checklist = { "required": False, @@ -460,16 +317,12 @@ def ontology(): "checklist", "options": [ { - "label": "option1", - "value": "option1" - }, - { - "label": "option2", - "value": "option2" + "label": "first_checklist_answer", + "value": "first_checklist_answer" }, { - "label": "optionN", - "value": "optionn" + "label": "second_checklist_answer", + "value": "second_checklist_answer" }, ], } @@ -486,16 +339,12 @@ def ontology(): "index", "options": [ { - "label": "option1_index", - "value": "option1_index" - }, - { - "label": "option2_index", - "value": "option2_index" + "label": "first_checklist_answer", + "value": "first_checklist_answer" }, { - "label": "optionN_index", - "value": "optionn_index" + "label": "second_checklist_answer", + "value": "second_checklist_answer" }, ], } @@ -536,33 +385,130 @@ def ontology(): }, ], } - named_entity = { - "tool": "named-entity", - "name": "named-entity", - "required": False, - "color": "#A30059", - "classifications": [], - } - tools = [ - bbox_tool, - bbox_tool_with_nested_text, - polygon_tool, - polyline_tool, - point_tool, - entity_tool, - segmentation_tool, - raster_segmentation_tool, - named_entity, - ] - classifications = [ - checklist, - checklist_index, - free_form_text, - free_form_text_index, - radio, - ] - return {"tools": tools, "classifications": classifications} + return { + MediaType.Image: { + "tools": [ + bbox_tool, + bbox_tool_with_nested_text, + polygon_tool, + polyline_tool, + point_tool, + raster_segmentation_tool, + ], + "classifications": [ + checklist, + free_form_text, + radio, + ] + }, + MediaType.Text: { + "tools": [ + entity_tool + ], + "classifications": [ + checklist, + free_form_text, + radio, + ] + }, + MediaType.Video: { + "tools": [ + bbox_tool, + bbox_tool_with_nested_text, + polyline_tool, + point_tool, + raster_segmentation_tool, + ], + "classifications": [ + checklist, + free_form_text, + radio, + checklist_index, + free_form_text_index + ] + }, + MediaType.Geospatial_Tile: { + "tools": [ + bbox_tool, + bbox_tool_with_nested_text, + polygon_tool, + polyline_tool, + point_tool, + ], + "classifications": [ + checklist, + free_form_text, + radio, + ] + }, + MediaType.Document: { + "tools": [ + entity_tool, + bbox_tool, + bbox_tool_with_nested_text + ], + "classifications": [ + checklist, + free_form_text, + radio, + ] + }, + MediaType.Audio: { + "tools":[], + "classifications": [ + checklist, + free_form_text, + radio, + ] + }, + MediaType.Html: { + "tools": [], + "classifications": [ + checklist, + free_form_text, + radio, + ] + }, + MediaType.Dicom: { + "tools": [ + raster_segmentation_tool, + polyline_tool + ], + "classifications": [] + }, + MediaType.Conversational: { + "tools": [ + entity_tool + ], + "classifications": [ + checklist, + free_form_text, + radio, + checklist_index, + free_form_text_index + ] + }, + "all": { + "tools":[ + bbox_tool, + bbox_tool_with_nested_text, + polygon_tool, + polyline_tool, + point_tool, + entity_tool, + segmentation_tool, + raster_segmentation_tool, + ], + "classifications": [ + checklist, + checklist_index, + free_form_text, + free_form_text_index, + radio, + ] + } + } @pytest.fixture @@ -591,48 +537,55 @@ def func(project): return func +##### Unit test strategies ##### + @pytest.fixture -def configured_project_datarow_id(configured_project): +def hardcoded_datarow_id(): + data_row_id = 'ck8q9q9qj00003g5z3q1q9q9q' - def get_data_row_id(indx=0): - return configured_project.data_row_ids[indx] + def get_data_row_id(): + return data_row_id yield get_data_row_id @pytest.fixture -def configured_project_one_datarow_id(configured_project_with_one_data_row): +def hardcoded_global_key(): + global_key = str(uuid.uuid4()) - def get_data_row_id(indx=0): - return configured_project_with_one_data_row.data_row_ids[0] + def get_global_key(): + return global_key - yield get_data_row_id + yield get_global_key -#TODO: Switch to connect_ontology, setup might get removed in later releases -@pytest.fixture -def configured_project(client, initial_dataset, ontology, rand_gen, image_url): - dataset = initial_dataset - project = client.create_project(name=rand_gen(str), - queue_mode=QueueMode.Batch) - editor = list( - client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] - project.setup(editor, ontology) - data_row_ids = [] - - ontologies = ontology["tools"] + ontology["classifications"] +##### Integration test strategies ##### + +def _create_project(client: Client, rand_gen, data_row_json_by_media_type, media_type, normalized_ontology_by_media_type) -> Tuple[Project, Ontology, Dataset]: + """ Shared function to configure project for integration tests """ + + dataset = client.create_dataset(name=rand_gen(str)) + + # LLMPromptResponseCreation is not support for project or ontology creation needs to be conversational + if media_type == MediaType.LLMPromptResponseCreation: + media_type = MediaType.Conversational + + project = client.create_project(name=f"{media_type}-{rand_gen(str)}", + media_type=media_type) + + ontology = client.create_ontology(name=f"{media_type}-{rand_gen(str)}", normalized=normalized_ontology_by_media_type[media_type], media_type=media_type) + + project.connect_ontology(ontology) + data_row_data = [] - for ind in range(len(ontologies)): - data_row_data.append({ - "row_data": image_url, - "global_key": f"gk_{ontologies[ind]['name']}_{rand_gen(str)}" - }) + + for _ in range(DATA_ROW_COUNT): + data_row_data.append(data_row_json_by_media_type[media_type](rand_gen(str))) + task = dataset.create_data_rows(data_row_data) task.wait_till_done() + global_keys = [row['global_key'] for row in task.result] data_row_ids = [row['id'] for row in task.result] - project._wait_until_data_rows_are_processed(data_row_ids=data_row_ids, - sleep_interval=3) project.create_batch( rand_gen(str), @@ -640,520 +593,580 @@ def configured_project(client, initial_dataset, ontology, rand_gen, image_url): 5, # priority between 1(Highest) - 5(lowest) ) project.data_row_ids = data_row_ids + project.global_keys = global_keys + + return project, ontology, dataset - yield project - - project.delete() - -#TODO: Switch to connect_ontology, setup might get removed in later releases @pytest.fixture -def project_with_ontology(client, configured_project, ontology, rand_gen): - project = client.create_project(name=rand_gen(str), - queue_mode=QueueMode.Batch, - media_type=MediaType.Image) - editor = list( - client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] - project.setup(editor, ontology) - - yield project, ontology - - project.delete() +def configured_project(client: Client, rand_gen, data_row_json_by_media_type, request: FixtureRequest, normalized_ontology_by_media_type): + """Configure project for test. Request.param will contain the media type if not present will use Image MediaType. The project will have 10 data rows.""" + + media_type = getattr(request, "param", MediaType.Image) + + project, ontology, dataset = _create_project(client, rand_gen, data_row_json_by_media_type, media_type, normalized_ontology_by_media_type) - -#TODO: Switch to connect_ontology, setup might get removed in later releases -@pytest.fixture -def configured_project_pdf(client, ontology, rand_gen, pdf_url): - project = client.create_project(name=rand_gen(str), - queue_mode=QueueMode.Batch, - media_type=MediaType.Pdf) - dataset = client.create_dataset(name=rand_gen(str)) - editor = list( - client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] - project.setup(editor, ontology) - data_row = dataset.create_data_row(pdf_url) - data_row_ids = [data_row.uid] - project.create_batch( - rand_gen(str), - data_row_ids, # sample of data row objects - 5, # priority between 1(Highest) - 5(lowest) - ) - project.data_row_ids = data_row_ids yield project + project.delete() dataset.delete() + client.delete_unused_ontology(ontology.uid) -@pytest.fixture -def dataset_pdf_entity(client, rand_gen, document_data_row): - dataset = client.create_dataset(name=rand_gen(str)) - data_row_ids = [] - data_row = dataset.create_data_row(document_data_row) - data_row_ids.append(data_row.uid) - yield dataset, data_row_ids - dataset.delete() - - -@pytest.fixture -def dataset_conversation_entity(client, rand_gen, conversation_entity_data_row, - wait_for_data_row_processing): +@pytest.fixture() +def configured_project_by_global_key(client: Client, rand_gen, data_row_json_by_media_type, request: FixtureRequest, normalized_ontology_by_media_type): + """Does the same thing as configured project but with global keys focus.""" + dataset = client.create_dataset(name=rand_gen(str)) - data_row_ids = [] - data_row = dataset.create_data_row(conversation_entity_data_row) - data_row = wait_for_data_row_processing(client, data_row) + + media_type = getattr(request, "param", MediaType.Image) + + project, ontology, dataset = _create_project(client, rand_gen, data_row_json_by_media_type, media_type, normalized_ontology_by_media_type) - data_row_ids.append(data_row.uid) - yield dataset, data_row_ids + yield project + + project.delete() dataset.delete() + client.delete_unused_ontology(ontology.uid) -@pytest.fixture -def configured_project_with_one_data_row(client, ontology, rand_gen, - initial_dataset, image_url): - project = client.create_project(name=rand_gen(str), - description=rand_gen(str), - queue_mode=QueueMode.Batch) - editor = list( - client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] - project.setup(editor, ontology) - - data_row = initial_dataset.create_data_row(row_data=image_url) - data_row_ids = [data_row.uid] - project._wait_until_data_rows_are_processed(data_row_ids=data_row_ids, - sleep_interval=3) - - batch = project.create_batch( - rand_gen(str), - data_row_ids, # sample of data row objects - 5, # priority between 1(Highest) - 5(lowest) - ) - project.data_row_ids = data_row_ids - yield project +@pytest.fixture(scope="module") +def module_project(client: Client, rand_gen, data_row_json_by_media_type, request: FixtureRequest, normalized_ontology_by_media_type): + """Generates a image project that scopes to the test module(file). Used to reduce api calls.""" + + media_type = getattr(request, "param", MediaType.Image) + + project, ontology, dataset = _create_project(client, rand_gen, data_row_json_by_media_type, media_type, normalized_ontology_by_media_type) - batch.delete() + yield project + project.delete() - - -# This function allows to convert an ontology feature to actual annotation -# At the moment it expects only one feature per tool type and this creates unnecessary coupling between differet tests -# In an example of a 'rectangle' we have extended to support multiple instances of the same tool type -# TODO: we will support this approach in the future for all tools -# -""" -Please note that this fixture now offers the flexibility to configure three different strategies for generating data row ids for predictions: -Default(configured_project fixture): - configured_project that generates a data row for each member of ontology. - This makes sure each prediction has its own data row id. This is applicable to prediction upload cases when last label overwrites existing ones - -Optimized Strategy (configured_project_with_one_data_row fixture): - This fixture has only one data row and all predictions will be mapped to it - -Custom Data Row IDs Strategy: - Individuals can supply hard-coded data row ids when a creation of data row is not required. - This particular fixture, termed "hardcoded_datarow_id," should be defined locally within a test file. - In the future, we can use this approach to inject correct number of rows instead of using configured_project fixture - that creates a data row for each member of ontology (14 in total) for each run. -""" + dataset.delete() + client.delete_unused_ontology(ontology.uid) @pytest.fixture -def prediction_id_mapping(ontology, request): - # Maps tool types to feature schema ids +def prediction_id_mapping(request, normalized_ontology_by_media_type): + """Creates the base of annotation based on tools inside project ontology. We would want only annotations supported for the MediaType of the ontology and project. Annotations are generated for each data row created later be combined inside the test file. This serves as the base fixture for all the interference (annotations) fixture. This fixtures supports a few strategies: + + Integration test: + configured_project: generates data rows with data row id focus. + configured_project_by_global_key: generates data rows with global key focus. + module_configured_project: configured project but scoped to test module. + + Unit tests + Individuals can supply hard-coded data row ids or global keys without configured a project must include a media type fixture to get the appropriate annotations. + + Each strategy provides a few items. + + Labelbox Project (unit testing strategies do not make api calls so will have None for project) + Data row identifiers (ids the annotation uses) + Ontology: normalized ontology + """ + if "configured_project" in request.fixturenames: - data_row_id_factory = request.getfixturevalue( - "configured_project_datarow_id") project = request.getfixturevalue("configured_project") + data_row_identifiers = [{"id": data_row_id} for data_row_id in project.data_row_ids] + ontology = project.ontology().normalized + + elif "configured_project_by_global_key" in request.fixturenames: + project = request.getfixturevalue("configured_project_by_global_key") + data_row_identifiers = [{"globalKey": global_key} for global_key in project.global_keys] + ontology = project.ontology().normalized + + elif "module_project" in request.fixturenames: + project = request.getfixturevalue("module_project") + data_row_identifiers = [{"id": data_row_id} for data_row_id in project.data_row_ids] + ontology = project.ontology().normalized + elif "hardcoded_datarow_id" in request.fixturenames: - data_row_id_factory = request.getfixturevalue("hardcoded_datarow_id") - project = request.getfixturevalue("configured_project_with_ontology") + if "media_type" not in request.fixturenames: + raise Exception("Please include a 'media_type' fixture") + project = None + media_type = request.getfixturevalue("media_type") + ontology = normalized_ontology_by_media_type[media_type] + data_row_identifiers = [{"id": request.getfixturevalue("hardcoded_datarow_id")()}] + + elif "hardcoded_global_key" in request.fixturenames: + if "media_type" not in request.fixturenames: + raise Exception("Please include a 'media_type' fixture") + project = None + media_type = request.getfixturevalue("media_type") + ontology = normalized_ontology_by_media_type[media_type] + data_row_identifiers = [{"globalKey": request.getfixturevalue("hardcoded_global_key")()}] + + # Used for tests that need access to every ontology else: - data_row_id_factory = request.getfixturevalue( - "configured_project_one_datarow_id") - project = request.getfixturevalue( - "configured_project_with_one_data_row") - - ontology = project.ontology().normalized - - result = {} - - for idx, tool in enumerate(ontology["tools"] + ontology["classifications"]): - if "tool" in tool: - tool_type = tool["tool"] - else: - tool_type = (tool["type"] if "scope" not in tool else - f"{tool['type']}_{tool['scope']}" - ) # so 'checklist' of 'checklist_index' - - # TODO: remove this once we have a better way to associate multiple tools instances with a single tool type - if tool_type == "rectangle": - value = { - "uuid": str(uuid.uuid4()), - "schemaId": tool["featureSchemaId"], - "name": tool["name"], - "dataRow": { - "id": data_row_id_factory(idx), - }, - "tool": tool, - } - if tool_type not in result: - result[tool_type] = [] - result[tool_type].append(value) - else: - result[tool_type] = { + project = None + media_type = None + ontology = normalized_ontology_by_media_type["all"] + data_row_identifiers = [{"id":"ck8q9q9qj00003g5z3q1q9q9q"}] + + base_annotations = [] + for data_row_identifier in data_row_identifiers: + base_annotation = {} + for feature in (ontology["tools"] + ontology["classifications"]): + if "tool" in feature: + feature_type = (feature["tool"] if feature["classifications"] == [] else + f"{feature['tool']}_nested" + ) # tool vs nested classification tool + else: + feature_type = (feature["type"] if "scope" not in feature else + f"{feature['type']}_{feature['scope']}" + ) # checklist vs indexed checklist + + base_annotation[feature_type] = { "uuid": str(uuid.uuid4()), - "schemaId": tool["featureSchemaId"], - "name": tool["name"], - "dataRow": { - "id": data_row_id_factory(idx), - }, - "tool": tool, - } - return result + "name": feature["name"], + "tool": feature, + "dataRow": data_row_identifier + } + base_annotations.append(base_annotation) + return base_annotations + +# Each inference represents a feature type that adds to the base annotation created with prediction_id_mapping @pytest.fixture def polygon_inference(prediction_id_mapping): - polygon = prediction_id_mapping["polygon"].copy() - polygon.update({ - "polygon": [ - { - "x": 147.692, - "y": 118.154 - }, - { - "x": 142.769, - "y": 104.923 - }, - { - "x": 57.846, - "y": 118.769 - }, - { - "x": 28.308, - "y": 169.846 - }, - ] - }) - del polygon["tool"] - return polygon - - -def find_tool_by_name(tool_instances, name): - for tool in tool_instances: - if tool["name"] == name: - return tool - return None + polygons = [] + for feature in prediction_id_mapping: + if "polygon" not in feature: + continue + polygon = feature["polygon"].copy() + polygon.update({ + "polygon": [ + { + "x": 147.692, + "y": 118.154 + }, + { + "x": 142.769, + "y": 104.923 + }, + { + "x": 57.846, + "y": 118.769 + }, + { + "x": 28.308, + "y": 169.846 + }, + ] + }) + del polygon["tool"] + polygons.append(polygon) + return polygons @pytest.fixture def rectangle_inference(prediction_id_mapping): - tool_instance = find_tool_by_name(prediction_id_mapping["rectangle"], - "bbox") - rectangle = tool_instance.copy() - rectangle.update({ - "bbox": { - "top": 48, - "left": 58, - "height": 65, - "width": 12 - }, - "classifications": [{ - "schemaId": - rectangle["tool"]["classifications"][0]["featureSchemaId"], - "name": - rectangle["tool"]["classifications"][0]["name"], - "answer": { - "schemaId": - rectangle["tool"]["classifications"][0]["options"][0] - ["featureSchemaId"], - "name": - rectangle["tool"]["classifications"][0]["options"][0] - ["value"], - "customMetrics": [{ - "name": "customMetric1", - "value": 0.4 - }], + rectangles = [] + for feature in prediction_id_mapping: + if "rectangle" not in feature: + continue + rectangle = feature["rectangle"].copy() + rectangle.update({ + "bbox": { + "top": 48, + "left": 58, + "height": 65, + "width": 12 }, - }], - }) - del rectangle["tool"] - return rectangle + }) + del rectangle["tool"] + rectangles.append(rectangle) + return rectangles @pytest.fixture def rectangle_inference_with_confidence(prediction_id_mapping): - tool_instance = find_tool_by_name(prediction_id_mapping["rectangle"], - "bbox_tool_with_nested_text") - rectangle = tool_instance.copy() - rectangle.update({ - "bbox": { - "top": 48, - "left": 58, - "height": 65, - "width": 12 - }, - "classifications": [{ - "schemaId": - rectangle["tool"]["classifications"][0]["featureSchemaId"], - "name": - rectangle["tool"]["classifications"][0]["name"], - "answer": { - "schemaId": - rectangle["tool"]["classifications"][0]["options"][0] - ["featureSchemaId"], + rectangles = [] + for feature in prediction_id_mapping: + if "rectangle_nested" not in feature: + continue + rectangle = feature["rectangle_nested"].copy() + print(rectangle) + rectangle.update({ + "bbox": { + "top": 48, + "left": 58, + "height": 65, + "width": 12 + }, + "classifications": [{ "name": - rectangle["tool"]["classifications"][0]["options"][0] - ["value"], - "classifications": [{ - "schemaId": - rectangle["tool"]["classifications"][0]["options"][0] - ["options"][1]["featureSchemaId"], + rectangle["tool"]["classifications"][0]["name"], + "answer": { "name": rectangle["tool"]["classifications"][0]["options"][0] - ["options"][1]["name"], - "answer": - "nested answer", - }], - }, - }], - }) + ["value"], + "classifications": [{ + "name": + rectangle["tool"]["classifications"][0]["options"][0] + ["options"][1]["name"], + "answer": + "nested answer", + }], + }, + }], + }) - rectangle.update({"confidence": 0.9}) - rectangle["classifications"][0]["answer"]["confidence"] = 0.8 - rectangle["classifications"][0]["answer"]["classifications"][0][ - "confidence"] = 0.7 + rectangle.update({"confidence": 0.9}) + rectangle["classifications"][0]["answer"]["confidence"] = 0.8 + rectangle["classifications"][0]["answer"]["classifications"][0][ + "confidence"] = 0.7 - del rectangle["tool"] - return rectangle + del rectangle["tool"] + rectangles.append(rectangle) + return rectangles @pytest.fixture def rectangle_inference_document(rectangle_inference): - rectangle = rectangle_inference.copy() - rectangle.update({"page": 1, "unit": "POINTS"}) - return rectangle + rectangles = [] + for feature in rectangle_inference: + rectangle = feature.copy() + rectangle.update({"page": 1, "unit": "POINTS"}) + rectangles.append(rectangle) + return rectangles @pytest.fixture def line_inference(prediction_id_mapping): - line = prediction_id_mapping["line"].copy() - line.update( - {"line": [{ - "x": 147.692, - "y": 118.154 - }, { - "x": 150.692, - "y": 160.154 - }]}) - del line["tool"] - return line + lines = [] + for feature in prediction_id_mapping: + if "line" not in feature: + continue + line = feature["line"].copy() + line.update( + {"line": [{ + "x": 147.692, + "y": 118.154 + }, { + "x": 150.692, + "y": 160.154 + }]}) + del line["tool"] + lines.append(line) + return lines @pytest.fixture def line_inference_v2(prediction_id_mapping): - line = prediction_id_mapping["line"].copy() - line_data = { - "groupKey": - "axial", - "segments": [{ - "keyframes": [{ - "frame": - 1, - "line": [ - { - "x": 147.692, - "y": 118.154 - }, - { - "x": 150.692, - "y": 160.154 - }, - ], - }] - },], - } - line.update(line_data) - del line["tool"] - return line + lines = [] + for feature in prediction_id_mapping: + if "line" not in feature: + continue + line = feature["line"].copy() + line_data = { + "groupKey": + "axial", + "segments": [{ + "keyframes": [{ + "frame": + 1, + "line": [ + { + "x": 147.692, + "y": 118.154 + }, + { + "x": 150.692, + "y": 160.154 + }, + ], + }] + },], + } + line.update(line_data) + del line["tool"] + lines.append(line) + return lines @pytest.fixture def point_inference(prediction_id_mapping): - point = prediction_id_mapping["point"].copy() - point.update({"point": {"x": 147.692, "y": 118.154}}) - del point["tool"] - return point + points = [] + for feature in prediction_id_mapping: + if "point" not in feature: + continue + point = feature["point"].copy() + point.update({"point": {"x": 147.692, "y": 118.154}}) + del point["tool"] + points.append(point) + return points @pytest.fixture def entity_inference(prediction_id_mapping): - entity = prediction_id_mapping["named-entity"].copy() - entity.update({"location": {"start": 112, "end": 128}}) - del entity["tool"] - return entity + named_entities = [] + for feature in prediction_id_mapping: + if "named-entity" not in feature: + continue + entity = feature["named-entity"].copy() + entity.update({"location": {"start": 112, "end": 128}}) + del entity["tool"] + named_entities.append(entity) + return named_entities @pytest.fixture def entity_inference_index(prediction_id_mapping): - entity = prediction_id_mapping["named-entity"].copy() - entity.update({ - "location": { - "start": 0, - "end": 8 - }, - "messageId": "0", - }) - - del entity["tool"] - return entity + named_entities = [] + for feature in prediction_id_mapping: + if "named-entity" not in feature: + continue + entity = feature["named-entity"].copy() + entity.update({ + "location": { + "start": 0, + "end": 8 + }, + "messageId": "0", + }) + del entity["tool"] + named_entities.append(entity) + return named_entities @pytest.fixture def entity_inference_document(prediction_id_mapping): - entity = prediction_id_mapping["named-entity"].copy() - document_selections = { - "textSelections": [{ - "tokenIds": [ - "3f984bf3-1d61-44f5-b59a-9658a2e3440f", - "3bf00b56-ff12-4e52-8cc1-08dbddb3c3b8", - "6e1c3420-d4b7-4c5a-8fd6-ead43bf73d80", - "87a43d32-af76-4a1d-b262-5c5f4d5ace3a", - "e8606e8a-dfd9-4c49-a635-ad5c879c75d0", - "67c7c19e-4654-425d-bf17-2adb8cf02c30", - "149c5e80-3e07-49a7-ab2d-29ddfe6a38fa", - "b0e94071-2187-461e-8e76-96c58738a52c", - ], - "groupId": "2f4336f4-a07e-4e0a-a9e1-5629b03b719b", - "page": 1, - }] - } - entity.update(document_selections) - del entity["tool"] - return entity + named_entities = [] + for feature in prediction_id_mapping: + if "named-entity" not in feature: + continue + entity = feature["named-entity"].copy() + document_selections = { + "textSelections": [{ + "tokenIds": [ + "3f984bf3-1d61-44f5-b59a-9658a2e3440f", + "3bf00b56-ff12-4e52-8cc1-08dbddb3c3b8", + "6e1c3420-d4b7-4c5a-8fd6-ead43bf73d80", + "87a43d32-af76-4a1d-b262-5c5f4d5ace3a", + "e8606e8a-dfd9-4c49-a635-ad5c879c75d0", + "67c7c19e-4654-425d-bf17-2adb8cf02c30", + "149c5e80-3e07-49a7-ab2d-29ddfe6a38fa", + "b0e94071-2187-461e-8e76-96c58738a52c", + ], + "groupId": "2f4336f4-a07e-4e0a-a9e1-5629b03b719b", + "page": 1, + }] + } + entity.update(document_selections) + del entity["tool"] + named_entities.append(entity) + return named_entities @pytest.fixture def segmentation_inference(prediction_id_mapping): - segmentation = prediction_id_mapping["superpixel"].copy() - segmentation.update({ - "mask": { - "instanceURI": - "https://storage.googleapis.com/labelbox-datasets/image_sample_data/raster_seg.png", - "colorRGB": (255, 255, 255), - } - }) - del segmentation["tool"] - return segmentation + superpixel_masks = [] + for feature in prediction_id_mapping: + if "superpixel" not in feature: + continue + segmentation = feature["superpixel"].copy() + segmentation.update({ + "mask": { + "instanceURI": + "https://storage.googleapis.com/labelbox-datasets/image_sample_data/raster_seg.png", + "colorRGB": (255, 255, 255), + } + }) + del segmentation["tool"] + superpixel_masks.append(segmentation) + return superpixel_masks @pytest.fixture def segmentation_inference_rle(prediction_id_mapping): - segmentation = prediction_id_mapping["superpixel"].copy() - segmentation.update({ - "uuid": str(uuid.uuid4()), - "mask": { - "size": [10, 10], - "counts": [1, 0, 10, 100] - }, - }) - del segmentation["tool"] - return segmentation + superpixel_masks = [] + for feature in prediction_id_mapping: + if "superpixel" not in feature: + continue + segmentation = feature["superpixel"].copy() + segmentation.update({ + "uuid": str(uuid.uuid4()), + "mask": { + "size": [10, 10], + "counts": [1, 0, 10, 100] + }, + }) + del segmentation["tool"] + superpixel_masks.append(segmentation) + return superpixel_masks @pytest.fixture def segmentation_inference_png(prediction_id_mapping): - segmentation = prediction_id_mapping["superpixel"].copy() - segmentation.update({ - "uuid": str(uuid.uuid4()), - "mask": { - "png": "somedata", - }, - }) - del segmentation["tool"] - return segmentation + superpixel_masks = [] + for feature in prediction_id_mapping: + if "superpixel" not in feature: + continue + segmentation = feature["superpixel"].copy() + segmentation.update({ + "uuid": str(uuid.uuid4()), + "mask": { + "png": "somedata", + }, + }) + del segmentation["tool"] + superpixel_masks.append(segmentation) + return superpixel_masks @pytest.fixture def checklist_inference(prediction_id_mapping): - checklist = prediction_id_mapping["checklist"].copy() - checklist.update({ - "answers": [{ - "schemaId": checklist["tool"]["options"][0]["featureSchemaId"] - }] - }) - del checklist["tool"] - return checklist + checklists = [] + for feature in prediction_id_mapping: + if "checklist" not in feature: + continue + checklist = feature["checklist"].copy() + checklist.update({ + "answers": [ + {"name": "first_checklist_answer"}, + {"name": "second_checklist_answer"} + ] + }) + del checklist["tool"] + checklists.append(checklist) + return checklists @pytest.fixture def checklist_inference_index(prediction_id_mapping): - checklist = prediction_id_mapping["checklist_index"].copy() - checklist.update({ - "answers": [{ - "schemaId": checklist["tool"]["options"][0]["featureSchemaId"] - }], - "messageId": "0", - }) - del checklist["tool"] - return checklist + checklists = [] + for feature in prediction_id_mapping: + if "checklist_index" not in feature: + return None + checklist = feature["checklist_index"].copy() + checklist.update({ + "answers": [ + {"name": "first_checklist_answer"}, + {"name": "second_checklist_answer"} + ], + "messageId": "0", + }) + del checklist["tool"] + checklists.append(checklist) + return checklists @pytest.fixture def text_inference(prediction_id_mapping): - text = prediction_id_mapping["text"].copy() - text.update({"answer": "free form text..."}) - del text["tool"] - return text + texts = [] + for feature in prediction_id_mapping: + if "text" not in feature: + continue + text = feature["text"].copy() + text.update({"answer": "free form text..."}) + del text["tool"] + texts.append(text) + return texts @pytest.fixture def text_inference_with_confidence(text_inference): - text = text_inference.copy() - text.update({"confidence": 0.9}) - return text + texts = [] + for feature in text_inference: + text = feature.copy() + text.update({"confidence": 0.9}) + texts.append(text) + return texts @pytest.fixture def text_inference_index(prediction_id_mapping): - text = prediction_id_mapping["text_index"].copy() - text.update({"answer": "free form text...", "messageId": "0"}) - del text["tool"] - return text + texts = [] + for feature in prediction_id_mapping: + if "text_index" not in feature: + continue + text = feature["text_index"].copy() + text.update({"answer": "free form text...", "messageId": "0"}) + del text["tool"] + texts.append(text) + return texts @pytest.fixture def video_checklist_inference(prediction_id_mapping): - checklist = prediction_id_mapping["checklist"].copy() - checklist.update({ - "answers": [{ - "schemaId": checklist["tool"]["options"][0]["featureSchemaId"] - }] - }) - - checklist.update( - {"frames": [ - { - "start": 7, - "end": 13, - }, - { - "start": 18, - "end": 19, - }, - ]}) - del checklist["tool"] - return checklist + checklists = [] + for feature in prediction_id_mapping: + if "checklist" not in feature: + continue + checklist = feature["checklist"].copy() + checklist.update({ + "answers": [ + {"name": "first_checklist_answer"}, + {"name": "second_checklist_answer"} + ] + }) + + checklist.update( + {"frames": [ + { + "start": 7, + "end": 13, + }, + { + "start": 18, + "end": 19, + }, + ]}) + del checklist["tool"] + checklists.append(checklist) + return checklists + + +@pytest.fixture +def annotations_by_media_type( + polygon_inference, + rectangle_inference, + rectangle_inference_document, + line_inference_v2, + line_inference, + entity_inference, + entity_inference_index, + entity_inference_document, + checklist_inference_index, + text_inference_index, + checklist_inference, + text_inference, + video_checklist_inference, +): + return { + MediaType.Audio: [checklist_inference, text_inference], + MediaType.Conversational: [ + checklist_inference_index, + text_inference_index, + entity_inference_index, + ], + MediaType.Dicom: [line_inference_v2], + MediaType.Document: [ + entity_inference_document, + checklist_inference, + text_inference, + rectangle_inference_document, + ], + MediaType.Html: [text_inference, checklist_inference], + MediaType.Image: [ + polygon_inference, + rectangle_inference, + line_inference, + checklist_inference, + text_inference, + ], + MediaType.Text: [checklist_inference, text_inference, entity_inference], + MediaType.Video: [video_checklist_inference], + } @pytest.fixture def model_run_predictions(polygon_inference, rectangle_inference, line_inference): # Not supporting mask since there isn't a signed url representing a seg mask to upload - return [polygon_inference, rectangle_inference, line_inference] + return (polygon_inference + rectangle_inference + line_inference) @pytest.fixture @@ -1164,13 +1177,13 @@ def object_predictions( entity_inference, segmentation_inference, ): - return [ - polygon_inference, - rectangle_inference, - line_inference, - entity_inference, - segmentation_inference, - ] + return ( + polygon_inference + + rectangle_inference + + line_inference + + entity_inference + + segmentation_inference + ) @pytest.fixture @@ -1178,28 +1191,28 @@ def object_predictions_for_annotation_import(polygon_inference, rectangle_inference, line_inference, segmentation_inference): - return [ - polygon_inference, - rectangle_inference, - line_inference, - segmentation_inference, - ] + return ( + polygon_inference + + rectangle_inference + + line_inference + + segmentation_inference + ) + @pytest.fixture def classification_predictions(checklist_inference, text_inference): - return [checklist_inference, text_inference] - + return checklist_inference + text_inference @pytest.fixture def predictions(object_predictions, classification_predictions): return object_predictions + classification_predictions +# Can only have confidence predictions supported by media type of project @pytest.fixture -def predictions_with_confidence(text_inference_with_confidence, - rectangle_inference_with_confidence): - return [text_inference_with_confidence, rectangle_inference_with_confidence] +def predictions_with_confidence(rectangle_inference_with_confidence): + return rectangle_inference_with_confidence @pytest.fixture @@ -1248,7 +1261,6 @@ def model_run_with_data_rows( model_run, wait_for_label_processing, ): - configured_project.enable_model_assisted_labeling() use_data_row_ids = [p["dataRow"]["id"] for p in model_run_predictions] model_run.upsert_data_rows(use_data_row_ids) @@ -1269,7 +1281,6 @@ def model_run_with_data_rows( model_run.upsert_labels(label_ids) yield model_run model_run.delete() - # TODO: Delete resources when that is possible .. @pytest.fixture @@ -1277,11 +1288,11 @@ def model_run_with_all_project_labels( client, configured_project, model_run_predictions, - model_run, - wait_for_label_processing, + model_run: ModelRun, + wait_for_label_processing ): - configured_project.enable_model_assisted_labeling() - use_data_row_ids = [p["dataRow"]["id"] for p in model_run_predictions] + use_data_row_ids = list(set([p["dataRow"]["id"] for p in model_run_predictions])) + model_run.upsert_data_rows(use_data_row_ids) upload_task = LabelImport.create_from_objects( @@ -1296,11 +1307,11 @@ def model_run_with_all_project_labels( assert ( len(upload_task.errors) == 0 ), f"Label Import {upload_task.name} failed with errors {upload_task.errors}" - wait_for_label_processing(configured_project) - model_run.upsert_labels(project_id=configured_project.uid) + labels = wait_for_label_processing(configured_project) + label_ids = [label.uid for label in labels] + model_run.upsert_labels(label_ids) yield model_run model_run.delete() - # TODO: Delete resources when that is possible .. class AnnotationImportTestHelpers: @@ -1379,15 +1390,7 @@ def expected_export_v2_image(): "name": "bbox", "value": "bbox", "annotation_kind": "ImageBoundingBox", - "classifications": [{ - "name": "nested", - "value": "nested", - "radio_answer": { - "name": "radio_option_1", - "value": "radio_value_1", - "classifications": [], - }, - }], + "classifications": [], "bounding_box": { "top": 48.0, "left": 58.0, @@ -1419,9 +1422,14 @@ def expected_export_v2_image(): "value": "checklist", "checklist_answers": [{ - "name": "option1", - "value": "option1", + "name": "first_checklist_answer", + "value": "first_checklist_answer", "classifications": [] + }, + { + "name": "second_checklist_answer", + "value": "second_checklist_answer", + "classifications": [] }], }, { @@ -1449,9 +1457,14 @@ def expected_export_v2_audio(): "value": "checklist", "checklist_answers": [{ - "name": "option1", - "value": "option1", + "name": "first_checklist_answer", + "value": "first_checklist_answer", "classifications": [] + }, + { + "name": "second_checklist_answer", + "value": "second_checklist_answer", + "classifications": [] }], }, { @@ -1485,9 +1498,14 @@ def expected_export_v2_html(): "value": "checklist", "checklist_answers": [{ - "name": "option1", - "value": "option1", + "name": "first_checklist_answer", + "value": "first_checklist_answer", "classifications": [] + }, + { + "name": "second_checklist_answer", + "value": "second_checklist_answer", + "classifications": [] }], }, ], @@ -1517,9 +1535,14 @@ def expected_export_v2_text(): "value": "checklist", "checklist_answers": [{ - "name": "option1", - "value": "option1", + "name": "first_checklist_answer", + "value": "first_checklist_answer", "classifications": [] + }, + { + "name": "second_checklist_answer", + "value": "second_checklist_answer", + "classifications": [] }], }, { @@ -1548,11 +1571,16 @@ def expected_export_v2_video(): "checklist", "value": "checklist", - "checklist_answers": [{ - "name": "option1", - "value": "option1", - "classifications": [] - }], + "checklist_answers": [{ + "name": "first_checklist_answer", + "value": "first_checklist_answer", + "classifications": [] + }, + { + "name": "second_checklist_answer", + "value": "second_checklist_answer", + "classifications": [] + }], }], } return expected_annotations @@ -1583,9 +1611,14 @@ def expected_export_v2_conversation(): "message_id": "0", "conversational_checklist_answers": [{ - "name": "option1_index", - "value": "option1_index", - "classifications": [], + "name": "first_checklist_answer", + "value": "first_checklist_answer", + "classifications": [] + }, + { + "name": "second_checklist_answer", + "value": "second_checklist_answer", + "classifications": [] }], }, { @@ -1702,15 +1735,7 @@ def expected_export_v2_document(): "name": "bbox", "value": "bbox", "annotation_kind": "DocumentBoundingBox", - "classifications": [{ - "name": "nested", - "value": "nested", - "radio_answer": { - "name": "radio_option_1", - "value": "radio_value_1", - "classifications": [], - }, - }], + "classifications": [], "page_number": 1, "bounding_box": { "top": 48.0, @@ -1727,9 +1752,14 @@ def expected_export_v2_document(): "value": "checklist", "checklist_answers": [{ - "name": "option1", - "value": "option1", + "name": "first_checklist_answer", + "value": "first_checklist_answer", "classifications": [] + }, + { + "name": "second_checklist_answer", + "value": "second_checklist_answer", + "classifications": [] }], }, { @@ -1756,9 +1786,14 @@ def expected_export_v2_llm_prompt_creation(): "value": "checklist", "checklist_answers": [{ - "name": "option1", - "value": "option1", + "name": "first_checklist_answer", + "value": "first_checklist_answer", "classifications": [] + }, + { + "name": "second_checklist_answer", + "value": "second_checklist_answer", + "classifications": [] }], }, { @@ -1785,38 +1820,14 @@ def expected_export_v2_llm_prompt_response_creation(): "value": "checklist", "checklist_answers": [{ - "name": "option1", - "value": "option1", + "name": "first_checklist_answer", + "value": "first_checklist_answer", "classifications": [] - }], - }, - { - "name": "text", - "value": "text", - "text_answer": { - "content": "free form text..." }, - }, - ], - "relationships": [], - } - return expected_annotations - - -@pytest.fixture() -def expected_export_v2_llm_response_creation(): - expected_annotations = { - "objects": [], - "classifications": [ - { - "name": - "checklist", - "value": - "checklist", - "checklist_answers": [{ - "name": "option1", - "value": "option1", - "classifications": [] + { + "name": "second_checklist_answer", + "value": "second_checklist_answer", + "classifications": [] }], }, { @@ -1832,66 +1843,37 @@ def expected_export_v2_llm_response_creation(): return expected_annotations -import pytest -from labelbox.data.annotation_types.classification.classification import ( - Checklist, - ClassificationAnnotation, - ClassificationAnswer, - Radio, -) -from labelbox.data.annotation_types.geometry.point import Point -from labelbox.data.annotation_types.geometry.rectangle import Rectangle - -from labelbox.data.annotation_types.video import VideoObjectAnnotation - - @pytest.fixture -def bbox_video_annotation_objects(): - bbox_annotation = [ - VideoObjectAnnotation( - name="bbox", - keyframe=True, - frame=13, - segment_index=0, - value=Rectangle( - start=Point(x=146.0, y=98.0), # Top left - end=Point(x=382.0, y=341.0), # Bottom right - ), - classifications=[ - ClassificationAnnotation( - name="nested", - value=Radio(answer=ClassificationAnswer( - name="radio_option_1", - classifications=[ - ClassificationAnnotation( - name="nested_checkbox", - value=Checklist(answer=[ - ClassificationAnswer( - name="nested_checkbox_option_1"), - ClassificationAnswer( - name="nested_checkbox_option_2"), - ]), - ) - ], - )), - ) - ], - ), - VideoObjectAnnotation( - name="bbox", - keyframe=True, - frame=19, - segment_index=0, - value=Rectangle( - start=Point(x=186.0, y=98.0), # Top left - end=Point(x=490.0, y=341.0), # Bottom right - ), - ), - ] - - return bbox_annotation - - +def exports_v2_by_media_type( + expected_export_v2_image, + expected_export_v2_audio, + expected_export_v2_html, + expected_export_v2_text, + expected_export_v2_video, + expected_export_v2_conversation, + expected_export_v2_dicom, + expected_export_v2_document, +): + return { + MediaType.Image: + expected_export_v2_image, + MediaType.Audio: + expected_export_v2_audio, + MediaType.Html: + expected_export_v2_html, + MediaType.Text: + expected_export_v2_text, + MediaType.Video: + expected_export_v2_video, + MediaType.Conversational: + expected_export_v2_conversation, + MediaType.Dicom: + expected_export_v2_dicom, + MediaType.Document: + expected_export_v2_document, + } + + class Helpers: @staticmethod @@ -1931,7 +1913,7 @@ def to_pascal_case(name: str) -> str: data_type_string = data_type_class.__name__[:-4].lower() media_type = to_pascal_case(data_type_string) - if media_type == "Conversation": + if media_type == "Conversational": media_type = "Conversational" elif media_type == "Llmpromptcreation": media_type = "LLMPromptCreation" diff --git a/libs/labelbox/tests/data/annotation_import/test_bulk_import_request.py b/libs/labelbox/tests/data/annotation_import/test_bulk_import_request.py index e428ed4bf..e85af5f5b 100644 --- a/libs/labelbox/tests/data/annotation_import/test_bulk_import_request.py +++ b/libs/labelbox/tests/data/annotation_import/test_bulk_import_request.py @@ -1,6 +1,7 @@ from unittest.mock import patch import uuid from labelbox import parser, Project +from labelbox.data.annotation_types.data.generic_data_row_data import GenericDataRowData import pytest import random from labelbox.data.annotation_types.annotation import ObjectAnnotation @@ -22,19 +23,19 @@ """ - Here we only want to check that the uploads are calling the validation - Then with unit tests we can check the types of errors raised - """ +#TODO: remove library once bulk import requests are removed @pytest.mark.order(1) -def test_create_from_url(project): +def test_create_from_url(module_project): name = str(uuid.uuid4()) url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" - bulk_import_request = project.upload_annotations(name=name, - annotations=url, - validate=False) + bulk_import_request = module_project.upload_annotations(name=name, + annotations=url, + validate=False) - assert bulk_import_request.project() == project + assert bulk_import_request.project() == module_project assert bulk_import_request.name == name assert bulk_import_request.input_file_url == url assert bulk_import_request.error_file_url is None @@ -42,24 +43,24 @@ def test_create_from_url(project): assert bulk_import_request.state == BulkImportRequestState.RUNNING -def test_validate_file(project_with_empty_ontology): +def test_validate_file(module_project): name = str(uuid.uuid4()) url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" with pytest.raises(MALValidationError): - project_with_empty_ontology.upload_annotations(name=name, - annotations=url, - validate=True) + module_project.upload_annotations(name=name, + annotations=url, + validate=True) #Schema ids shouldn't match -def test_create_from_objects(configured_project_with_one_data_row, predictions, +def test_create_from_objects(module_project: Project, predictions, annotation_import_test_helpers): name = str(uuid.uuid4()) - bulk_import_request = configured_project_with_one_data_row.upload_annotations( + bulk_import_request = module_project.upload_annotations( name=name, annotations=predictions) - assert bulk_import_request.project() == configured_project_with_one_data_row + assert bulk_import_request.project() == module_project assert bulk_import_request.name == name assert bulk_import_request.error_file_url is None assert bulk_import_request.status_file_url is None @@ -68,15 +69,15 @@ def test_create_from_objects(configured_project_with_one_data_row, predictions, bulk_import_request.input_file_url, predictions) -def test_create_from_label_objects(configured_project, predictions, +def test_create_from_label_objects(module_project, predictions, annotation_import_test_helpers): name = str(uuid.uuid4()) labels = list(NDJsonConverter.deserialize(predictions)) - bulk_import_request = configured_project.upload_annotations( + bulk_import_request = module_project.upload_annotations( name=name, annotations=labels) - assert bulk_import_request.project() == configured_project + assert bulk_import_request.project() == module_project assert bulk_import_request.name == name assert bulk_import_request.error_file_url is None assert bulk_import_request.status_file_url is None @@ -86,7 +87,7 @@ def test_create_from_label_objects(configured_project, predictions, bulk_import_request.input_file_url, normalized_predictions) -def test_create_from_local_file(tmp_path, predictions, configured_project, +def test_create_from_local_file(tmp_path, predictions, module_project, annotation_import_test_helpers): name = str(uuid.uuid4()) file_name = f"{name}.ndjson" @@ -94,10 +95,10 @@ def test_create_from_local_file(tmp_path, predictions, configured_project, with file_path.open("w") as f: parser.dump(predictions, f) - bulk_import_request = configured_project.upload_annotations( + bulk_import_request = module_project.upload_annotations( name=name, annotations=str(file_path), validate=False) - assert bulk_import_request.project() == configured_project + assert bulk_import_request.project() == module_project assert bulk_import_request.name == name assert bulk_import_request.error_file_url is None assert bulk_import_request.status_file_url is None @@ -106,17 +107,17 @@ def test_create_from_local_file(tmp_path, predictions, configured_project, bulk_import_request.input_file_url, predictions) -def test_get(client, configured_project_with_one_data_row): +def test_get(client, module_project): name = str(uuid.uuid4()) url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" - configured_project_with_one_data_row.upload_annotations(name=name, - annotations=url, - validate=False) + module_project.upload_annotations(name=name, + annotations=url, + validate=False) bulk_import_request = BulkImportRequest.from_name( - client, project_id=configured_project_with_one_data_row.uid, name=name) + client, project_id=module_project.uid, name=name) - assert bulk_import_request.project() == configured_project_with_one_data_row + assert bulk_import_request.project() == module_project assert bulk_import_request.name == name assert bulk_import_request.input_file_url == url assert bulk_import_request.error_file_url is None @@ -124,18 +125,18 @@ def test_get(client, configured_project_with_one_data_row): assert bulk_import_request.state == BulkImportRequestState.RUNNING -def test_validate_ndjson(tmp_path, configured_project_with_one_data_row): +def test_validate_ndjson(tmp_path, module_project): file_name = f"broken.ndjson" file_path = tmp_path / file_name with file_path.open("w") as f: f.write("test") with pytest.raises(ValueError): - configured_project_with_one_data_row.upload_annotations( + module_project.upload_annotations( name="name", validate=True, annotations=str(file_path)) -def test_validate_ndjson_uuid(tmp_path, configured_project, predictions): +def test_validate_ndjson_uuid(tmp_path, module_project, predictions): file_name = f"repeat_uuid.ndjson" file_path = tmp_path / file_name repeat_uuid = predictions.copy() @@ -147,23 +148,23 @@ def test_validate_ndjson_uuid(tmp_path, configured_project, predictions): parser.dump(repeat_uuid, f) with pytest.raises(UuidError): - configured_project.upload_annotations(name="name", + module_project.upload_annotations(name="name", validate=True, annotations=str(file_path)) with pytest.raises(UuidError): - configured_project.upload_annotations(name="name", - validate=True, - annotations=repeat_uuid) + module_project.upload_annotations(name="name", + validate=True, + annotations=repeat_uuid) -@pytest.mark.slow +@pytest.mark.skip("Slow test and uses a deprecated api endpoint for annotation imports") def test_wait_till_done(rectangle_inference, - configured_project_with_one_data_row): + project): name = str(uuid.uuid4()) - url = configured_project_with_one_data_row.client.upload_data( - content=parser.dumps([rectangle_inference]), sign=True) - bulk_import_request = configured_project_with_one_data_row.upload_annotations( + url = project.client.upload_data( + content=parser.dumps(rectangle_inference), sign=True) + bulk_import_request = project.upload_annotations( name=name, annotations=url, validate=False) assert len(bulk_import_request.inputs) == 1 @@ -180,203 +181,50 @@ def test_wait_till_done(rectangle_inference, 'uuid'] -def test_project_bulk_import_requests(configured_project, predictions): - result = configured_project.bulk_import_requests() +def test_project_bulk_import_requests(module_project, predictions): + result = module_project.bulk_import_requests() assert len(list(result)) == 0 name = str(uuid.uuid4()) - bulk_import_request = configured_project.upload_annotations( + bulk_import_request = module_project.upload_annotations( name=name, annotations=predictions) bulk_import_request.wait_until_done() name = str(uuid.uuid4()) - bulk_import_request = configured_project.upload_annotations( + bulk_import_request = module_project.upload_annotations( name=name, annotations=predictions) bulk_import_request.wait_until_done() name = str(uuid.uuid4()) - bulk_import_request = configured_project.upload_annotations( + bulk_import_request = module_project.upload_annotations( name=name, annotations=predictions) bulk_import_request.wait_until_done() - result = configured_project.bulk_import_requests() + result = module_project.bulk_import_requests() assert len(list(result)) == 3 -def test_delete(configured_project, predictions): +def test_delete(module_project, predictions): name = str(uuid.uuid4()) - - bulk_import_request = configured_project.upload_annotations( + + bulk_import_requests = module_project.bulk_import_requests() + [bulk_import_request.delete() for bulk_import_request in bulk_import_requests] + + bulk_import_request = module_project.upload_annotations( name=name, annotations=predictions) bulk_import_request.wait_until_done() - all_import_requests = configured_project.bulk_import_requests() + all_import_requests = module_project.bulk_import_requests() assert len(list(all_import_requests)) == 1 bulk_import_request.delete() - all_import_requests = configured_project.bulk_import_requests() + all_import_requests = module_project.bulk_import_requests() assert len(list(all_import_requests)) == 0 -def test_pdf_mal_bbox(client, configured_project_pdf:Project): - """ - tests pdf mal against only a bbox annotation - """ - annotations = [] - num_annotations = 1 - - for data_row_id in configured_project_pdf.data_row_ids: - for _ in range(num_annotations): - annotations.append({ - "uuid": str(uuid.uuid4()), - "name": "bbox", - "dataRow": { - "id": data_row_id - }, - "bbox": { - "top": round(random.uniform(0, 300), 2), - "left": round(random.uniform(0, 300), 2), - "height": round(random.uniform(200, 500), 2), - "width": round(random.uniform(0, 200), 2) - }, - "page": random.randint(0, 1), - "unit": "POINTS" - }) - annotations.extend([ - { #annotations intended to test classifications - 'name': 'text', - 'answer': 'the answer to the text question', - 'uuid': 'fc1913c6-b735-4dea-bd25-c18152a4715f', - "dataRow": { - "id": data_row_id - } - }, - { - 'name': 'checklist', - 'uuid': '9d7b2e57-d68f-4388-867a-af2a9b233719', - "dataRow": { - "id": data_row_id - }, - 'answer': [{ - 'name': 'option1' - }, { - 'name': 'optionN' - }] - }, - { - 'name': 'radio', - 'answer': { - 'name': 'second_radio_answer' - }, - 'uuid': 'ad60897f-ea1a-47de-b923-459339764921', - "dataRow": { - "id": data_row_id - } - }, - { #adding this with the intention to ensure we allow page: 0 - "uuid": str(uuid.uuid4()), - "name": "bbox", - "dataRow": { - "id": data_row_id - }, - "bbox": { - "top": round(random.uniform(0, 300), 2), - "left": round(random.uniform(0, 300), 2), - "height": round(random.uniform(200, 500), 2), - "width": round(random.uniform(0, 200), 2) - }, - "page": 0, - "unit": "POINTS" - } - ]) - import_annotations = MALPredictionImport.create_from_objects( - client=client, - project_id=configured_project_pdf.uid, - name=f"import {str(uuid.uuid4())}", - predictions=annotations) - import_annotations.wait_until_done() - - assert import_annotations.errors == [] - - -def test_pdf_document_entity(client, configured_project_with_one_data_row, - dataset_pdf_entity, rand_gen): - # for content "Metal-insulator (MI) transitions have been one of the" in OCR JSON extract tests/assets/arxiv-pdf_data_99-word-token-pdfs_0801.3483-lb-textlayer.json - document_text_selection = DocumentTextSelection( - group_id="2f4336f4-a07e-4e0a-a9e1-5629b03b719b", - token_ids=[ - "3f984bf3-1d61-44f5-b59a-9658a2e3440f", - "3bf00b56-ff12-4e52-8cc1-08dbddb3c3b8", - "6e1c3420-d4b7-4c5a-8fd6-ead43bf73d80", - "87a43d32-af76-4a1d-b262-5c5f4d5ace3a", - "e8606e8a-dfd9-4c49-a635-ad5c879c75d0", - "67c7c19e-4654-425d-bf17-2adb8cf02c30", - "149c5e80-3e07-49a7-ab2d-29ddfe6a38fa", - "b0e94071-2187-461e-8e76-96c58738a52c" - ], - page=1) - - entities_annotation_document_entity = DocumentEntity( - text_selections=[document_text_selection]) - entities_annotation = ObjectAnnotation( - name="named-entity", value=entities_annotation_document_entity) - - labels = [] - _, data_row_uids = dataset_pdf_entity - configured_project_with_one_data_row.create_batch( - rand_gen(str), - data_row_uids, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) - ) - - for data_row_uid in data_row_uids: - labels.append( - Label(data=TextData(uid=data_row_uid), - annotations=[ - entities_annotation, - ])) - - import_annotations = MALPredictionImport.create_from_objects( - client=client, - project_id=configured_project_with_one_data_row.uid, - name=f"import {str(uuid.uuid4())}", - predictions=labels) - import_annotations.wait_until_done() - - assert import_annotations.errors == [] - - -def test_nested_video_object_annotations(client, - configured_project_with_one_data_row, - video_data, - bbox_video_annotation_objects, - rand_gen): - labels = [] - _, data_row_uids = video_data - configured_project_with_one_data_row.update(media_type=MediaType.Video) - configured_project_with_one_data_row.create_batch( - rand_gen(str), - data_row_uids, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) - ) - - for data_row_uid in data_row_uids: - labels.append( - Label(data=VideoData(uid=data_row_uid), - annotations=bbox_video_annotation_objects)) - import_annotations = MALPredictionImport.create_from_objects( - client=client, - project_id=configured_project_with_one_data_row.uid, - name=f"import {str(uuid.uuid4())}", - predictions=labels) - import_annotations.wait_until_done() - - assert import_annotations.errors == [] - - def _create_label(row_index, data_row_uids, label_name_ids=['bbox']): label_name = label_name_ids[row_index % len(label_name_ids)] data_row_uid = data_row_uids[row_index % len(data_row_uids)] - return Label(data=VideoData(uid=data_row_uid), + return Label(data=GenericDataRowData(uid=data_row_uid), annotations=[ VideoObjectAnnotation(name=label_name, keyframe=True, @@ -389,12 +237,12 @@ def _create_label(row_index, data_row_uids, label_name_ids=['bbox']): ]) +@pytest.mark.parametrize("configured_project", [MediaType.Video], indirect = True) @patch('labelbox.schema.annotation_import.ANNOTATION_PER_LABEL_LIMIT', 20) def test_below_annotation_limit_on_single_data_row( - client, configured_project_with_one_data_row, video_data, rand_gen): + client, configured_project, video_data, rand_gen): _, data_row_uids = video_data - configured_project_with_one_data_row.update(media_type=MediaType.Video) - configured_project_with_one_data_row.create_batch( + configured_project.create_batch( rand_gen(str), data_row_uids, # sample of data row objects 5 # priority between 1(Highest) - 5(lowest) @@ -402,7 +250,7 @@ def test_below_annotation_limit_on_single_data_row( labels = [_create_label(index, data_row_uids) for index in range(19)] import_annotations = MALPredictionImport.create_from_objects( client=client, - project_id=configured_project_with_one_data_row.uid, + project_id=configured_project.uid, name=f"import {str(uuid.uuid4())}", predictions=labels) import_annotations.wait_until_done() @@ -410,13 +258,13 @@ def test_below_annotation_limit_on_single_data_row( assert import_annotations.errors == [] +@pytest.mark.parametrize("configured_project", [MediaType.Video], indirect = True) @patch('labelbox.schema.annotation_import.ANNOTATION_PER_LABEL_LIMIT', 20) def test_above_annotation_limit_on_single_label_on_single_data_row( - client, configured_project_with_one_data_row, video_data, rand_gen): + client, configured_project, video_data, rand_gen): _, data_row_uids = video_data - configured_project_with_one_data_row.update(media_type=MediaType.Video) - configured_project_with_one_data_row.create_batch( + configured_project.create_batch( rand_gen(str), data_row_uids, # sample of data row objects 5 # priority between 1(Highest) - 5(lowest) @@ -425,20 +273,19 @@ def test_above_annotation_limit_on_single_label_on_single_data_row( with pytest.raises(ValueError): import_annotations = MALPredictionImport.create_from_objects( client=client, - project_id=configured_project_with_one_data_row.uid, + project_id=configured_project.uid, name=f"import {str(uuid.uuid4())}", predictions=labels) import_annotations.wait_until_done() - +@pytest.mark.parametrize("configured_project", [MediaType.Video], indirect = True) @patch('labelbox.schema.annotation_import.ANNOTATION_PER_LABEL_LIMIT', 20) def test_above_annotation_limit_divided_among_different_rows( - client, configured_project_with_one_data_row, video_data_100_rows, + client, configured_project, video_data_100_rows, rand_gen): _, data_row_uids = video_data_100_rows - configured_project_with_one_data_row.update(media_type=MediaType.Video) - configured_project_with_one_data_row.create_batch( + configured_project.create_batch( rand_gen(str), data_row_uids, # sample of data row objects 5 # priority between 1(Highest) - 5(lowest) @@ -447,20 +294,20 @@ def test_above_annotation_limit_divided_among_different_rows( import_annotations = MALPredictionImport.create_from_objects( client=client, - project_id=configured_project_with_one_data_row.uid, + project_id=configured_project.uid, name=f"import {str(uuid.uuid4())}", predictions=labels) assert import_annotations.errors == [] +@pytest.mark.parametrize("configured_project", [MediaType.Video], indirect = True) @patch('labelbox.schema.annotation_import.ANNOTATION_PER_LABEL_LIMIT', 20) def test_above_annotation_limit_divided_among_labels_on_one_row( - client, configured_project_with_one_data_row, video_data, rand_gen): + client, configured_project, video_data, rand_gen): _, data_row_uids = video_data - configured_project_with_one_data_row.update(media_type=MediaType.Video) - configured_project_with_one_data_row.create_batch( + configured_project.create_batch( rand_gen(str), data_row_uids, # sample of data row objects 5 # priority between 1(Highest) - 5(lowest) @@ -474,7 +321,7 @@ def test_above_annotation_limit_divided_among_labels_on_one_row( import_annotations = MALPredictionImport.create_from_objects( client=client, - project_id=configured_project_with_one_data_row.uid, + project_id=configured_project.uid, name=f"import {str(uuid.uuid4())}", predictions=labels) diff --git a/libs/labelbox/tests/data/annotation_import/test_conversation_import.py b/libs/labelbox/tests/data/annotation_import/test_conversation_import.py deleted file mode 100644 index 4332bfd03..000000000 --- a/libs/labelbox/tests/data/annotation_import/test_conversation_import.py +++ /dev/null @@ -1,45 +0,0 @@ -import uuid -import pytest -from labelbox.data.annotation_types.annotation import ObjectAnnotation -from labelbox.data.annotation_types.label import Label -from labelbox.data.annotation_types.data.text import TextData -from labelbox.data.annotation_types.ner import ConversationEntity - -from labelbox.schema.annotation_import import MALPredictionImport - -@pytest.mark.order(1) -def test_conversation_entity(client, configured_project_with_one_data_row, - dataset_conversation_entity, rand_gen): - - conversation_entity_annotation = ConversationEntity(start=0, - end=8, - message_id="4") - - entities_annotation = ObjectAnnotation(name="named-entity", - value=conversation_entity_annotation) - - labels = [] - _, data_row_uids = dataset_conversation_entity - - configured_project_with_one_data_row.create_batch( - rand_gen(str), - data_row_uids, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) - ) - - for data_row_uid in data_row_uids: - labels.append( - Label(data=TextData(uid=data_row_uid), - annotations=[ - entities_annotation, - ])) - - import_annotations = MALPredictionImport.create_from_objects( - client=client, - project_id=configured_project_with_one_data_row.uid, - name=f"import {str(uuid.uuid4())}", - predictions=labels) - - import_annotations.wait_until_done() - - assert import_annotations.errors == [] diff --git a/libs/labelbox/tests/data/annotation_import/test_data_types.py b/libs/labelbox/tests/data/annotation_import/test_data_types.py index d607c4a3c..74695470d 100644 --- a/libs/labelbox/tests/data/annotation_import/test_data_types.py +++ b/libs/labelbox/tests/data/annotation_import/test_data_types.py @@ -1,12 +1,8 @@ import datetime -import itertools +from labelbox.schema.label import Label import pytest import uuid -import labelbox as lb -from labelbox.data.annotation_types.data.video import VideoData -from labelbox.schema.media_type import MediaType -import labelbox.types as lb_types from labelbox.data.annotation_types.data import ( AudioData, ConversationData, @@ -15,412 +11,76 @@ HTMLData, ImageData, TextData, - LlmPromptCreationData, - LlmPromptResponseCreationData, - LlmResponseCreationData, ) from labelbox.data.serialization import NDJsonConverter +from labelbox.data.annotation_types.data.video import VideoData + +import labelbox as lb +import labelbox.types as lb_types +from labelbox.schema.media_type import MediaType from labelbox.schema.annotation_import import AnnotationImportState +from labelbox import Project, Client -DATA_ROW_PROCESSING_WAIT_TIMEOUT_SECONDS = 40 -DATA_ROW_PROCESSING_WAIT_SLEEP_INTERNAL_SECONDS = 7 +# Unit test for label based on data type. +# TODO: Dicom removed it is unstable when you deserialize and serialize on label import. If we intend to keep this library this needs add generic data types tests work with this data type. +# TODO: add MediaType.LLMPromptResponseCreation(data gen) once supported and llm human preference once media type is added -radio_annotation = lb_types.ClassificationAnnotation( - name="radio", - value=lb_types.Radio(answer=lb_types.ClassificationAnswer( - name="second_radio_answer")), -) -checklist_annotation = lb_types.ClassificationAnnotation( - name="checklist", - value=lb_types.Checklist(answer=[ - lb_types.ClassificationAnswer(name="option1"), - lb_types.ClassificationAnswer(name="option2"), - ]), -) -text_annotation = lb_types.ClassificationAnnotation( - name="text", value=lb_types.Text(answer="sample text")) - -video_mask_annotation = lb_types.VideoMaskAnnotation( - frames=[ - lb_types.MaskFrame( - index=10, - instance_uri= - "https://storage.googleapis.com/labelbox-datasets/video-sample-data/mask_example.png", - ) - ], - instances=[ - lb_types.MaskInstance(color_rgb=(255, 255, 255), - name="segmentation_mask") - ], -) -test_params = [ - [ - "html", - lb_types.HTMLData, - [radio_annotation, checklist_annotation, text_annotation], - ], +@pytest.mark.parametrize( + "media_type, data_type_class", [ - "audio", - lb_types.AudioData, - [radio_annotation, checklist_annotation, text_annotation], + (MediaType.Audio, AudioData), + (MediaType.Html, HTMLData), + (MediaType.Image, ImageData), + (MediaType.Text, TextData), + (MediaType.Video, VideoData), + (MediaType.Conversational, ConversationData), + (MediaType.Document, DocumentData), ], - ["video", lb_types.VideoData, [video_mask_annotation]], -] - - -def get_annotation_comparison_dicts_from_labels(labels): - labels_ndjson = list(NDJsonConverter.serialize(labels)) - for annotation in labels_ndjson: - annotation.pop("uuid", None) - annotation.pop("dataRow") - - if "masks" in annotation: - for frame in annotation["masks"]["frames"]: - frame.pop("instanceURI") - frame.pop("imBytes") - for instance in annotation["masks"]["instances"]: - instance.pop("colorRGB") - return labels_ndjson - - -def get_annotation_comparison_dicts_from_export(export_result, data_row_id, - project_id): - exported_data_row = [ - dr for dr in export_result if dr["data_row"]["id"] == data_row_id - ][0] - exported_label = exported_data_row["projects"][project_id]["labels"][0] - exported_annotations = exported_label["annotations"] - converted_annotations = [] - if exported_label["label_kind"] == "Video": - frames = [] - instances = [] - for frame_id, frame in exported_annotations["frames"].items(): - frames.append({"index": int(frame_id)}) - for object in frame["objects"].values(): - instances.append({"name": object["name"]}) - converted_annotations.append( - {"masks": { - "frames": frames, - "instances": instances, - }}) - else: - exported_annotations = list( - itertools.chain(*exported_annotations.values())) - for annotation in exported_annotations: - if annotation["name"] == "radio": - converted_annotations.append({ - "name": annotation["name"], - "answer": { - "name": annotation["radio_answer"]["name"] - }, - }) - elif annotation["name"] == "checklist": - converted_annotations.append({ - "name": - annotation["name"], - "answer": [{ - "name": answer["name"] - } for answer in annotation["checklist_answers"]], - }) - elif annotation["name"] == "text": - converted_annotations.append({ - "name": annotation["name"], - "answer": annotation["text_answer"]["content"], - }) - return converted_annotations - - -def create_data_row_for_project(project, dataset, data_row_ndjson, batch_name): - data_row = dataset.create_data_row(data_row_ndjson) - - project.create_batch( - batch_name, - [data_row.uid], # sample of data row objects - 5, # priority between 1(Highest) - 5(lowest) - ) - project.data_row_ids.append(data_row.uid) - - return data_row - - -@pytest.mark.skip(reason="broken export v1 api, to be retired soon") -def test_import_data_types_by_global_key( - client, - configured_project, - initial_dataset, - rand_gen, - data_row_json_by_data_type, - annotations_by_data_type, - helpers, +) +def test_data_row_type_by_data_row_id( + media_type, + data_type_class, + annotations_by_media_type, + hardcoded_datarow_id, ): - project = configured_project - project_id = project.uid - dataset = initial_dataset - data_type_class = ImageData - helpers.set_project_media_type_from_data_type(project, data_type_class) - - data_row_ndjson = data_row_json_by_data_type["image"] - data_row_ndjson["global_key"] = str(uuid.uuid4()) - data_row = create_data_row_for_project(project, dataset, data_row_ndjson, - rand_gen(str)) - - annotations_ndjson = annotations_by_data_type["image"] - annotations_list = [ - label.annotations - for label in NDJsonConverter.deserialize(annotations_ndjson) - ] - labels = [ - lb_types.Label( - data=data_type_class(global_key=data_row.global_key), - annotations=annotations, - ) for annotations in annotations_list - ] + annotations_ndjson = annotations_by_media_type[media_type] + annotations_ndjson = [annotation[0] for annotation in annotations_ndjson] + + label = list(NDJsonConverter.deserialize(annotations_ndjson))[0] + + data_label = lb_types.Label(data=data_type_class(uid = hardcoded_datarow_id()), + annotations=label.annotations) - label_import = lb.LabelImport.create_from_objects(client, project_id, - f"test-import-image", - labels) - label_import.wait_until_done() + assert data_label.data.uid == label.data.uid + assert label.annotations == data_label.annotations - assert label_import.state == AnnotationImportState.FINISHED - assert len(label_import.errors) == 0 - exported_labels = project.export_labels(download=True) - objects = exported_labels[0]["Label"]["objects"] - classifications = exported_labels[0]["Label"]["classifications"] - assert len(objects) + len(classifications) == len(labels) - data_row.delete() - -def validate_iso_format(date_string: str): - parsed_t = datetime.datetime.fromisoformat( - date_string) # this will blow up if the string is not in iso format - assert parsed_t.hour is not None - assert parsed_t.minute is not None - assert parsed_t.second is not None - -@pytest.mark.order(1) @pytest.mark.parametrize( - "data_type_class", + "media_type, data_type_class", [ - AudioData, - HTMLData, - ImageData, - TextData, - VideoData, - ConversationData, - DocumentData, - DicomData, - LlmResponseCreationData, + (MediaType.Audio, AudioData), + (MediaType.Html, HTMLData), + (MediaType.Image, ImageData), + (MediaType.Text, TextData), + (MediaType.Video, VideoData), + (MediaType.Conversational, ConversationData), + (MediaType.Document, DocumentData), ], ) -def test_import_data_types_v2( - client, - configured_project, - initial_dataset, - data_row_json_by_data_type, - annotations_by_data_type_v2, +def test_data_row_type_by_global_key( + media_type, data_type_class, - exports_v2_by_data_type, - export_v2_test_helpers, - rand_gen, - helpers, + annotations_by_media_type, + hardcoded_global_key, ): - project = configured_project - dataset = initial_dataset - project_id = project.uid - - helpers.set_project_media_type_from_data_type(project, data_type_class) - - data_type_string = data_type_class.__name__[:-4].lower() - data_row_ndjson = data_row_json_by_data_type[data_type_string] - data_row = create_data_row_for_project(project, dataset, data_row_ndjson, - rand_gen(str)) - annotations_ndjson = annotations_by_data_type_v2[data_type_string] - annotations_list = [ - label.annotations - for label in NDJsonConverter.deserialize(annotations_ndjson) - ] - labels = [ - lb_types.Label(data=data_type_class(uid=data_row.uid), - annotations=annotations) - for annotations in annotations_list - ] - - label_import = lb.LabelImport.create_from_objects( - client, project_id, f"test-import-{data_type_string}", labels) - label_import.wait_until_done() - - assert label_import.state == AnnotationImportState.FINISHED - assert len(label_import.errors) == 0 - - # TODO need to migrate project to the new BATCH mode and change this code - # to be similar to tests/integration/test_task_queue.py - - result = export_v2_test_helpers.run_project_export_v2_task(project) - find_data_row = lambda dr: dr['data_row']['id'] == data_row.uid - exported_data = list(filter(find_data_row, result))[0] - assert exported_data - - # timestamp fields are in iso format - validate_iso_format(exported_data["data_row"]["details"]["created_at"]) - validate_iso_format(exported_data["data_row"]["details"]["updated_at"]) - validate_iso_format(exported_data["projects"][project_id]["labels"][0] - ["label_details"]["created_at"]) - validate_iso_format(exported_data["projects"][project_id]["labels"][0] - ["label_details"]["updated_at"]) - - assert exported_data["data_row"]["id"] == data_row.uid - exported_project = exported_data["projects"][project_id] - exported_project_labels = exported_project["labels"][0] - exported_annotations = exported_project_labels["annotations"] - - helpers.remove_keys_recursive(exported_annotations, - ["feature_id", "feature_schema_id"]) - helpers.rename_cuid_key_recursive(exported_annotations) - assert exported_annotations == exports_v2_by_data_type[data_type_string] - - data_row = client.get_data_row(data_row.uid) - data_row.delete() - - -@pytest.mark.parametrize("data_type, data_class, annotations", test_params) -def test_import_label_annotations( - client, - configured_project_with_one_data_row, - initial_dataset, - data_row_json_by_data_type, - data_type, - data_class, - annotations, - rand_gen, - helpers, -): - project = configured_project_with_one_data_row - dataset = initial_dataset - helpers.set_project_media_type_from_data_type(project, data_class) - - data_row_json = data_row_json_by_data_type[data_type] - data_row = create_data_row_for_project(project, dataset, data_row_json, - rand_gen(str)) - - labels = [ - lb_types.Label(data=data_class(uid=data_row.uid), - annotations=annotations) - ] - - label_import = lb.LabelImport.create_from_objects(client, project.uid, - f"test-import-html", - labels) - label_import.wait_until_done() - - assert label_import.state == lb.AnnotationImportState.FINISHED - assert len(label_import.errors) == 0 - export_params = { - "attachments": False, - "metadata_fields": False, - "data_row_details": False, - "project_details": False, - "performance_details": False, - } - export_task = project.export_v2(params=export_params) - export_task.wait_till_done() - assert export_task.errors is None - expected_annotations = get_annotation_comparison_dicts_from_labels(labels) - actual_annotations = get_annotation_comparison_dicts_from_export( - export_task.result, data_row.uid, - configured_project_with_one_data_row.uid) - assert actual_annotations == expected_annotations - data_row.delete() - - -@pytest.mark.parametrize("data_type, data_class, annotations", test_params) -@pytest.fixture -def one_datarow(client, rand_gen, data_row_json_by_data_type, data_type): - dataset = client.create_dataset(name=rand_gen(str)) - data_row_json = data_row_json_by_data_type[data_type] - data_row = dataset.create_data_row(data_row_json) - - yield data_row - - dataset.delete() - - -@pytest.fixture -def one_datarow_global_key(client, rand_gen, data_row_json_by_data_type): - dataset = client.create_dataset(name=rand_gen(str)) - data_row_json = data_row_json_by_data_type["video"] - data_row = dataset.create_data_row(data_row_json) - - yield data_row - - dataset.delete() - - -@pytest.mark.parametrize("data_type, data_class, annotations", test_params) -def test_import_mal_annotations( - client, - configured_project_with_one_data_row, - data_type, - data_class, - annotations, - rand_gen, - one_datarow, - helpers, -): - data_row = one_datarow - helpers.set_project_media_type_from_data_type( - configured_project_with_one_data_row, data_class) - - configured_project_with_one_data_row.create_batch( - rand_gen(str), - [data_row.uid], - ) - - labels = [ - lb_types.Label(data=data_class(uid=data_row.uid), - annotations=annotations) - ] - - import_annotations = lb.MALPredictionImport.create_from_objects( - client=client, - project_id=configured_project_with_one_data_row.uid, - name=f"import {str(uuid.uuid4())}", - predictions=labels, - ) - import_annotations.wait_until_done() - - assert import_annotations.errors == [] - # MAL Labels cannot be exported and compared to input labels - - -def test_import_mal_annotations_global_key(client, - configured_project_with_one_data_row, - rand_gen, one_datarow_global_key, - helpers): - data_class = lb_types.VideoData - data_row = one_datarow_global_key - annotations = [video_mask_annotation] - helpers.set_project_media_type_from_data_type( - configured_project_with_one_data_row, data_class) - - configured_project_with_one_data_row.create_batch( - rand_gen(str), - [data_row.uid], - ) - - labels = [ - lb_types.Label(data=data_class(global_key=data_row.global_key), - annotations=annotations) - ] - - import_annotations = lb.MALPredictionImport.create_from_objects( - client=client, - project_id=configured_project_with_one_data_row.uid, - name=f"import {str(uuid.uuid4())}", - predictions=labels, - ) - import_annotations.wait_until_done() - - assert import_annotations.errors == [] - # MAL Labels cannot be exported and compared to input labels - \ No newline at end of file + annotations_ndjson = annotations_by_media_type[media_type] + annotations_ndjson = [annotation[0] for annotation in annotations_ndjson] + + label = list(NDJsonConverter.deserialize(annotations_ndjson))[0] + + data_label = lb_types.Label(data=data_type_class(global_key = hardcoded_global_key()), + annotations=label.annotations) + + assert data_label.data.global_key == label.data.global_key + assert label.annotations == data_label.annotations \ No newline at end of file diff --git a/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py b/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py index 1ea155fa1..ea6b5876b 100644 --- a/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py +++ b/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py @@ -1,315 +1,262 @@ import datetime +from labelbox.data.annotation_types.data.generic_data_row_data import GenericDataRowData +from labelbox.data.serialization.ndjson.converter import NDJsonConverter +from labelbox.data.annotation_types import Label import pytest import uuid import labelbox as lb -from labelbox.data.annotation_types.data.video import VideoData from labelbox.schema.media_type import MediaType -import labelbox.types as lb_types -from labelbox.data.annotation_types.data import ( - AudioData, - ConversationData, - DicomData, - DocumentData, - HTMLData, - ImageData, - TextData, - LlmPromptCreationData, - LlmPromptResponseCreationData, - LlmResponseCreationData, -) -from labelbox.data.serialization import NDJsonConverter from labelbox.schema.annotation_import import AnnotationImportState +from labelbox import Project, Client +import itertools -radio_annotation = lb_types.ClassificationAnnotation( - name="radio", - value=lb_types.Radio(answer=lb_types.ClassificationAnswer( - name="second_radio_answer")), -) -checklist_annotation = lb_types.ClassificationAnnotation( - name="checklist", - value=lb_types.Checklist(answer=[ - lb_types.ClassificationAnswer(name="option1"), - lb_types.ClassificationAnswer(name="option2"), - ]), -) -text_annotation = lb_types.ClassificationAnnotation( - name="text", value=lb_types.Text(answer="sample text")) - -video_mask_annotation = lb_types.VideoMaskAnnotation( - frames=[ - lb_types.MaskFrame( - index=10, - instance_uri= - "https://storage.googleapis.com/labelbox-datasets/video-sample-data/mask_example.png", - ) - ], - instances=[ - lb_types.MaskInstance(color_rgb=(255, 255, 255), - name="segmentation_mask") - ], -) +""" + - integration test for importing mal labels and ground truths with each supported MediaType. + - NDJSON is used to generate annotations. +""" -test_params = [ - [ - "html", - lb_types.HTMLData, - [radio_annotation, checklist_annotation, text_annotation], - ], +def validate_iso_format(date_string: str): + parsed_t = datetime.datetime.fromisoformat( + date_string) # this will blow up if the string is not in iso format + assert parsed_t.hour is not None + assert parsed_t.minute is not None + assert parsed_t.second is not None + +@pytest.mark.parametrize( + "media_type, data_type_class", [ - "audio", - lb_types.AudioData, - [radio_annotation, checklist_annotation, text_annotation], + (MediaType.Audio, GenericDataRowData), + (MediaType.Html, GenericDataRowData), + (MediaType.Image, GenericDataRowData), + (MediaType.Text, GenericDataRowData), + (MediaType.Video, GenericDataRowData), + (MediaType.Conversational, GenericDataRowData), + (MediaType.Document, GenericDataRowData), ], - ["video", lb_types.VideoData, [video_mask_annotation]], -] +) +def test_generic_data_row_type_by_data_row_id( + media_type, + data_type_class, + annotations_by_media_type, + hardcoded_datarow_id, +): + annotations_ndjson = annotations_by_media_type[media_type] + annotations_ndjson = [annotation[0] for annotation in annotations_ndjson] + + label = list(NDJsonConverter.deserialize(annotations_ndjson))[0] + + data_label = Label(data=data_type_class(uid = hardcoded_datarow_id()), + annotations=label.annotations) + assert data_label.data.uid == label.data.uid + assert label.annotations == data_label.annotations -def create_data_row_for_project(project, dataset, data_row_ndjson, batch_name): - data_row = dataset.create_data_row(data_row_ndjson) - project.create_batch( - batch_name, - [data_row.uid], # sample of data row objects - 5, # priority between 1(Highest) - 5(lowest) - ) - project.data_row_ids.append(data_row.uid) +@pytest.mark.parametrize( + "media_type, data_type_class", + [ + (MediaType.Audio, GenericDataRowData), + (MediaType.Html, GenericDataRowData), + (MediaType.Image, GenericDataRowData), + (MediaType.Text, GenericDataRowData), + (MediaType.Video, GenericDataRowData), + (MediaType.Conversational, GenericDataRowData), + (MediaType.Document, GenericDataRowData), + ], +) +def test_generic_data_row_type_by_global_key( + media_type, + data_type_class, + annotations_by_media_type, + hardcoded_global_key, +): + annotations_ndjson = annotations_by_media_type[media_type] + annotations_ndjson = [annotation[0] for annotation in annotations_ndjson] + + label = list(NDJsonConverter.deserialize(annotations_ndjson))[0] + + data_label = Label(data=data_type_class(global_key = hardcoded_global_key()), + annotations=label.annotations) - return data_row + assert data_label.data.global_key == label.data.global_key + assert label.annotations == data_label.annotations -@pytest.mark.order(1) -def test_import_data_types_by_global_key( - client, - configured_project, - initial_dataset, - rand_gen, - data_row_json_by_data_type, - annotations_by_data_type, + +# TODO: add MediaType.LLMPromptResponseCreation(data gen) once supported and llm human preference once media type is added +@pytest.mark.parametrize( + "configured_project", + [ + MediaType.Audio, + MediaType.Html, + MediaType.Image, + MediaType.Text, + MediaType.Video, + MediaType.Conversational, + MediaType.Document, + MediaType.Dicom, + ], + indirect=True +) +def test_import_media_types( + client: Client, + configured_project: Project, + annotations_by_media_type, + exports_v2_by_media_type, export_v2_test_helpers, helpers, ): - project = configured_project - project_id = project.uid - dataset = initial_dataset - data_type_class = ImageData - helpers.set_project_media_type_from_data_type(project, data_type_class) - - data_row_ndjson = data_row_json_by_data_type["image"] - data_row_ndjson["global_key"] = str(uuid.uuid4()) - data_row = create_data_row_for_project(project, dataset, data_row_ndjson, - rand_gen(str)) - - annotations_ndjson = annotations_by_data_type["image"] - annotations_list = [ - label.annotations - for label in NDJsonConverter.deserialize(annotations_ndjson) - ] - labels = [ - lb_types.Label( - data={'global_key': data_row.global_key}, - annotations=annotations, - ) for annotations in annotations_list - ] - - def find_data_row(dr): - return dr['data_row']['id'] == data_row.uid - - label_import = lb.LabelImport.create_from_objects(client, project_id, - f"test-import-image", - labels) + annotations_ndjson = list(itertools.chain.from_iterable(annotations_by_media_type[configured_project.media_type])) + + label_import = lb.LabelImport.create_from_objects( + client, configured_project.uid, f"test-import-{configured_project.media_type}", annotations_ndjson) label_import.wait_until_done() assert label_import.state == AnnotationImportState.FINISHED assert len(label_import.errors) == 0 - result = export_v2_test_helpers.run_project_export_v2_task(project) - exported_data = list(filter(find_data_row, result))[0] - assert exported_data + result = export_v2_test_helpers.run_project_export_v2_task(configured_project) - label = exported_data['projects'][project.uid]['labels'][0] - annotations = label['annotations'] - objects = annotations['objects'] - classifications = annotations['classifications'] - assert len(objects) + len(classifications) == len(labels) + assert result - data_row.delete() + for exported_data in result: + # timestamp fields are in iso format + validate_iso_format(exported_data["data_row"]["details"]["created_at"]) + validate_iso_format(exported_data["data_row"]["details"]["updated_at"]) + validate_iso_format(exported_data["projects"][configured_project.uid]["labels"][0] + ["label_details"]["created_at"]) + validate_iso_format(exported_data["projects"][configured_project.uid]["labels"][0] + ["label_details"]["updated_at"]) + assert exported_data["data_row"]["id"] in configured_project.data_row_ids + exported_project = exported_data["projects"][configured_project.uid] + exported_project_labels = exported_project["labels"][0] + exported_annotations = exported_project_labels["annotations"] -def validate_iso_format(date_string: str): - parsed_t = datetime.datetime.fromisoformat( - date_string) # this will blow up if the string is not in iso format - assert parsed_t.hour is not None - assert parsed_t.minute is not None - assert parsed_t.second is not None + expected_data = exports_v2_by_media_type[configured_project.media_type] + helpers.remove_keys_recursive(exported_annotations, + ["feature_id", "feature_schema_id"]) + helpers.rename_cuid_key_recursive(exported_annotations) + assert exported_annotations == expected_data + +@pytest.mark.order(1) @pytest.mark.parametrize( - "data_type_class", + "configured_project_by_global_key", [ - AudioData, - HTMLData, - ImageData, - TextData, - VideoData, - ConversationData, - DocumentData, - DicomData, - LlmResponseCreationData, + MediaType.Audio, + MediaType.Html, + MediaType.Image, + MediaType.Text, + MediaType.Video, + MediaType.Conversational, + MediaType.Document, + MediaType.Dicom, ], + indirect=True ) -def test_import_data_types_v2( +def test_import_media_types_by_global_key( client, - configured_project, - initial_dataset, - data_row_json_by_data_type, - annotations_by_data_type_v2, - data_type_class, - exports_v2_by_data_type, + configured_project_by_global_key, + annotations_by_media_type, + exports_v2_by_media_type, export_v2_test_helpers, - rand_gen, helpers, ): - project = configured_project - dataset = initial_dataset - project_id = project.uid - - helpers.set_project_media_type_from_data_type(project, data_type_class) - - data_type_string = data_type_class.__name__[:-4].lower() - data_row_ndjson = data_row_json_by_data_type[data_type_string] - data_row = create_data_row_for_project(project, dataset, data_row_ndjson, - rand_gen(str)) - annotations_ndjson = annotations_by_data_type_v2[data_type_string] - annotations_list = [ - label.annotations - for label in NDJsonConverter.deserialize(annotations_ndjson) - ] - labels = [ - lb_types.Label(data={'uid': data_row.uid}, annotations=annotations) - for annotations in annotations_list - ] + annotations_ndjson = list(itertools.chain.from_iterable(annotations_by_media_type[configured_project_by_global_key.media_type])) label_import = lb.LabelImport.create_from_objects( - client, project_id, f"test-import-{data_type_string}", labels) + client, configured_project_by_global_key.uid, f"test-import-{configured_project_by_global_key.media_type}", annotations_ndjson) label_import.wait_until_done() assert label_import.state == AnnotationImportState.FINISHED assert len(label_import.errors) == 0 - # TODO need to migrate project to the new BATCH mode and change this code - # to be similar to tests/integration/test_task_queue.py - - result = export_v2_test_helpers.run_project_export_v2_task(project) - - exported_data = next( - dr for dr in result if dr['data_row']['id'] == data_row.uid) - assert exported_data - - # timestamp fields are in iso format - validate_iso_format(exported_data["data_row"]["details"]["created_at"]) - validate_iso_format(exported_data["data_row"]["details"]["updated_at"]) - validate_iso_format(exported_data["projects"][project_id]["labels"][0] - ["label_details"]["created_at"]) - validate_iso_format(exported_data["projects"][project_id]["labels"][0] - ["label_details"]["updated_at"]) - - assert exported_data["data_row"]["id"] == data_row.uid - exported_project = exported_data["projects"][project_id] - exported_project_labels = exported_project["labels"][0] - exported_annotations = exported_project_labels["annotations"] - - helpers.remove_keys_recursive(exported_annotations, - ["feature_id", "feature_schema_id"]) - helpers.rename_cuid_key_recursive(exported_annotations) - assert exported_annotations == exports_v2_by_data_type[data_type_string] - - data_row = client.get_data_row(data_row.uid) - data_row.delete() - + result = export_v2_test_helpers.run_project_export_v2_task(configured_project_by_global_key) -@pytest.mark.parametrize("data_type, data_class, annotations", test_params) -@pytest.fixture -def one_datarow(client, rand_gen, data_row_json_by_data_type, data_type): - dataset = client.create_dataset(name=rand_gen(str)) - data_row_json = data_row_json_by_data_type[data_type] - data_row = dataset.create_data_row(data_row_json) + assert result - yield data_row + for exported_data in result: + # timestamp fields are in iso format + validate_iso_format(exported_data["data_row"]["details"]["created_at"]) + validate_iso_format(exported_data["data_row"]["details"]["updated_at"]) + validate_iso_format(exported_data["projects"][configured_project_by_global_key.uid]["labels"][0] + ["label_details"]["created_at"]) + validate_iso_format(exported_data["projects"][configured_project_by_global_key.uid]["labels"][0] + ["label_details"]["updated_at"]) - dataset.delete() + assert exported_data["data_row"]["id"] in configured_project_by_global_key.data_row_ids + exported_project = exported_data["projects"][configured_project_by_global_key.uid] + exported_project_labels = exported_project["labels"][0] + exported_annotations = exported_project_labels["annotations"] + expected_data = exports_v2_by_media_type[configured_project_by_global_key.media_type] + helpers.remove_keys_recursive(exported_annotations, + ["feature_id", "feature_schema_id"]) + helpers.rename_cuid_key_recursive(exported_annotations) -@pytest.fixture -def one_datarow_global_key(client, rand_gen, data_row_json_by_data_type): - dataset = client.create_dataset(name=rand_gen(str)) - data_row_json = data_row_json_by_data_type["video"] - data_row = dataset.create_data_row(data_row_json) + assert exported_annotations == expected_data - yield data_row - dataset.delete() - - -@pytest.mark.parametrize("data_type, data_class, annotations", test_params) +@pytest.mark.parametrize( + "configured_project", + [ + MediaType.Audio, + MediaType.Html, + MediaType.Image, + MediaType.Text, + MediaType.Video, + MediaType.Conversational, + MediaType.Document, + MediaType.Dicom, + ], + indirect=True +) def test_import_mal_annotations( client, - configured_project_with_one_data_row, - data_class, - annotations, - rand_gen, - one_datarow, - helpers, + configured_project: Project, + annotations_by_media_type, ): - data_row = one_datarow - helpers.set_project_media_type_from_data_type( - configured_project_with_one_data_row, data_class) - - configured_project_with_one_data_row.create_batch( - rand_gen(str), - [data_row.uid], - ) - - labels = [ - lb_types.Label(data={'uid': data_row.uid}, annotations=annotations) - ] + annotations_ndjson = list(itertools.chain.from_iterable(annotations_by_media_type[configured_project.media_type])) import_annotations = lb.MALPredictionImport.create_from_objects( client=client, - project_id=configured_project_with_one_data_row.uid, + project_id=configured_project.uid, name=f"import {str(uuid.uuid4())}", - predictions=labels, + predictions=annotations_ndjson, ) import_annotations.wait_until_done() assert import_annotations.errors == [] # MAL Labels cannot be exported and compared to input labels + - +@pytest.mark.parametrize( + "configured_project_by_global_key", + [ + MediaType.Audio, + MediaType.Html, + MediaType.Image, + MediaType.Text, + MediaType.Video, + MediaType.Conversational, + MediaType.Document, + MediaType.Dicom, + ], + indirect=True +) def test_import_mal_annotations_global_key(client, - configured_project_with_one_data_row, - rand_gen, one_datarow_global_key, - helpers): - data_class = lb_types.VideoData - data_row = one_datarow_global_key - annotations = [video_mask_annotation] - helpers.set_project_media_type_from_data_type( - configured_project_with_one_data_row, data_class) - - configured_project_with_one_data_row.create_batch( - rand_gen(str), - [data_row.uid], - ) + configured_project_by_global_key: Project, + annotations_by_media_type): - labels = [ - lb_types.Label(data={'global_key': data_row.global_key}, - annotations=annotations) - ] + annotations_ndjson = list(itertools.chain.from_iterable(annotations_by_media_type[configured_project_by_global_key.media_type])) import_annotations = lb.MALPredictionImport.create_from_objects( client=client, - project_id=configured_project_with_one_data_row.uid, + project_id=configured_project_by_global_key.uid, name=f"import {str(uuid.uuid4())}", - predictions=labels, + predictions=annotations_ndjson, ) import_annotations.wait_until_done() diff --git a/libs/labelbox/tests/data/annotation_import/test_label_import.py b/libs/labelbox/tests/data/annotation_import/test_label_import.py index b0d50ac5d..50b701813 100644 --- a/libs/labelbox/tests/data/annotation_import/test_label_import.py +++ b/libs/labelbox/tests/data/annotation_import/test_label_import.py @@ -10,68 +10,68 @@ """ -def test_create_with_url_arg(client, configured_project_with_one_data_row, +def test_create_with_url_arg(client, module_project, annotation_import_test_helpers): name = str(uuid.uuid4()) url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" label_import = LabelImport.create( client=client, - id=configured_project_with_one_data_row.uid, + id=module_project.uid, name=name, url=url) - assert label_import.parent_id == configured_project_with_one_data_row.uid + assert label_import.parent_id == module_project.uid annotation_import_test_helpers.check_running_state(label_import, name, url) -def test_create_from_url(client, configured_project_with_one_data_row, +def test_create_from_url(client, module_project, annotation_import_test_helpers): name = str(uuid.uuid4()) url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" label_import = LabelImport.create_from_url( client=client, - project_id=configured_project_with_one_data_row.uid, + project_id=module_project.uid, name=name, url=url) - assert label_import.parent_id == configured_project_with_one_data_row.uid + assert label_import.parent_id == module_project.uid annotation_import_test_helpers.check_running_state(label_import, name, url) -def test_create_with_labels_arg(client, configured_project, object_predictions, +def test_create_with_labels_arg(client, module_project, object_predictions, annotation_import_test_helpers): """this test should check running state only to validate running, not completed""" name = str(uuid.uuid4()) label_import = LabelImport.create(client=client, - id=configured_project.uid, + id=module_project.uid, name=name, labels=object_predictions) - assert label_import.parent_id == configured_project.uid + assert label_import.parent_id == module_project.uid annotation_import_test_helpers.check_running_state(label_import, name) annotation_import_test_helpers.assert_file_content( label_import.input_file_url, object_predictions) -def test_create_from_objects(client, configured_project, object_predictions, +def test_create_from_objects(client, module_project, object_predictions, annotation_import_test_helpers): """this test should check running state only to validate running, not completed""" name = str(uuid.uuid4()) label_import = LabelImport.create_from_objects( client=client, - project_id=configured_project.uid, + project_id=module_project.uid, name=name, labels=object_predictions) - assert label_import.parent_id == configured_project.uid + assert label_import.parent_id == module_project.uid annotation_import_test_helpers.check_running_state(label_import, name) annotation_import_test_helpers.assert_file_content( label_import.input_file_url, object_predictions) -def test_create_with_path_arg(client, tmp_path, configured_project, object_predictions, +def test_create_with_path_arg(client, tmp_path, module_project, object_predictions, annotation_import_test_helpers): - project = configured_project + project = module_project name = str(uuid.uuid4()) file_name = f"{name}.ndjson" file_path = tmp_path / file_name @@ -89,9 +89,9 @@ def test_create_with_path_arg(client, tmp_path, configured_project, object_predi label_import.input_file_url, object_predictions) -def test_create_from_local_file(client, tmp_path, configured_project, object_predictions, +def test_create_from_local_file(client, tmp_path, module_project, object_predictions, annotation_import_test_helpers): - project = configured_project + project = module_project name = str(uuid.uuid4()) file_name = f"{name}.ndjson" file_path = tmp_path / file_name @@ -109,26 +109,26 @@ def test_create_from_local_file(client, tmp_path, configured_project, object_pre label_import.input_file_url, object_predictions) -def test_get(client, configured_project_with_one_data_row, +def test_get(client, module_project, annotation_import_test_helpers): name = str(uuid.uuid4()) url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" label_import = LabelImport.create_from_url( client=client, - project_id=configured_project_with_one_data_row.uid, + project_id=module_project.uid, name=name, url=url) - assert label_import.parent_id == configured_project_with_one_data_row.uid + assert label_import.parent_id == module_project.uid annotation_import_test_helpers.check_running_state(label_import, name, url) @pytest.mark.slow -def test_wait_till_done(client, configured_project, predictions): +def test_wait_till_done(client, module_project, predictions): name = str(uuid.uuid4()) label_import = LabelImport.create_from_objects( client=client, - project_id=configured_project.uid, + project_id=module_project.uid, name=name, labels=predictions) diff --git a/libs/labelbox/tests/data/annotation_import/test_mal_prediction_import.py b/libs/labelbox/tests/data/annotation_import/test_mal_prediction_import.py index c42d4b8ce..c50c82315 100644 --- a/libs/labelbox/tests/data/annotation_import/test_mal_prediction_import.py +++ b/libs/labelbox/tests/data/annotation_import/test_mal_prediction_import.py @@ -1,5 +1,4 @@ import uuid -import pytest from labelbox import parser from labelbox.schema.annotation_import import MALPredictionImport @@ -9,31 +8,31 @@ """ -@pytest.mark.order(1) -def test_create_with_url_arg(client, configured_project_with_one_data_row, + +def test_create_with_url_arg(client, module_project, annotation_import_test_helpers): name = str(uuid.uuid4()) url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" label_import = MALPredictionImport.create( client=client, - id=configured_project_with_one_data_row.uid, + id=module_project.uid, name=name, url=url) - assert label_import.parent_id == configured_project_with_one_data_row.uid + assert label_import.parent_id == module_project.uid annotation_import_test_helpers.check_running_state(label_import, name, url) -def test_create_with_labels_arg(client, configured_project, object_predictions, +def test_create_with_labels_arg(client, module_project, object_predictions, annotation_import_test_helpers): """this test should check running state only to validate running, not completed""" name = str(uuid.uuid4()) label_import = MALPredictionImport.create(client=client, - id=configured_project.uid, + id=module_project.uid, name=name, labels=object_predictions) - assert label_import.parent_id == configured_project.uid + assert label_import.parent_id == module_project.uid annotation_import_test_helpers.check_running_state(label_import, name) annotation_import_test_helpers.assert_file_content( label_import.input_file_url, object_predictions) diff --git a/libs/labelbox/tests/data/annotation_import/test_mea_prediction_import.py b/libs/labelbox/tests/data/annotation_import/test_mea_prediction_import.py index 771744cf8..f2765fd3f 100644 --- a/libs/labelbox/tests/data/annotation_import/test_mea_prediction_import.py +++ b/libs/labelbox/tests/data/annotation_import/test_mea_prediction_import.py @@ -39,10 +39,10 @@ def test_create_from_objects_global_key(client, model_run_with_data_rows, polygon_inference, annotation_import_test_helpers): name = str(uuid.uuid4()) - dr = client.get_data_row(polygon_inference['dataRow']['id']) - del polygon_inference['dataRow']['id'] - polygon_inference['dataRow']['globalKey'] = dr.global_key - object_predictions = [polygon_inference] + dr = client.get_data_row(polygon_inference[0]['dataRow']['id']) + polygon_inference[0]['dataRow']['globalKey'] = dr.global_key + del polygon_inference[0]['dataRow']['id'] + object_predictions = [polygon_inference[0]] annotation_import = model_run_with_data_rows.add_predictions( name=name, predictions=object_predictions) @@ -63,13 +63,13 @@ def test_create_from_objects_with_confidence(predictions_with_confidence, annotation_import_test_helpers): name = str(uuid.uuid4()) - object_prediction_data_rows = [ + object_prediction_data_rows = set([ object_prediction["dataRow"]["id"] for object_prediction in predictions_with_confidence - ] + ]) # MUST have all data rows in the model run model_run_with_data_rows.upsert_data_rows( - data_row_ids=object_prediction_data_rows) + data_row_ids=list(object_prediction_data_rows)) annotation_import = model_run_with_data_rows.add_predictions( name=name, predictions=predictions_with_confidence) @@ -110,8 +110,8 @@ def test_create_from_objects_all_project_labels( def test_model_run_project_labels(model_run_with_all_project_labels: ModelRun, model_run_predictions): + model_run = model_run_with_all_project_labels - export_task = model_run.export() export_task.wait_till_done() stream = export_task.get_buffered_stream() @@ -122,24 +122,25 @@ def test_model_run_project_labels(model_run_with_all_project_labels: ModelRun, data_row.json["experiments"][model_run.model_id]["runs"][model_run.uid]["labels"][0]) for data_row in stream] - labels_indexed_by_schema_id = {} + labels_indexed_by_name = {} + # making sure the labels are in this model run are all labels uploaded to the project + # by comparing some 'immutable' attributes + # multiple data rows per prediction import for data_row_id, label in model_run_exported_labels: - # assuming exported array of label 'objects' has only one label per data row... as usually is when there are no label revisions - schema_id = label["annotations"]["objects"][0]["feature_schema_id"] - labels_indexed_by_schema_id[schema_id] = {"label": label, "data_row_id": data_row_id} - + for object in label["annotations"]["objects"]: + name = object["name"] + labels_indexed_by_name[f"{name}-{data_row_id}"] = {"label": label, "data_row_id": data_row_id} + assert (len( - labels_indexed_by_schema_id.keys())) == len(model_run_predictions) + labels_indexed_by_name.keys())) == len([prediction["dataRow"]["id"] for prediction in model_run_predictions]) + + expected_data_row_ids = set([prediction["dataRow"]["id"] for prediction in model_run_predictions]) + expected_objects = set([prediction["name"] for prediction in model_run_predictions]) + for data_row_id, actual_label in model_run_exported_labels: + assert data_row_id in expected_data_row_ids + assert len(expected_objects) == len(actual_label["annotations"]["objects"]) - # making sure the labels are in this model run are all labels uploaded to the project - # by comparing some 'immutable' attributes - for expected_label in model_run_predictions: - schema_id = expected_label["schemaId"] - actual_label = labels_indexed_by_schema_id[schema_id] - assert actual_label["label"]["annotations"]["objects"][0]["name"] == expected_label[ - 'name'] - assert actual_label["data_row_id"] == expected_label["dataRow"]["id"] def test_create_from_label_objects(model_run_with_data_rows, diff --git a/libs/labelbox/tests/data/annotation_import/test_model.py b/libs/labelbox/tests/data/annotation_import/test_model.py index 131ecd9d0..dcfe9ef2c 100644 --- a/libs/labelbox/tests/data/annotation_import/test_model.py +++ b/libs/labelbox/tests/data/annotation_import/test_model.py @@ -4,14 +4,14 @@ from labelbox.exceptions import ResourceNotFoundError -def test_model(client, configured_project_with_one_data_row, rand_gen): +def test_model(client, configured_project, rand_gen): # Get all models = list(client.get_models()) for m in models: assert isinstance(m, Model) # Create - ontology = configured_project_with_one_data_row.ontology() + ontology = configured_project.ontology() data = {"name": rand_gen(str), "ontology_id": ontology.uid} model = client.create_model(data["name"], data["ontology_id"]) assert model.name == data["name"] diff --git a/libs/labelbox/tests/data/annotation_import/test_model_run.py b/libs/labelbox/tests/data/annotation_import/test_model_run.py index 6a7b0e0d5..bf30ed169 100644 --- a/libs/labelbox/tests/data/annotation_import/test_model_run.py +++ b/libs/labelbox/tests/data/annotation_import/test_model_run.py @@ -88,11 +88,11 @@ def test_model_run_data_rows_delete(model_run_with_data_rows): def test_model_run_upsert_data_rows(dataset, model_run, - configured_project_with_one_data_row): + configured_project): n_model_run_data_rows = len(list(model_run.model_run_data_rows())) assert n_model_run_data_rows == 0 data_row = dataset.create_data_row(row_data="test row data") - configured_project_with_one_data_row._wait_until_data_rows_are_processed( + configured_project._wait_until_data_rows_are_processed( data_row_ids=[data_row.uid]) model_run.upsert_data_rows([data_row.uid]) n_model_run_data_rows = len(list(model_run.model_run_data_rows())) @@ -119,11 +119,6 @@ def test_model_run_upsert_data_rows_with_existing_labels( assert n_data_rows == len( list(model_run_with_data_rows.model_run_data_rows())) -@pytest.mark.export_v1("tests used export v1 method, v2 test -> test_import_data_types_v2 below") -def test_model_run_export_labels(model_run_with_data_rows): - labels = model_run_with_data_rows.export_labels(download=True) - assert len(labels) == 3 - @pytest.mark.skipif(condition=os.environ['LABELBOX_TEST_ENVIRON'] == "onprem", reason="does not work for onprem") diff --git a/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py b/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py index fba161ef3..ac197a321 100644 --- a/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py +++ b/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py @@ -1,4 +1,5 @@ from labelbox.schema.media_type import MediaType +from labelbox.schema.project import Project import pytest from labelbox import parser @@ -10,50 +11,17 @@ NDRadio, NDRectangle, NDText, NDTextEntity, NDTool, _validate_ndjson) -from labelbox.schema.labeling_frontend import LabelingFrontend -from labelbox.schema.queue_mode import QueueMode - - -@pytest.fixture(scope="module", autouse=True) -def hardcoded_datarow_id(): - data_row_id = 'ck8q9q9qj00003g5z3q1q9q9q' - - def get_data_row_id(indx=0): - return data_row_id - - yield get_data_row_id - - -@pytest.fixture(scope="module", autouse=True) -def configured_project_with_ontology(client, ontology, rand_gen): - project = client.create_project( - name=rand_gen(str), - queue_mode=QueueMode.Batch, - media_type=MediaType.Image, - ) - editor = list( - client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] - project.setup(editor, ontology) - - yield project - - project.delete() - +""" +- These NDlabels are apart of bulkImportReqeust and should be removed once bulk import request is removed +""" def test_classification_construction(checklist_inference, text_inference): - checklist = NDClassification.build(checklist_inference) + checklist = NDClassification.build(checklist_inference[0]) assert isinstance(checklist, NDChecklist) - text = NDClassification.build(text_inference) + text = NDClassification.build(text_inference[0]) assert isinstance(text, NDText) -def test_subclassification_construction(rectangle_inference): - tool = NDTool.build(rectangle_inference) - assert len(tool.classifications) == 1, "Subclass was not constructed" - assert isinstance(tool.classifications[0], NDRadio) - - @parametrize("inference, expected_type", [(fixture_ref('polygon_inference'), NDPolygon), (fixture_ref('rectangle_inference'), NDRectangle), @@ -63,198 +31,176 @@ def test_subclassification_construction(rectangle_inference): (fixture_ref('segmentation_inference_rle'), NDMask), (fixture_ref('segmentation_inference_png'), NDMask)]) def test_tool_construction(inference, expected_type): - assert isinstance(NDTool.build(inference), expected_type) - - -def test_incorrect_feature_schema(rectangle_inference, polygon_inference, - configured_project_with_ontology): - #Valid but incorrect feature schema - #Prob the error message says something about the config not anything useful. We might want to fix this. - pred = rectangle_inference.copy() - pred['schemaId'] = polygon_inference['schemaId'] - with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) + assert isinstance(NDTool.build(inference[0]), expected_type) -def no_tool(text_inference, configured_project_with_ontology): - pred = text_inference.copy() +def no_tool(text_inference, module_project): + pred = text_inference[0].copy() #Missing key del pred['answer'] with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) - + _validate_ndjson([pred], module_project) -def test_invalid_text(text_inference, configured_project_with_ontology): +@pytest.mark.parametrize( + "configured_project", + [MediaType.Text], + indirect=True +) +def test_invalid_text(text_inference, configured_project): #and if it is not a string - pred = text_inference.copy() + pred = text_inference[0].copy() #Extra and wrong key del pred['answer'] pred['answers'] = [] with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) + _validate_ndjson([pred], configured_project) del pred['answers'] #Invalid type pred['answer'] = [] with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) + _validate_ndjson([pred], configured_project) #Invalid type pred['answer'] = None with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) + _validate_ndjson([pred], configured_project) def test_invalid_checklist_item(checklist_inference, - configured_project_with_ontology): + module_project): #Only two points - pred = checklist_inference.copy() + pred = checklist_inference[0].copy() pred['answers'] = [pred['answers'][0], pred['answers'][0]] #Duplicate schema ids with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) + _validate_ndjson([pred], module_project) pred['answers'] = [{"name": "asdfg"}] with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) + _validate_ndjson([pred], module_project) pred['answers'] = [{"schemaId": "1232132132"}] with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) + _validate_ndjson([pred], module_project) pred['answers'] = [{}] with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) + _validate_ndjson([pred], module_project) pred['answers'] = [] with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) + _validate_ndjson([pred], module_project) del pred['answers'] with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) + _validate_ndjson([pred], module_project) -def test_invalid_polygon(polygon_inference, configured_project_with_ontology): +def test_invalid_polygon(polygon_inference, module_project): #Only two points - pred = polygon_inference.copy() + pred = polygon_inference[0].copy() pred['polygon'] = [{"x": 100, "y": 100}, {"x": 200, "y": 200}] with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) + _validate_ndjson([pred], module_project) -def test_incorrect_entity(entity_inference, configured_project_with_ontology): - entity = entity_inference.copy() +@pytest.mark.parametrize( + "configured_project", + [MediaType.Text], + indirect=True +) +def test_incorrect_entity(entity_inference, configured_project): + entity = entity_inference[0].copy() #Location cannot be a list entity["location"] = [0, 10] with pytest.raises(MALValidationError): - _validate_ndjson([entity], configured_project_with_ontology) + _validate_ndjson([entity], configured_project) entity["location"] = {"start": -1, "end": 5} with pytest.raises(MALValidationError): - _validate_ndjson([entity], configured_project_with_ontology) + _validate_ndjson([entity], configured_project) entity["location"] = {"start": 15, "end": 5} with pytest.raises(MALValidationError): - _validate_ndjson([entity], configured_project_with_ontology) - - -def test_incorrect_mask(segmentation_inference, - configured_project_with_ontology): - seg = segmentation_inference.copy() - seg['mask']['colorRGB'] = [-1, 0, 10] - with pytest.raises(MALValidationError): - _validate_ndjson([seg], configured_project_with_ontology) - - seg['mask']['colorRGB'] = [0, 0] - with pytest.raises(MALValidationError): - _validate_ndjson([seg], configured_project_with_ontology) - - seg['mask'] = {'counts': [0], 'size': [0, 1]} - with pytest.raises(MALValidationError): - _validate_ndjson([seg], configured_project_with_ontology) - - seg['mask'] = {'counts': [-1], 'size': [1, 1]} - with pytest.raises(MALValidationError): - _validate_ndjson([seg], configured_project_with_ontology) + _validate_ndjson([entity], configured_project) -def test_all_validate_json(configured_project_with_ontology, predictions): +@pytest.mark.skip("Test wont work/fails randomly since projects have to have a media type and could be missing features from prediction list") +def test_all_validate_json(module_project, predictions): #Predictions contains one of each type of prediction. #These should be properly formatted and pass. - _validate_ndjson(predictions, configured_project_with_ontology) + _validate_ndjson(predictions[0], module_project) -def test_incorrect_line(line_inference, configured_project_with_ontology): - line = line_inference.copy() +def test_incorrect_line(line_inference, module_project): + line = line_inference[0].copy() line["line"] = [line["line"][0]] #Just one point with pytest.raises(MALValidationError): - _validate_ndjson([line], configured_project_with_ontology) + _validate_ndjson([line], module_project) def test_incorrect_rectangle(rectangle_inference, - configured_project_with_ontology): - del rectangle_inference['bbox']['top'] + module_project): + del rectangle_inference[0]['bbox']['top'] with pytest.raises(MALValidationError): _validate_ndjson([rectangle_inference], - configured_project_with_ontology) + module_project) -def test_duplicate_tools(rectangle_inference, configured_project_with_ontology): - #Trying to upload a polygon and rectangle at the same time - pred = rectangle_inference.copy() +def test_duplicate_tools(rectangle_inference, module_project): + pred = rectangle_inference[0].copy() pred['polygon'] = [{"x": 100, "y": 100}, {"x": 200, "y": 200}] with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) + _validate_ndjson([pred], module_project) -def test_invalid_feature_schema(configured_project_with_ontology, +def test_invalid_feature_schema(module_project, rectangle_inference): - #Trying to upload a polygon and rectangle at the same time - pred = rectangle_inference.copy() + pred = rectangle_inference[0].copy() pred['schemaId'] = "blahblah" with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) + _validate_ndjson([pred], module_project) -def test_name_only_feature_schema(configured_project_with_ontology, +def test_name_only_feature_schema(module_project, rectangle_inference): - #Trying to upload a polygon and rectangle at the same time - pred = rectangle_inference.copy() - del pred['schemaId'] - _validate_ndjson([pred], configured_project_with_ontology) + pred = rectangle_inference[0].copy() + _validate_ndjson([pred], module_project) -def test_schema_id_only_feature_schema(configured_project_with_ontology, +def test_schema_id_only_feature_schema(module_project, rectangle_inference): - #Trying to upload a polygon and rectangle at the same time - pred = rectangle_inference.copy() + pred = rectangle_inference[0].copy() del pred['name'] - _validate_ndjson([pred], configured_project_with_ontology) + ontology = module_project.ontology().normalized["tools"] + for tool in ontology: + if tool["name"] == "bbox": + feature_schema_id = tool["featureSchemaId"] + pred["schemaId"] = feature_schema_id + _validate_ndjson([pred], module_project) -def test_missing_feature_schema(configured_project_with_ontology, +def test_missing_feature_schema(module_project, rectangle_inference): - #Trying to upload a polygon and rectangle at the same time - pred = rectangle_inference.copy() - del pred['schemaId'] + pred = rectangle_inference[0].copy() del pred['name'] with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) + _validate_ndjson([pred], module_project) -def test_validate_ndjson(tmp_path, configured_project_with_ontology): +def test_validate_ndjson(tmp_path, configured_project): file_name = f"broken.ndjson" file_path = tmp_path / file_name with file_path.open("w") as f: f.write("test") with pytest.raises(ValueError): - configured_project_with_ontology.upload_annotations( + configured_project.upload_annotations( name="name", annotations=str(file_path), validate=True) -def test_validate_ndjson_uuid(tmp_path, configured_project_with_ontology, +def test_validate_ndjson_uuid(tmp_path, configured_project, predictions): file_name = f"repeat_uuid.ndjson" file_path = tmp_path / file_name @@ -266,15 +212,16 @@ def test_validate_ndjson_uuid(tmp_path, configured_project_with_ontology, parser.dump(repeat_uuid, f) with pytest.raises(MALValidationError): - configured_project_with_ontology.upload_annotations( + configured_project.upload_annotations( name="name", validate=True, annotations=str(file_path)) with pytest.raises(MALValidationError): - configured_project_with_ontology.upload_annotations( + configured_project.upload_annotations( name="name", validate=True, annotations=repeat_uuid) +@pytest.mark.parametrize("configured_project", [MediaType.Video], indirect=True) def test_video_upload(video_checklist_inference, - configured_project_with_ontology): - pred = video_checklist_inference.copy() - _validate_ndjson([pred], configured_project_with_ontology) + configured_project): + pred = video_checklist_inference[0].copy() + _validate_ndjson([pred], configured_project) diff --git a/libs/labelbox/tests/data/annotation_import/test_send_to_annotate_mea.py b/libs/labelbox/tests/data/annotation_import/test_send_to_annotate_mea.py index b96eb31b6..1f8b84742 100644 --- a/libs/labelbox/tests/data/annotation_import/test_send_to_annotate_mea.py +++ b/libs/labelbox/tests/data/annotation_import/test_send_to_annotate_mea.py @@ -8,7 +8,7 @@ def test_send_to_annotate_from_model(client, configured_project, model_run_predictions, model_run_with_data_rows, project): model_run = model_run_with_data_rows - data_row_ids = [p['dataRow']['id'] for p in model_run_predictions] + data_row_ids = list(set([p['dataRow']['id'] for p in model_run_predictions])) assert len(data_row_ids) > 0 destination_project = project diff --git a/libs/labelbox/tests/data/annotation_import/test_upsert_prediction_import.py b/libs/labelbox/tests/data/annotation_import/test_upsert_prediction_import.py index 55f227315..59c894c65 100644 --- a/libs/labelbox/tests/data/annotation_import/test_upsert_prediction_import.py +++ b/libs/labelbox/tests/data/annotation_import/test_upsert_prediction_import.py @@ -11,7 +11,7 @@ @pytest.mark.skip() def test_create_from_url(client, tmp_path, object_predictions, model_run_with_data_rows, - configured_project_with_one_data_row, + configured_project, annotation_import_test_helpers): name = str(uuid.uuid4()) file_name = f"{name}.json" @@ -39,7 +39,7 @@ def test_create_from_url(client, tmp_path, object_predictions, annotation_import, batch, mal_prediction_import = model_run_with_data_rows.upsert_predictions_and_send_to_project( name=name, predictions=url, - project_id=configured_project_with_one_data_row.uid, + project_id=configured_project.uid, priority=5) assert annotation_import.model_run_id == model_run_with_data_rows.uid @@ -48,7 +48,7 @@ def test_create_from_url(client, tmp_path, object_predictions, assert annotation_import.statuses assert batch - assert batch.project().uid == configured_project_with_one_data_row.uid + assert batch.project().uid == configured_project.uid assert mal_prediction_import mal_prediction_import.wait_until_done() @@ -59,7 +59,7 @@ def test_create_from_url(client, tmp_path, object_predictions, @pytest.mark.skip() def test_create_from_objects(model_run_with_data_rows, - configured_project_with_one_data_row, + configured_project, object_predictions, annotation_import_test_helpers): name = str(uuid.uuid4()) @@ -74,7 +74,7 @@ def test_create_from_objects(model_run_with_data_rows, annotation_import, batch, mal_prediction_import = model_run_with_data_rows.upsert_predictions_and_send_to_project( name=name, predictions=predictions, - project_id=configured_project_with_one_data_row.uid, + project_id=configured_project.uid, priority=5) assert annotation_import.model_run_id == model_run_with_data_rows.uid @@ -83,7 +83,7 @@ def test_create_from_objects(model_run_with_data_rows, assert annotation_import.statuses assert batch - assert batch.project().uid == configured_project_with_one_data_row.uid + assert batch.project().uid == configured_project.uid assert mal_prediction_import mal_prediction_import.wait_until_done()