From a4fe5b936277a36c24ac44e79a1b0d0b97b60d91 Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Tue, 1 Oct 2024 09:52:00 -0700 Subject: [PATCH 1/3] Refactor get datarows tests so that I can reuse some fixtures --- libs/labelbox/src/labelbox/schema/dataset.py | 38 ++---- libs/labelbox/tests/integration/conftest.py | 56 +++++---- .../tests/integration/test_data_rows.py | 118 +++++++++--------- 3 files changed, 105 insertions(+), 107 deletions(-) diff --git a/libs/labelbox/src/labelbox/schema/dataset.py b/libs/labelbox/src/labelbox/schema/dataset.py index 04877c885..16c993dfa 100644 --- a/libs/labelbox/src/labelbox/schema/dataset.py +++ b/libs/labelbox/src/labelbox/schema/dataset.py @@ -1,57 +1,43 @@ -from datetime import datetime -from typing import Dict, Generator, List, Optional, Any, Final, Tuple, Union -import os import json import logging -from collections.abc import Iterable -from string import Template -import time +import os import warnings - -from labelbox import parser -from itertools import islice - from concurrent.futures import ThreadPoolExecutor, as_completed -from io import StringIO -import requests +from itertools import islice +from string import Template +from typing import Any, Dict, List, Optional, Tuple, Union +import labelbox.schema.internal.data_row_uploader as data_row_uploader from labelbox.exceptions import ( InvalidQueryError, LabelboxError, - ResourceNotFoundError, ResourceCreationError, + ResourceNotFoundError, ) +from labelbox.orm import query from labelbox.orm.comparison import Comparison -from labelbox.orm.db_object import DbObject, Updateable, Deletable, experimental +from labelbox.orm.db_object import DbObject, Deletable, Updateable from labelbox.orm.model import Entity, Field, Relationship -from labelbox.orm import query -from labelbox.exceptions import MalformedQueryException from labelbox.pagination import PaginatedCollection from labelbox.schema.data_row import DataRow -from labelbox.schema.embedding import EmbeddingVector from labelbox.schema.export_filters import DatasetExportFilters, build_filters from labelbox.schema.export_params import ( CatalogExportParams, validate_catalog_export_params, ) from labelbox.schema.export_task import ExportTask -from labelbox.schema.identifiable import UniqueId, GlobalKey -from labelbox.schema.task import Task, DataUpsertTask -from labelbox.schema.user import User from labelbox.schema.iam_integration import IAMIntegration +from labelbox.schema.identifiable import GlobalKey, UniqueId from labelbox.schema.internal.data_row_upsert_item import ( + DataRowCreateItem, DataRowItemBase, DataRowUpsertItem, - DataRowCreateItem, -) -import labelbox.schema.internal.data_row_uploader as data_row_uploader -from labelbox.schema.internal.descriptor_file_creator import ( - DescriptorFileCreator, ) from labelbox.schema.internal.datarow_upload_constants import ( FILE_UPLOAD_THREAD_COUNT, UPSERT_CHUNK_SIZE_BYTES, ) +from labelbox.schema.task import DataUpsertTask, Task logger = logging.getLogger(__name__) @@ -359,7 +345,7 @@ def data_row_for_external_id(self, external_id) -> "DataRow": ) if len(data_rows) > 1: logger.warning( - f"More than one data_row has the provided external_id : `%s`. Use function data_rows_for_external_id to fetch all", + "More than one data_row has the provided external_id : `%s`. Use function data_rows_for_external_id to fetch all", external_id, ) return data_rows[0] diff --git a/libs/labelbox/tests/integration/conftest.py b/libs/labelbox/tests/integration/conftest.py index c917a6164..622bedc27 100644 --- a/libs/labelbox/tests/integration/conftest.py +++ b/libs/labelbox/tests/integration/conftest.py @@ -1,42 +1,35 @@ -from collections import defaultdict -from itertools import islice -import json import os import sys -import re import time -import uuid -import requests -from types import SimpleNamespace -from typing import Type, List -from enum import Enum -from typing import Tuple +from collections import defaultdict +from datetime import datetime, timezone +from itertools import islice +from typing import Type import pytest -import requests +from constants import ( + CAPTURE_DT_SCHEMA_ID, + SPLIT_SCHEMA_ID, + TEST_SPLIT_ID, + TEXT_SCHEMA_ID, +) -from labelbox import Dataset, DataRow -from labelbox import LabelingFrontend from labelbox import ( - OntologyBuilder, - Tool, - Option, Classification, + Client, + Dataset, + LabelingFrontend, MediaType, + OntologyBuilder, + Option, PromptResponseClassification, ResponseOption, + Tool, ) -from labelbox.orm import query -from labelbox.pagination import PaginatedCollection -from labelbox.schema.annotation_import import LabelImport -from labelbox.schema.catalog import Catalog -from labelbox.schema.enums import AnnotationImportState -from labelbox.schema.invite import Invite -from labelbox.schema.quality_mode import QualityMode +from labelbox.schema.data_row import DataRowMetadataField +from labelbox.schema.ontology_kind import OntologyKind from labelbox.schema.queue_mode import QueueMode from labelbox.schema.user import User -from labelbox import Client -from labelbox.schema.ontology_kind import OntologyKind @pytest.fixture @@ -835,3 +828,16 @@ def print_perf_summary(): for aaa in islice(sorted_dict, num_of_entries) ] print("\nTop slowest fixtures:\n", slowest_fixtures, file=sys.stderr) + + +@pytest.fixture +def make_metadata_fields(): + msg = "A message" + time = datetime.now(timezone.utc) + + fields = [ + DataRowMetadataField(schema_id=SPLIT_SCHEMA_ID, value=TEST_SPLIT_ID), + DataRowMetadataField(schema_id=CAPTURE_DT_SCHEMA_ID, value=time), + DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, value=msg), + ] + return fields diff --git a/libs/labelbox/tests/integration/test_data_rows.py b/libs/labelbox/tests/integration/test_data_rows.py index 7f69c2995..d2bbec072 100644 --- a/libs/labelbox/tests/integration/test_data_rows.py +++ b/libs/labelbox/tests/integration/test_data_rows.py @@ -1,41 +1,39 @@ -from tempfile import NamedTemporaryFile -import uuid -from datetime import datetime import json -import requests import os - +import uuid +from datetime import datetime +from tempfile import NamedTemporaryFile from unittest.mock import patch + import pytest +import requests +from constants import ( + CAPTURE_DT_SCHEMA_ID, + CUSTOM_TEXT_SCHEMA_NAME, + EXPECTED_METADATA_SCHEMA_IDS, + SPLIT_SCHEMA_ID, + TEST_SPLIT_ID, + TEXT_SCHEMA_ID, +) -from labelbox.schema.media_type import MediaType -from labelbox import DataRow, AssetAttachment +from labelbox import AssetAttachment, DataRow from labelbox.exceptions import ( + InvalidQueryError, MalformedQueryException, ResourceCreationError, - InvalidQueryError, ) -from labelbox.schema.task import Task, DataUpsertTask from labelbox.schema.data_row_metadata import ( DataRowMetadataField, DataRowMetadataKind, ) - -SPLIT_SCHEMA_ID = "cko8sbczn0002h2dkdaxb5kal" -TEST_SPLIT_ID = "cko8scbz70005h2dkastwhgqt" -TEXT_SCHEMA_ID = "cko8s9r5v0001h2dk9elqdidh" -CAPTURE_DT_SCHEMA_ID = "cko8sdzv70006h2dk8jg64zvb" -EXPECTED_METADATA_SCHEMA_IDS = [ - SPLIT_SCHEMA_ID, - TEST_SPLIT_ID, - TEXT_SCHEMA_ID, - CAPTURE_DT_SCHEMA_ID, -].sort() -CUSTOM_TEXT_SCHEMA_NAME = "custom_text" +from labelbox.schema.media_type import MediaType +from labelbox.schema.task import Task @pytest.fixture -def mdo(client): +def mdo( + client, +): mdo = client.get_data_row_metadata_ontology() try: mdo.create_schema(CUSTOM_TEXT_SCHEMA_NAME, DataRowMetadataKind.string) @@ -93,18 +91,6 @@ def tile_content(): } -def make_metadata_fields(): - msg = "A message" - time = datetime.utcnow() - - fields = [ - DataRowMetadataField(schema_id=SPLIT_SCHEMA_ID, value=TEST_SPLIT_ID), - DataRowMetadataField(schema_id=CAPTURE_DT_SCHEMA_ID, value=time), - DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, value=msg), - ] - return fields - - def make_metadata_fields_dict(): msg = "A message" time = datetime.utcnow() @@ -375,15 +361,17 @@ def test_create_data_row_with_invalid_input(dataset, image_url): dataset.create_data_row("asdf") -def test_create_data_row_with_metadata(mdo, dataset, image_url): +def test_create_data_row_with_metadata( + mdo, dataset, image_url, make_metadata_fields +): client = dataset.client assert len(list(dataset.data_rows())) == 0 data_row = dataset.create_data_row( - row_data=image_url, metadata_fields=make_metadata_fields() + row_data=image_url, metadata_fields=make_metadata_fields ) - assert len(list(dataset.data_rows())) == 1 + assert len([dr for dr in dataset.data_rows()]) == 1 assert data_row.dataset() == dataset assert data_row.created_by() == client.get_user() assert data_row.organization() == client.get_organization() @@ -398,7 +386,7 @@ def test_create_data_row_with_metadata(mdo, dataset, image_url): assert len(metadata) == 3 assert [ m["schemaId"] for m in metadata_fields - ].sort() == EXPECTED_METADATA_SCHEMA_IDS + ].sort() == EXPECTED_METADATA_SCHEMA_IDS.sort() for m in metadata: assert mdo._parse_upsert(m) @@ -426,13 +414,15 @@ def test_create_data_row_with_metadata_dict(mdo, dataset, image_url): assert len(metadata) == 3 assert [ m["schemaId"] for m in metadata_fields - ].sort() == EXPECTED_METADATA_SCHEMA_IDS + ].sort() == EXPECTED_METADATA_SCHEMA_IDS.sort() for m in metadata: assert mdo._parse_upsert(m) -def test_create_data_row_with_invalid_metadata(dataset, image_url): - fields = make_metadata_fields() +def test_create_data_row_with_invalid_metadata( + dataset, image_url, make_metadata_fields +): + fields = make_metadata_fields # make the payload invalid by providing the same schema id more than once fields.append( DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, value="some msg") @@ -442,7 +432,9 @@ def test_create_data_row_with_invalid_metadata(dataset, image_url): dataset.create_data_row(row_data=image_url, metadata_fields=fields) -def test_create_data_rows_with_metadata(mdo, dataset, image_url): +def test_create_data_rows_with_metadata( + mdo, dataset, image_url, make_metadata_fields +): client = dataset.client assert len(list(dataset.data_rows())) == 0 @@ -451,12 +443,12 @@ def test_create_data_rows_with_metadata(mdo, dataset, image_url): { DataRow.row_data: image_url, DataRow.external_id: "row1", - DataRow.metadata_fields: make_metadata_fields(), + DataRow.metadata_fields: make_metadata_fields, }, { DataRow.row_data: image_url, DataRow.external_id: "row2", - "metadata_fields": make_metadata_fields(), + "metadata_fields": make_metadata_fields, }, { DataRow.row_data: image_url, @@ -490,7 +482,7 @@ def test_create_data_rows_with_metadata(mdo, dataset, image_url): assert len(metadata) == 3 assert [ m["schemaId"] for m in metadata_fields - ].sort() == EXPECTED_METADATA_SCHEMA_IDS + ].sort() == EXPECTED_METADATA_SCHEMA_IDS.sort() for m in metadata: assert mdo._parse_upsert(m) @@ -565,8 +557,10 @@ def create_data_row(data_rows): ) -def test_create_data_rows_with_invalid_metadata(dataset, image_url): - fields = make_metadata_fields() +def test_create_data_rows_with_invalid_metadata( + dataset, image_url, make_metadata_fields +): + fields = make_metadata_fields # make the payload invalid by providing the same schema id more than once fields.append( DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, value="some msg") @@ -585,8 +579,10 @@ def test_create_data_rows_with_invalid_metadata(dataset, image_url): ) -def test_create_data_rows_with_metadata_missing_value(dataset, image_url): - fields = make_metadata_fields() +def test_create_data_rows_with_metadata_missing_value( + dataset, image_url, make_metadata_fields +): + fields = make_metadata_fields fields.append({"schemaId": "some schema id"}) with pytest.raises(ValueError) as exc: @@ -601,8 +597,10 @@ def test_create_data_rows_with_metadata_missing_value(dataset, image_url): ) -def test_create_data_rows_with_metadata_missing_schema_id(dataset, image_url): - fields = make_metadata_fields() +def test_create_data_rows_with_metadata_missing_schema_id( + dataset, image_url, make_metadata_fields +): + fields = make_metadata_fields fields.append({"value": "some value"}) with pytest.raises(ValueError) as exc: @@ -617,8 +615,10 @@ def test_create_data_rows_with_metadata_missing_schema_id(dataset, image_url): ) -def test_create_data_rows_with_metadata_wrong_type(dataset, image_url): - fields = make_metadata_fields() +def test_create_data_rows_with_metadata_wrong_type( + dataset, image_url, make_metadata_fields +): + fields = make_metadata_fields fields.append("Neither DataRowMetadataField or dict") with pytest.raises(ValueError) as exc: @@ -944,7 +944,11 @@ def test_does_not_update_not_provided_attachment_fields(data_row): assert attachment.attachment_type == "RAW_TEXT" -def test_create_data_rows_result(client, dataset, image_url): +def test_create_data_rows_result( + client, + dataset, + image_url, +): task = dataset.create_data_rows( [ { @@ -963,12 +967,14 @@ def test_create_data_rows_result(client, dataset, image_url): client.get_data_row(result["id"]) -def test_create_data_rows_local_file(dataset, sample_image): +def test_create_data_rows_local_file( + dataset, sample_image, make_metadata_fields +): task = dataset.create_data_rows( [ { DataRow.row_data: sample_image, - DataRow.metadata_fields: make_metadata_fields(), + DataRow.metadata_fields: make_metadata_fields, } ] ) From f1b30504613f6dac913c860eb82aac530a22be2d Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Tue, 1 Oct 2024 11:59:49 -0700 Subject: [PATCH 2/3] Add test for mmc data rows --- libs/labelbox/tests/integration/constants.py | 11 ++++ .../tests/integration/test_mmc_data_rows.py | 59 +++++++++++++++++++ 2 files changed, 70 insertions(+) create mode 100644 libs/labelbox/tests/integration/constants.py create mode 100644 libs/labelbox/tests/integration/test_mmc_data_rows.py diff --git a/libs/labelbox/tests/integration/constants.py b/libs/labelbox/tests/integration/constants.py new file mode 100644 index 000000000..d48d31c16 --- /dev/null +++ b/libs/labelbox/tests/integration/constants.py @@ -0,0 +1,11 @@ +SPLIT_SCHEMA_ID = "cko8sbczn0002h2dkdaxb5kal" +TEST_SPLIT_ID = "cko8scbz70005h2dkastwhgqt" +TEXT_SCHEMA_ID = "cko8s9r5v0001h2dk9elqdidh" +CAPTURE_DT_SCHEMA_ID = "cko8sdzv70006h2dk8jg64zvb" +EXPECTED_METADATA_SCHEMA_IDS = [ + SPLIT_SCHEMA_ID, + TEST_SPLIT_ID, + TEXT_SCHEMA_ID, + CAPTURE_DT_SCHEMA_ID, +] +CUSTOM_TEXT_SCHEMA_NAME = "custom_text" diff --git a/libs/labelbox/tests/integration/test_mmc_data_rows.py b/libs/labelbox/tests/integration/test_mmc_data_rows.py new file mode 100644 index 000000000..2f3886cae --- /dev/null +++ b/libs/labelbox/tests/integration/test_mmc_data_rows.py @@ -0,0 +1,59 @@ +import json +import random + +import pytest +from constants import EXPECTED_METADATA_SCHEMA_IDS + + +@pytest.fixture +def mmc_data_row(dataset, make_metadata_fields, embedding): + row_data = { + "type": "application/vnd.labelbox.conversational.model-chat-evaluation", + "draft": True, + "rootMessageIds": ["root1"], + "actors": {}, + "messages": {}, + } + + vector = [random.uniform(1.0, 2.0) for _ in range(embedding.dims)] + embeddings = [{"embedding_id": embedding.id, "vector": vector}] + + content_all = { + "row_data": row_data, + "attachments": [{"type": "RAW_TEXT", "value": "attachment value"}], + "metadata_fields": make_metadata_fields, + "embeddings": embeddings, + } + task = dataset.create_data_rows([content_all]) + task.wait_till_done() + assert task.status == "COMPLETE" + + data_row = list(dataset.data_rows())[0] + + yield data_row + + data_row.delete() + + +def test_mmc(mmc_data_row, embedding): + data_row = mmc_data_row + assert json.loads(data_row.row_data) == { + "type": "application/vnd.labelbox.conversational.model-chat-evaluation", + "draft": True, + "rootMessageIds": ["root1"], + "actors": {}, + "messages": {}, + } + + metadata_fields = data_row.metadata_fields + metadata = data_row.metadata + assert len(metadata_fields) == 3 + assert len(metadata) == 3 + assert [ + m["schemaId"] for m in metadata_fields + ].sort() == EXPECTED_METADATA_SCHEMA_IDS.sort() + + attachments = list(data_row.attachments()) + assert len(attachments) == 1 + + assert embedding.get_imported_vector_count() == 1 From c6fec9228beea42f3b7eeab98aabfeddac2f0d6b Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Tue, 1 Oct 2024 12:09:49 -0700 Subject: [PATCH 3/3] Dealing with consts in tests --- libs/labelbox/tests/integration/conftest.py | 43 ++++++-- libs/labelbox/tests/integration/constants.py | 11 -- .../tests/integration/test_data_rows.py | 100 +++++++++++------- .../tests/integration/test_mmc_data_rows.py | 9 +- 4 files changed, 96 insertions(+), 67 deletions(-) delete mode 100644 libs/labelbox/tests/integration/constants.py diff --git a/libs/labelbox/tests/integration/conftest.py b/libs/labelbox/tests/integration/conftest.py index 622bedc27..10b05681e 100644 --- a/libs/labelbox/tests/integration/conftest.py +++ b/libs/labelbox/tests/integration/conftest.py @@ -7,12 +7,6 @@ from typing import Type import pytest -from constants import ( - CAPTURE_DT_SCHEMA_ID, - SPLIT_SCHEMA_ID, - TEST_SPLIT_ID, - TEXT_SCHEMA_ID, -) from labelbox import ( Classification, @@ -32,6 +26,30 @@ from labelbox.schema.user import User +@pytest.fixture +def constants(): + SPLIT_SCHEMA_ID = "cko8sbczn0002h2dkdaxb5kal" + TEST_SPLIT_ID = "cko8scbz70005h2dkastwhgqt" + TEXT_SCHEMA_ID = "cko8s9r5v0001h2dk9elqdidh" + CAPTURE_DT_SCHEMA_ID = "cko8sdzv70006h2dk8jg64zvb" + EXPECTED_METADATA_SCHEMA_IDS = [ + SPLIT_SCHEMA_ID, + TEST_SPLIT_ID, + TEXT_SCHEMA_ID, + CAPTURE_DT_SCHEMA_ID, + ] + CUSTOM_TEXT_SCHEMA_NAME = "custom_text" + + return { + "SPLIT_SCHEMA_ID": SPLIT_SCHEMA_ID, + "TEST_SPLIT_ID": TEST_SPLIT_ID, + "TEXT_SCHEMA_ID": TEXT_SCHEMA_ID, + "CAPTURE_DT_SCHEMA_ID": CAPTURE_DT_SCHEMA_ID, + "EXPECTED_METADATA_SCHEMA_IDS": EXPECTED_METADATA_SCHEMA_IDS, + "CUSTOM_TEXT_SCHEMA_NAME": CUSTOM_TEXT_SCHEMA_NAME, + } + + @pytest.fixture def project_based_user(client, rand_gen): email = rand_gen(str) @@ -831,13 +849,18 @@ def print_perf_summary(): @pytest.fixture -def make_metadata_fields(): +def make_metadata_fields(constants): msg = "A message" time = datetime.now(timezone.utc) fields = [ - DataRowMetadataField(schema_id=SPLIT_SCHEMA_ID, value=TEST_SPLIT_ID), - DataRowMetadataField(schema_id=CAPTURE_DT_SCHEMA_ID, value=time), - DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, value=msg), + DataRowMetadataField( + schema_id=constants["SPLIT_SCHEMA_ID"], + value=constants["TEST_SPLIT_ID"], + ), + DataRowMetadataField( + schema_id=constants["CAPTURE_DT_SCHEMA_ID"], value=time + ), + DataRowMetadataField(schema_id=constants["TEXT_SCHEMA_ID"], value=msg), ] return fields diff --git a/libs/labelbox/tests/integration/constants.py b/libs/labelbox/tests/integration/constants.py deleted file mode 100644 index d48d31c16..000000000 --- a/libs/labelbox/tests/integration/constants.py +++ /dev/null @@ -1,11 +0,0 @@ -SPLIT_SCHEMA_ID = "cko8sbczn0002h2dkdaxb5kal" -TEST_SPLIT_ID = "cko8scbz70005h2dkastwhgqt" -TEXT_SCHEMA_ID = "cko8s9r5v0001h2dk9elqdidh" -CAPTURE_DT_SCHEMA_ID = "cko8sdzv70006h2dk8jg64zvb" -EXPECTED_METADATA_SCHEMA_IDS = [ - SPLIT_SCHEMA_ID, - TEST_SPLIT_ID, - TEXT_SCHEMA_ID, - CAPTURE_DT_SCHEMA_ID, -] -CUSTOM_TEXT_SCHEMA_NAME = "custom_text" diff --git a/libs/labelbox/tests/integration/test_data_rows.py b/libs/labelbox/tests/integration/test_data_rows.py index d2bbec072..9f0429269 100644 --- a/libs/labelbox/tests/integration/test_data_rows.py +++ b/libs/labelbox/tests/integration/test_data_rows.py @@ -1,20 +1,12 @@ import json import os import uuid -from datetime import datetime +from datetime import datetime, timezone from tempfile import NamedTemporaryFile from unittest.mock import patch import pytest import requests -from constants import ( - CAPTURE_DT_SCHEMA_ID, - CUSTOM_TEXT_SCHEMA_NAME, - EXPECTED_METADATA_SCHEMA_IDS, - SPLIT_SCHEMA_ID, - TEST_SPLIT_ID, - TEXT_SCHEMA_ID, -) from labelbox import AssetAttachment, DataRow from labelbox.exceptions import ( @@ -33,10 +25,13 @@ @pytest.fixture def mdo( client, + constants, ): mdo = client.get_data_row_metadata_ontology() try: - mdo.create_schema(CUSTOM_TEXT_SCHEMA_NAME, DataRowMetadataKind.string) + mdo.create_schema( + constants["CUSTOM_TEXT_SCHEMA_NAME"], DataRowMetadataKind.string + ) except MalformedQueryException: # Do nothing if already exists pass @@ -91,14 +86,18 @@ def tile_content(): } -def make_metadata_fields_dict(): +@pytest.fixture +def make_metadata_fields_dict(constants): msg = "A message" - time = datetime.utcnow() + time = datetime.now(timezone.utc) fields = [ - {"schema_id": SPLIT_SCHEMA_ID, "value": TEST_SPLIT_ID}, - {"schema_id": CAPTURE_DT_SCHEMA_ID, "value": time}, - {"schema_id": TEXT_SCHEMA_ID, "value": msg}, + { + "schema_id": constants["SPLIT_SCHEMA_ID"], + "value": constants["TEST_SPLIT_ID"], + }, + {"schema_id": constants["CAPTURE_DT_SCHEMA_ID"], "value": time}, + {"schema_id": constants["TEXT_SCHEMA_ID"], "value": msg}, ] return fields @@ -362,7 +361,12 @@ def test_create_data_row_with_invalid_input(dataset, image_url): def test_create_data_row_with_metadata( - mdo, dataset, image_url, make_metadata_fields + mdo, + dataset, + image_url, + make_metadata_fields, + constants, + make_metadata_fields_dict, ): client = dataset.client assert len(list(dataset.data_rows())) == 0 @@ -384,19 +388,21 @@ def test_create_data_row_with_metadata( metadata = data_row.metadata assert len(metadata_fields) == 3 assert len(metadata) == 3 - assert [ - m["schemaId"] for m in metadata_fields - ].sort() == EXPECTED_METADATA_SCHEMA_IDS.sort() + assert [m["schemaId"] for m in metadata_fields].sort() == constants[ + "EXPECTED_METADATA_SCHEMA_IDS" + ].sort() for m in metadata: assert mdo._parse_upsert(m) -def test_create_data_row_with_metadata_dict(mdo, dataset, image_url): +def test_create_data_row_with_metadata_dict( + mdo, dataset, image_url, constants, make_metadata_fields_dict +): client = dataset.client assert len(list(dataset.data_rows())) == 0 data_row = dataset.create_data_row( - row_data=image_url, metadata_fields=make_metadata_fields_dict() + row_data=image_url, metadata_fields=make_metadata_fields_dict ) assert len(list(dataset.data_rows())) == 1 @@ -412,20 +418,22 @@ def test_create_data_row_with_metadata_dict(mdo, dataset, image_url): metadata = data_row.metadata assert len(metadata_fields) == 3 assert len(metadata) == 3 - assert [ - m["schemaId"] for m in metadata_fields - ].sort() == EXPECTED_METADATA_SCHEMA_IDS.sort() + assert [m["schemaId"] for m in metadata_fields].sort() == constants[ + "EXPECTED_METADATA_SCHEMA_IDS" + ].sort() for m in metadata: assert mdo._parse_upsert(m) def test_create_data_row_with_invalid_metadata( - dataset, image_url, make_metadata_fields + dataset, image_url, constants, make_metadata_fields ): fields = make_metadata_fields # make the payload invalid by providing the same schema id more than once fields.append( - DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, value="some msg") + DataRowMetadataField( + schema_id=constants["TEXT_SCHEMA_ID"], value="some msg" + ) ) with pytest.raises(ResourceCreationError): @@ -433,7 +441,12 @@ def test_create_data_row_with_invalid_metadata( def test_create_data_rows_with_metadata( - mdo, dataset, image_url, make_metadata_fields + mdo, + dataset, + image_url, + constants, + make_metadata_fields, + make_metadata_fields_dict, ): client = dataset.client assert len(list(dataset.data_rows())) == 0 @@ -453,12 +466,12 @@ def test_create_data_rows_with_metadata( { DataRow.row_data: image_url, DataRow.external_id: "row3", - DataRow.metadata_fields: make_metadata_fields_dict(), + DataRow.metadata_fields: make_metadata_fields_dict, }, { DataRow.row_data: image_url, DataRow.external_id: "row4", - "metadata_fields": make_metadata_fields_dict(), + "metadata_fields": make_metadata_fields_dict, }, ] ) @@ -480,9 +493,9 @@ def test_create_data_rows_with_metadata( metadata = row.metadata assert len(metadata_fields) == 3 assert len(metadata) == 3 - assert [ - m["schemaId"] for m in metadata_fields - ].sort() == EXPECTED_METADATA_SCHEMA_IDS.sort() + assert [m["schemaId"] for m in metadata_fields].sort() == constants[ + "EXPECTED_METADATA_SCHEMA_IDS" + ].sort() for m in metadata: assert mdo._parse_upsert(m) @@ -499,14 +512,16 @@ def test_create_data_rows_with_metadata( ], ) def test_create_data_rows_with_named_metadata_field_class( - test_function, metadata_obj_type, mdo, dataset, image_url + test_function, metadata_obj_type, mdo, dataset, image_url, constants ): row_with_metadata_field = { DataRow.row_data: image_url, DataRow.external_id: "row1", DataRow.metadata_fields: [ DataRowMetadataField(name="split", value="test"), - DataRowMetadataField(name=CUSTOM_TEXT_SCHEMA_NAME, value="hello"), + DataRowMetadataField( + name=constants["CUSTOM_TEXT_SCHEMA_NAME"], value="hello" + ), ], } @@ -515,7 +530,7 @@ def test_create_data_rows_with_named_metadata_field_class( DataRow.external_id: "row2", "metadata_fields": [ {"name": "split", "value": "test"}, - {"name": CUSTOM_TEXT_SCHEMA_NAME, "value": "hello"}, + {"name": constants["CUSTOM_TEXT_SCHEMA_NAME"], "value": "hello"}, ], } @@ -547,23 +562,26 @@ def create_data_row(data_rows): assert len(created_rows[0].metadata) == 2 metadata = created_rows[0].metadata - assert metadata[0].schema_id == SPLIT_SCHEMA_ID + assert metadata[0].schema_id == constants["SPLIT_SCHEMA_ID"] assert metadata[0].name == "test" assert metadata[0].value == mdo.reserved_by_name["split"]["test"].uid - assert metadata[1].name == CUSTOM_TEXT_SCHEMA_NAME + assert metadata[1].name == constants["CUSTOM_TEXT_SCHEMA_NAME"] assert metadata[1].value == "hello" assert ( - metadata[1].schema_id == mdo.custom_by_name[CUSTOM_TEXT_SCHEMA_NAME].uid + metadata[1].schema_id + == mdo.custom_by_name[constants["CUSTOM_TEXT_SCHEMA_NAME"]].uid ) def test_create_data_rows_with_invalid_metadata( - dataset, image_url, make_metadata_fields + dataset, image_url, constants, make_metadata_fields ): fields = make_metadata_fields # make the payload invalid by providing the same schema id more than once fields.append( - DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, value="some msg") + DataRowMetadataField( + schema_id=constants["TEXT_SCHEMA_ID"], value="some msg" + ) ) task = dataset.create_data_rows( @@ -574,7 +592,7 @@ def test_create_data_rows_with_invalid_metadata( assert task.status == "COMPLETE" assert len(task.failed_data_rows) == 1 assert ( - f"A schemaId can only be specified once per DataRow : [{TEXT_SCHEMA_ID}]" + f"A schemaId can only be specified once per DataRow : [{constants['TEXT_SCHEMA_ID']}]" in task.failed_data_rows[0]["message"] ) diff --git a/libs/labelbox/tests/integration/test_mmc_data_rows.py b/libs/labelbox/tests/integration/test_mmc_data_rows.py index 2f3886cae..ee457a7fe 100644 --- a/libs/labelbox/tests/integration/test_mmc_data_rows.py +++ b/libs/labelbox/tests/integration/test_mmc_data_rows.py @@ -2,7 +2,6 @@ import random import pytest -from constants import EXPECTED_METADATA_SCHEMA_IDS @pytest.fixture @@ -35,7 +34,7 @@ def mmc_data_row(dataset, make_metadata_fields, embedding): data_row.delete() -def test_mmc(mmc_data_row, embedding): +def test_mmc(mmc_data_row, embedding, constants): data_row = mmc_data_row assert json.loads(data_row.row_data) == { "type": "application/vnd.labelbox.conversational.model-chat-evaluation", @@ -49,9 +48,9 @@ def test_mmc(mmc_data_row, embedding): metadata = data_row.metadata assert len(metadata_fields) == 3 assert len(metadata) == 3 - assert [ - m["schemaId"] for m in metadata_fields - ].sort() == EXPECTED_METADATA_SCHEMA_IDS.sort() + assert [m["schemaId"] for m in metadata_fields].sort() == constants[ + "EXPECTED_METADATA_SCHEMA_IDS" + ].sort() attachments = list(data_row.attachments()) assert len(attachments) == 1