From 51ecfeab2efa15402d949b5799e21f77ea26ee95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20J=C3=B3=C5=BAwiak?= Date: Mon, 9 Sep 2024 15:24:35 +0200 Subject: [PATCH] [PTDT-2553] Added integration tests for MMC MAL/GT imports --- .../tests/data/annotation_import/conftest.py | 495 +++++++++++++++++- .../test_generic_data_types.py | 6 + 2 files changed, 500 insertions(+), 1 deletion(-) diff --git a/libs/labelbox/tests/data/annotation_import/conftest.py b/libs/labelbox/tests/data/annotation_import/conftest.py index 6543f54bf..2342a759a 100644 --- a/libs/labelbox/tests/data/annotation_import/conftest.py +++ b/libs/labelbox/tests/data/annotation_import/conftest.py @@ -1,4 +1,5 @@ import uuid +from typing import Union from labelbox.schema.model_run import ModelRun from labelbox.schema.ontology import Ontology @@ -152,6 +153,22 @@ def llm_human_preference_data_row(global_key): return llm_human_preference_data_row +@pytest.fixture(scope="module") +def mmc_data_row_url(): + return "https://storage.googleapis.com/labelbox-datasets/conversational_model_evaluation_sample/offline-model-chat-evaluation.json" + + +@pytest.fixture(scope="module", autouse=True) +def offline_model_evaluation_data_row_factory(mmc_data_row_url: str): + def offline_model_evaluation_data_row(global_key: str): + return { + "row_data": mmc_data_row_url, + "global_key": global_key, + } + + return offline_model_evaluation_data_row + + @pytest.fixture(scope="module", autouse=True) def data_row_json_by_media_type( audio_data_row_factory, @@ -163,6 +180,7 @@ def data_row_json_by_media_type( document_data_row_factory, text_data_row_factory, video_data_row_factory, + offline_model_evaluation_data_row_factory, ): return { MediaType.Audio: audio_data_row_factory, @@ -174,6 +192,7 @@ def data_row_json_by_media_type( MediaType.Document: document_data_row_factory, MediaType.Text: text_data_row_factory, MediaType.Video: video_data_row_factory, + OntologyKind.ModelEvaluation: offline_model_evaluation_data_row_factory, } @@ -345,6 +364,26 @@ def normalized_ontology_by_media_type(): ], } + radio_index = { + "required": False, + "instructions": "radio_index", + "name": "radio_index", + "type": "radio", + "scope": "index", + "options": [ + { + "label": "first_radio_answer", + "value": "first_radio_answer", + "options": [], + }, + { + "label": "second_radio_answer", + "value": "second_radio_answer", + "options": [], + }, + ], + } + prompt_text = { "instructions": "prompt-text", "name": "prompt-text", @@ -403,6 +442,27 @@ def normalized_ontology_by_media_type(): "type": "response-text", } + message_single_selection_task = { + "required": False, + "name": "message-single-selection", + "tool": "message-single-selection", + "classifications": [], + } + + message_multi_selection_task = { + "required": False, + "name": "message-multi-selection", + "tool": "message-multi-selection", + "classifications": [], + } + + message_ranking_task = { + "required": False, + "name": "message-ranking", + "tool": "message-ranking", + "classifications": [], + } + return { MediaType.Image: { "tools": [ @@ -516,6 +576,21 @@ def normalized_ontology_by_media_type(): response_checklist, ], }, + OntologyKind.ModelEvaluation: { + "tools": [ + message_single_selection_task, + message_multi_selection_task, + message_ranking_task, + ], + "classifications": [ + radio, + checklist, + free_form_text, + radio_index, + checklist_index, + free_form_text_index, + ], + }, "all": { "tools": [ bbox_tool, @@ -695,6 +770,45 @@ def _create_prompt_response_project( return prompt_response_project, ontology +def _create_offline_mmc_project( + client: Client, rand_gen, data_row_json, normalized_ontology +) -> Tuple[Project, Ontology, Dataset]: + dataset = client.create_dataset(name=rand_gen(str)) + + project = client.create_offline_model_evaluation_project( + name=f"offline-mmc-{rand_gen(str)}", + ) + + ontology = client.create_ontology( + name=f"offline-mmc-{rand_gen(str)}", + normalized=normalized_ontology, + media_type=MediaType.Conversational, + ontology_kind=OntologyKind.ModelEvaluation, + ) + + project.connect_ontology(ontology) + + data_row_data = [ + data_row_json(rand_gen(str)) for _ in range(DATA_ROW_COUNT) + ] + + task = dataset.create_data_rows(data_row_data) + task.wait_till_done() + global_keys = [row["global_key"] for row in task.result] + data_row_ids = [row["id"] for row in task.result] + + project.create_batch( + rand_gen(str), + data_row_ids, # sample of data row objects + 5, # priority between 1(Highest) - 5(lowest) + ) + project.data_row_ids = data_row_ids + project.data_row_data = data_row_data + project.global_keys = global_keys + + return project, ontology, dataset + + def _create_project( client: Client, rand_gen, @@ -753,7 +867,10 @@ def configured_project( ): """Configure project for test. Request.param will contain the media type if not present will use Image MediaType. The project will have 10 data rows.""" - media_type = getattr(request, "param", MediaType.Image) + media_type: Union[MediaType, OntologyKind] = getattr( + request, "param", MediaType.Image + ) + dataset = None if ( @@ -776,6 +893,13 @@ def configured_project( media_type, normalized_ontology_by_media_type, ) + elif media_type == OntologyKind.ModelEvaluation: + project, ontology, dataset = _create_offline_mmc_project( + client, + rand_gen, + data_row_json_by_media_type[media_type], + normalized_ontology_by_media_type[media_type], + ) else: project, ontology, dataset = _create_project( client, @@ -827,6 +951,13 @@ def configured_project_by_global_key( media_type, normalized_ontology_by_media_type, ) + elif media_type == OntologyKind.ModelEvaluation: + project, ontology, dataset = _create_offline_mmc_project( + client, + rand_gen, + data_row_json_by_media_type[media_type], + normalized_ontology_by_media_type[media_type], + ) else: project, ontology, dataset = _create_project( client, @@ -988,6 +1119,31 @@ def prediction_id_mapping(request, normalized_ontology_by_media_type): return base_annotations +@pytest.fixture +def mmc_example_data_row_message_ids(mmc_data_row_url: str): + data_row_content = requests.get(mmc_data_row_url).json() + + human_id = next( + actor_id + for actor_id, actor_metadata in data_row_content["actors"].items() + if actor_metadata["role"] == "human" + ) + + return { + message_id: [ + { + "id": child_msg_id, + "model_config_name": data_row_content["actors"][ + data_row_content["messages"][child_msg_id]["actorId"] + ]["metadata"]["modelConfigName"], + } + for child_msg_id in message_metadata["childMessageIds"] + ] + for message_id, message_metadata in data_row_content["messages"].items() + if message_metadata["actorId"] == human_id + } + + # Each inference represents a feature type that adds to the base annotation created with prediction_id_mapping @pytest.fixture def polygon_inference(prediction_id_mapping): @@ -1303,6 +1459,31 @@ def checklist_inference_index(prediction_id_mapping): return checklists +@pytest.fixture +def checklist_inference_index_mmc( + prediction_id_mapping, mmc_example_data_row_message_ids +): + checklists = [] + for feature in prediction_id_mapping: + if "checklist_index" not in feature: + return None + checklist = feature["checklist_index"].copy() + checklist.update( + { + "answers": [ + {"name": "first_checklist_answer"}, + {"name": "second_checklist_answer"}, + ], + "messageId": next( + iter(mmc_example_data_row_message_ids.keys()) + ), + } + ) + del checklist["tool"] + checklists.append(checklist) + return checklists + + @pytest.fixture def prompt_text_inference(prediction_id_mapping): prompt_texts = [] @@ -1333,6 +1514,45 @@ def radio_response_inference(prediction_id_mapping): return response_radios +@pytest.fixture +def radio_inference(prediction_id_mapping): + radios = [] + for feature in prediction_id_mapping: + if "radio" not in feature: + continue + radio = feature["radio"].copy() + radio.update( + { + "answer": {"name": "first_radio_answer"}, + } + ) + del radio["tool"] + radios.append(radio) + return radios + + +@pytest.fixture +def radio_inference_index_mmc( + prediction_id_mapping, mmc_example_data_row_message_ids +): + radios = [] + for feature in prediction_id_mapping: + if "radio_index" not in feature: + continue + radio = feature["radio_index"].copy() + radio.update( + { + "answer": {"name": "first_radio_answer"}, + "messageId": next( + iter(mmc_example_data_row_message_ids.keys()) + ), + } + ) + del radio["tool"] + radios.append(radio) + return radios + + @pytest.fixture def checklist_response_inference(prediction_id_mapping): response_checklists = [] @@ -1402,6 +1622,28 @@ def text_inference_index(prediction_id_mapping): return texts +@pytest.fixture +def text_inference_index_mmc( + prediction_id_mapping, mmc_example_data_row_message_ids +): + texts = [] + for feature in prediction_id_mapping: + if "text_index" not in feature: + continue + text = feature["text_index"].copy() + text.update( + { + "answer": "free form text...", + "messageId": next( + iter(mmc_example_data_row_message_ids.keys()) + ), + } + ) + del text["tool"] + texts.append(text) + return texts + + @pytest.fixture def video_checklist_inference(prediction_id_mapping): checklists = [] @@ -1437,6 +1679,118 @@ def video_checklist_inference(prediction_id_mapping): return checklists +@pytest.fixture +def message_single_selection_inference( + prediction_id_mapping, mmc_example_data_row_message_ids +): + some_parent_id, some_child_ids = next( + iter(mmc_example_data_row_message_ids.items()) + ) + + res = [] + for feature in prediction_id_mapping: + if "message-single-selection" not in feature: + continue + selection = feature["message-single-selection"].copy() + selection.update( + { + "messageEvaluationTask": { + "format": "message-single-selection", + "data": { + "messageId": some_child_ids[0]["id"], + "parentMessageId": some_parent_id, + "modelConfigName": some_child_ids[0][ + "model_config_name" + ], + }, + } + } + ) + del selection["tool"] + res.append(selection) + + return res + + +@pytest.fixture +def message_multi_selection_inference( + prediction_id_mapping, mmc_example_data_row_message_ids +): + some_parent_id, some_child_ids = next( + iter(mmc_example_data_row_message_ids.items()) + ) + + res = [] + for feature in prediction_id_mapping: + if "message-multi-selection" not in feature: + continue + selection = feature["message-multi-selection"].copy() + selection.update( + { + "messageEvaluationTask": { + "format": "message-multi-selection", + "data": { + "parentMessageId": some_parent_id, + "selectedMessages": [ + { + "messageId": child_id["id"], + "modelConfigName": child_id[ + "model_config_name" + ], + } + for child_id in some_child_ids + ], + }, + } + } + ) + del selection["tool"] + res.append(selection) + + return res + + +@pytest.fixture +def message_ranking_inference( + prediction_id_mapping, mmc_example_data_row_message_ids +): + some_parent_id, some_child_ids = next( + iter(mmc_example_data_row_message_ids.items()) + ) + + res = [] + for feature in prediction_id_mapping: + if "message-ranking" not in feature: + continue + selection = feature["message-ranking"].copy() + selection.update( + { + "messageEvaluationTask": { + "format": "message-ranking", + "data": { + "parentMessageId": some_parent_id, + "rankedMessages": [ + { + "messageId": child_id["id"], + "modelConfigName": child_id[ + "model_config_name" + ], + "order": idx, + } + for idx, child_id in enumerate( + some_child_ids, start=1 + ) + ], + }, + } + } + ) + del selection["tool"] + res.append(selection) + + return res + + @pytest.fixture def annotations_by_media_type( polygon_inference, @@ -1456,6 +1810,13 @@ def annotations_by_media_type( checklist_response_inference, radio_response_inference, text_response_inference, + message_single_selection_inference, + message_multi_selection_inference, + message_ranking_inference, + checklist_inference_index_mmc, + radio_inference, + radio_inference_index_mmc, + text_inference_index_mmc, ): return { MediaType.Audio: [checklist_inference, text_inference], @@ -1493,6 +1854,17 @@ def annotations_by_media_type( checklist_response_inference, radio_response_inference, ], + OntologyKind.ModelEvaluation: [ + message_single_selection_inference, + message_multi_selection_inference, + message_ranking_inference, + radio_inference, + checklist_inference, + text_inference, + radio_inference_index_mmc, + checklist_inference_index_mmc, + text_inference_index_mmc, + ], } @@ -2162,6 +2534,125 @@ def expected_export_v2_llm_response_creation(): return expected_annotations +@pytest.fixture +def expected_exports_v2_mmc(mmc_example_data_row_message_ids): + some_parent_id, some_child_ids = next( + iter(mmc_example_data_row_message_ids.items()) + ) + + return { + "objects": [ + { + "name": "message-single-selection", + "annotation_kind": "MessageSingleSelection", + "classifications": [], + "selected_message": { + "message_id": some_child_ids[0]["id"], + "model_config_name": some_child_ids[0]["model_config_name"], + "parent_message_id": some_parent_id, + }, + }, + { + "name": "message-multi-selection", + "annotation_kind": "MessageMultiSelection", + "classifications": [], + "selected_messages": { + "messages": [ + { + "message_id": child_id["id"], + "model_config_name": child_id["model_config_name"], + } + for child_id in some_child_ids + ], + "parent_message_id": some_parent_id, + }, + }, + { + "name": "message-ranking", + "annotation_kind": "MessageRanking", + "classifications": [], + "ranked_messages": { + "ranked_messages": [ + { + "message_id": child_id["id"], + "model_config_name": child_id["model_config_name"], + "order": idx, + } + for idx, child_id in enumerate(some_child_ids, start=1) + ], + "parent_message_id": some_parent_id, + }, + }, + ], + "classifications": [ + { + "name": "radio", + "value": "radio", + "radio_answer": { + "name": "first_radio_answer", + "value": "first_radio_answer", + "classifications": [], + }, + }, + { + "name": "checklist", + "value": "checklist", + "checklist_answers": [ + { + "name": "first_checklist_answer", + "value": "first_checklist_answer", + "classifications": [], + }, + { + "name": "second_checklist_answer", + "value": "second_checklist_answer", + "classifications": [], + }, + ], + }, + { + "name": "text", + "value": "text", + "text_answer": {"content": "free form text..."}, + }, + { + "name": "radio_index", + "value": "radio_index", + "message_id": some_parent_id, + "conversational_radio_answer": { + "name": "first_radio_answer", + "value": "first_radio_answer", + "classifications": [], + }, + }, + { + "name": "checklist_index", + "value": "checklist_index", + "message_id": some_parent_id, + "conversational_checklist_answers": [ + { + "name": "first_checklist_answer", + "value": "first_checklist_answer", + "classifications": [], + }, + { + "name": "second_checklist_answer", + "value": "second_checklist_answer", + "classifications": [], + }, + ], + }, + { + "name": "text_index", + "value": "text_index", + "message_id": some_parent_id, + "conversational_text_answer": {"content": "free form text..."}, + }, + ], + "relationships": [], + } + + @pytest.fixture def exports_v2_by_media_type( expected_export_v2_image, @@ -2175,6 +2666,7 @@ def exports_v2_by_media_type( expected_export_v2_llm_prompt_response_creation, expected_export_v2_llm_prompt_creation, expected_export_v2_llm_response_creation, + expected_exports_v2_mmc, ): return { MediaType.Image: expected_export_v2_image, @@ -2188,6 +2680,7 @@ def exports_v2_by_media_type( MediaType.LLMPromptResponseCreation: expected_export_v2_llm_prompt_response_creation, MediaType.LLMPromptCreation: expected_export_v2_llm_prompt_creation, OntologyKind.ResponseCreation: expected_export_v2_llm_response_creation, + OntologyKind.ModelEvaluation: expected_exports_v2_mmc, } diff --git a/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py b/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py index f8f0c449a..9de67bd4e 100644 --- a/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py +++ b/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py @@ -41,6 +41,7 @@ def validate_iso_format(date_string: str): (MediaType.LLMPromptResponseCreation, GenericDataRowData), (MediaType.LLMPromptCreation, GenericDataRowData), (OntologyKind.ResponseCreation, GenericDataRowData), + (OntologyKind.ModelEvaluation, GenericDataRowData), ], ) def test_generic_data_row_type_by_data_row_id( @@ -76,6 +77,7 @@ def test_generic_data_row_type_by_data_row_id( # (MediaType.LLMPromptResponseCreation, GenericDataRowData), # (MediaType.LLMPromptCreation, GenericDataRowData), (OntologyKind.ResponseCreation, GenericDataRowData), + (OntologyKind.ModelEvaluation, GenericDataRowData), ], ) def test_generic_data_row_type_by_global_key( @@ -115,6 +117,7 @@ def test_generic_data_row_type_by_global_key( ), (MediaType.LLMPromptCreation, MediaType.LLMPromptCreation), (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation), + (OntologyKind.ModelEvaluation, OntologyKind.ModelEvaluation), ], indirect=["configured_project"], ) @@ -191,6 +194,7 @@ def test_import_media_types( (MediaType.Document, MediaType.Document), (MediaType.Dicom, MediaType.Dicom), (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation), + (OntologyKind.ModelEvaluation, OntologyKind.ModelEvaluation), ], indirect=["configured_project_by_global_key"], ) @@ -275,6 +279,7 @@ def test_import_media_types_by_global_key( ), (MediaType.LLMPromptCreation, MediaType.LLMPromptCreation), (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation), + (OntologyKind.ModelEvaluation, OntologyKind.ModelEvaluation), ], indirect=["configured_project"], ) @@ -309,6 +314,7 @@ def test_import_mal_annotations( (MediaType.Document, MediaType.Document), (MediaType.Dicom, MediaType.Dicom), (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation), + (OntologyKind.ModelEvaluation, OntologyKind.ModelEvaluation), ], indirect=["configured_project_by_global_key"], )