From 02625c6c9c0ff9759596c7793f0b38c8994f4a97 Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Tue, 30 Jul 2024 10:59:08 -0500 Subject: [PATCH 1/7] set up embeddings to be reusable --- libs/labelbox/tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/labelbox/tests/conftest.py b/libs/labelbox/tests/conftest.py index 272b1c9c4..d79f6f854 100644 --- a/libs/labelbox/tests/conftest.py +++ b/libs/labelbox/tests/conftest.py @@ -1063,7 +1063,7 @@ def configured_project_with_complex_ontology(client, initial_dataset, rand_gen, project.delete() -@pytest.fixture +@pytest.fixture(scope="module") def embedding(client: Client): uuid_str = uuid.uuid4().hex embedding = client.create_embedding(f"sdk-int-{uuid_str}", 8) From 69a38172d02524e60d1e6d5163670b5a011b2dbf Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Tue, 30 Jul 2024 11:22:29 -0500 Subject: [PATCH 2/7] set up embeddings to be reusable --- libs/labelbox/tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/labelbox/tests/conftest.py b/libs/labelbox/tests/conftest.py index d79f6f854..77d619317 100644 --- a/libs/labelbox/tests/conftest.py +++ b/libs/labelbox/tests/conftest.py @@ -1063,7 +1063,7 @@ def configured_project_with_complex_ontology(client, initial_dataset, rand_gen, project.delete() -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def embedding(client: Client): uuid_str = uuid.uuid4().hex embedding = client.create_embedding(f"sdk-int-{uuid_str}", 8) From e3fbb039c3c305ddd92e59ebd4f7579f7e45f50a Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Tue, 30 Jul 2024 11:27:45 -0500 Subject: [PATCH 3/7] set up embeddings to be reusable --- libs/labelbox/tests/conftest.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/libs/labelbox/tests/conftest.py b/libs/labelbox/tests/conftest.py index 77d619317..ef518b0b2 100644 --- a/libs/labelbox/tests/conftest.py +++ b/libs/labelbox/tests/conftest.py @@ -35,6 +35,8 @@ from labelbox.schema.quality_mode import QualityMode from labelbox.schema.queue_mode import QueueMode from labelbox.schema.user import User +from labelbox.exceptions import LabelboxError +from contextlib import suppress from labelbox import Client IMG_URL = "https://picsum.photos/200/300.jpg" @@ -1064,7 +1066,15 @@ def configured_project_with_complex_ontology(client, initial_dataset, rand_gen, @pytest.fixture(scope="session") -def embedding(client: Client): +def embedding(client: Client, environ): + + # Remove all embeddings on staging + if environ == Environ.STAGING: + embeddings = client.get_embeddings() + for embedding in embeddings: + with suppress(LabelboxError): + embedding.delete() + uuid_str = uuid.uuid4().hex embedding = client.create_embedding(f"sdk-int-{uuid_str}", 8) yield embedding From b1138a8631a559cf0a3b969ae83fb86f63cb4acf Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Tue, 30 Jul 2024 11:38:13 -0500 Subject: [PATCH 4/7] fixed --- libs/labelbox/tests/conftest.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/libs/labelbox/tests/conftest.py b/libs/labelbox/tests/conftest.py index ef518b0b2..db47cc071 100644 --- a/libs/labelbox/tests/conftest.py +++ b/libs/labelbox/tests/conftest.py @@ -1067,18 +1067,18 @@ def configured_project_with_complex_ontology(client, initial_dataset, rand_gen, @pytest.fixture(scope="session") def embedding(client: Client, environ): - + + uuid_str = uuid.uuid4().hex + embedding = client.create_embedding(f"sdk-int-{uuid_str}", 8) + yield embedding # Remove all embeddings on staging if environ == Environ.STAGING: embeddings = client.get_embeddings() for embedding in embeddings: with suppress(LabelboxError): embedding.delete() - - uuid_str = uuid.uuid4().hex - embedding = client.create_embedding(f"sdk-int-{uuid_str}", 8) - yield embedding - embedding.delete() + else: + embedding.delete() @pytest.fixture From e43723d4b3244cbf602af17c819127e673155a23 Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Wed, 31 Jul 2024 11:01:48 -0500 Subject: [PATCH 5/7] simplified library --- .../data/annotation_types/metrics/scalar.py | 7 +-- .../data/serialization/ndjson/base.py | 30 +++++++++-- .../serialization/ndjson/classification.py | 14 +++--- .../data/serialization/ndjson/converter.py | 6 ++- .../data/serialization/ndjson/label.py | 50 ++++++++++++++++++- .../data/serialization/ndjson/metric.py | 8 +-- .../data/serialization/ndjson/objects.py | 38 +++++++------- .../data/serialization/ndjson/relationship.py | 4 +- 8 files changed, 111 insertions(+), 46 deletions(-) diff --git a/libs/labelbox/src/labelbox/data/annotation_types/metrics/scalar.py b/libs/labelbox/src/labelbox/data/annotation_types/metrics/scalar.py index 5f1279fd6..ccaa46854 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/metrics/scalar.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/metrics/scalar.py @@ -4,8 +4,9 @@ from .base import ConfidenceValue, BaseMetric from labelbox import pydantic_compat +from typing_extensions import Annotated -ScalarMetricValue = pydantic_compat.confloat(ge=0, le=100_000_000) +ScalarMetricValue = Annotated[float, pydantic_compat.confloat(ge=0, le=100_000_000)] ScalarMetricConfidenceValue = Dict[ConfidenceValue, ScalarMetricValue] @@ -27,11 +28,11 @@ class ScalarMetric(BaseMetric): For backwards compatibility, metric_name is optional. The metric_name will be set to a default name in the editor if it is not set. This is not recommended and support for empty metric_name fields will be removed. - aggregation will be ignored wihtout providing a metric name. + aggregation will be ignored without providing a metric name. """ metric_name: Optional[str] = None value: Union[ScalarMetricValue, ScalarMetricConfidenceValue] - aggregation: ScalarMetricAggregation = ScalarMetricAggregation.ARITHMETIC_MEAN + aggregation: Optional[ScalarMetricAggregation] = ScalarMetricAggregation.ARITHMETIC_MEAN @pydantic_compat.validator('metric_name') def validate_metric_name(cls, name: Union[str, None]): diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/base.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/base.py index 2a9186e02..21ef33dfe 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/base.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/base.py @@ -5,6 +5,18 @@ from labelbox import pydantic_compat from ...annotation_types.types import Cuid +subclass_registry = {} + +class SubclassRegistryBase(pydantic_compat.BaseModel): + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if cls.__name__ != "NDAnnotation": + subclass_registry[cls.__name__] = cls + + class Config: + extra = "allow" + class DataRow(_CamelCaseMixin): id: str = None @@ -19,7 +31,7 @@ def must_set_one(cls, values): class NDJsonBase(_CamelCaseMixin): uuid: str = None - data_row: DataRow + data_row: Optional[DataRow] = None @pydantic_compat.validator('uuid', pre=True, always=True) def set_id(cls, v): @@ -28,10 +40,18 @@ def set_id(cls, v): def dict(self, *args, **kwargs): """ Pop missing id or missing globalKey from dataRow """ res = super().dict(*args, **kwargs) - if not self.data_row.id: - res['dataRow'].pop('id') - if not self.data_row.global_key: - res['dataRow'].pop('globalKey') + if self.data_row and not self.data_row.id: + if "data_row" in res: + res["data_row"].pop("id") + else: + res['dataRow'].pop('id') + if self.data_row and not self.data_row.global_key: + if "data_row" in res: + res["data_row"].pop("global_key") + else: + res['dataRow'].pop('globalKey') + if not self.data_row: + del res["dataRow"] return res diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py index 46b8fc91f..56cc25c4a 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py @@ -11,6 +11,7 @@ from ...annotation_types.classification.classification import ClassificationAnswer, Dropdown, Text, Checklist, Radio from ...annotation_types.types import Cuid from ...annotation_types.data import TextData, VideoData, ImageData +from labelbox.data.serialization.ndjson.base import SubclassRegistryBase class NDAnswer(ConfidenceMixin, CustomMetricsMixin): @@ -174,7 +175,7 @@ def from_common(cls, prompt_text: PromptText, name: str, # ====== End of subclasses -class NDText(NDAnnotation, NDTextSubclass): +class NDText(NDAnnotation, NDTextSubclass, SubclassRegistryBase): @classmethod def from_common(cls, @@ -198,7 +199,7 @@ def from_common(cls, ) -class NDChecklist(NDAnnotation, NDChecklistSubclass, VideoSupported): +class NDChecklist(NDAnnotation, NDChecklistSubclass, VideoSupported, SubclassRegistryBase): @classmethod def from_common( @@ -234,7 +235,7 @@ def from_common( confidence=confidence) -class NDRadio(NDAnnotation, NDRadioSubclass, VideoSupported): +class NDRadio(NDAnnotation, NDRadioSubclass, VideoSupported, SubclassRegistryBase): @classmethod def from_common( @@ -265,7 +266,7 @@ def from_common( confidence=confidence) -class NDPromptText(NDAnnotation, NDPromptTextSubclass): +class NDPromptText(NDAnnotation, NDPromptTextSubclass, SubclassRegistryBase): @classmethod def from_common( @@ -404,8 +405,6 @@ def from_common( annotation.confidence) -# Make sure to keep NDChecklistSubclass prior to NDRadioSubclass in the list, -# otherwise list of answers gets parsed by NDRadio whereas NDChecklist must be used NDSubclassificationType = Union[NDChecklistSubclass, NDRadioSubclass, NDTextSubclass] @@ -418,7 +417,6 @@ def from_common( NDPromptText.update_forward_refs() NDTextSubclass.update_forward_refs() -# Make sure to keep NDChecklist prior to NDRadio in the list, -# otherwise list of answers gets parsed by NDRadio whereas NDChecklist must be used + NDClassificationType = Union[NDChecklist, NDRadio, NDText] NDPromptClassificationType = Union[NDPromptText] \ No newline at end of file diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py index a7c54b109..e4edd2561 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py @@ -15,6 +15,7 @@ from ...annotation_types.collection import LabelCollection, LabelGenerator from ...annotation_types.relationship import RelationshipAnnotation from .label import NDLabel +import copy logger = logging.getLogger(__name__) @@ -33,7 +34,9 @@ def deserialize(json_data: Iterable[Dict[str, Any]]) -> LabelGenerator: Returns: LabelGenerator containing the ndjson data. """ - data = NDLabel(**{"annotations": json_data}) + + data = copy.deepcopy(json_data) + data = NDLabel(**{"annotations": data}) res = data.to_common() return res @@ -108,7 +111,6 @@ def serialize( label.annotations = uuid_safe_annotations for example in NDLabel.from_common([label]): annotation_uuid = getattr(example, "uuid", None) - res = example.dict( by_alias=True, exclude={"uuid"} if annotation_uuid == "None" else None, diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py index 9d34c451b..884b06ee1 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py @@ -20,10 +20,13 @@ from ...annotation_types.llm_prompt_response.prompt import PromptClassificationAnnotation from .metric import NDScalarMetric, NDMetricAnnotation, NDConfusionMatrixMetric -from .classification import NDChecklistSubclass, NDClassification, NDClassificationType, NDRadioSubclass, NDPromptClassification, NDPromptClassificationType, NDPromptText +from .classification import NDClassification, NDClassificationType, NDPromptClassification, NDPromptClassificationType, NDPromptText, NDChecklistSubclass, NDRadioSubclass from .objects import NDObject, NDObjectType, NDSegments, NDDicomSegments, NDVideoMasks, NDDicomMasks from .relationship import NDRelationship from .base import DataRow +from labelbox.utils import camel_case +from labelbox.data.serialization.ndjson.base import SubclassRegistryBase, subclass_registry +from contextlib import suppress AnnotationType = Union[NDObjectType, NDClassificationType, NDPromptClassificationType, NDConfusionMatrixMetric, NDScalarMetric, NDDicomSegments, @@ -32,8 +35,51 @@ class NDLabel(pydantic_compat.BaseModel): - annotations: List[AnnotationType] + annotations: List[SubclassRegistryBase] + + def __init__(self, **kwargs): + # NOTE: Deserialization of subclasses in pydantic is difficult, see here https://blog.devgenius.io/deserialize-child-classes-with-pydantic-that-gonna-work-784230e1cf83 + # Below implements the subclass registry as mentioned in the article. The python dicts we pass in can be missing certain fields + # we essentially have to infer the type against all sub classes that have the SubclasssRegistryBase inheritance. + # It works by checking if the keys of our annotations any required keys inside subclasses. + # More keys are prioritized over less keys (closer match). This is used when importing json to our base models not a lot of customer workflows + # depending on this method but this works for all our existing tests with the bonus of added validation. (no subclass found it throws an error) + # Previous strategies hacked but dont work for pydantic V2 they also make this part of the code less complicated prior solutions depended on order + # of how classes were shown on Python file to work. This should open the door to cut out a lot of the library specifically some subclasses. + + for index in range(len(kwargs["annotations"])): + annotation = kwargs["annotations"][index] + if isinstance(annotation, dict): + item_annotation_keys = annotation.keys() + key_subclass_combos = defaultdict(list) + for subclass in subclass_registry.values(): + subclass = subclass + # Get all required keys from subclass + annotation_keys = [] + for k, field in subclass.__fields__.items(): + # must account for alias + if hasattr(field, "alias") and field.alias == "answers" and "answers" in item_annotation_keys: + annotation_keys.append("answers") + elif field.required is True and k != "uuid": + annotation_keys.append(camel_case(k)) + key_subclass_combos[subclass].extend(annotation_keys) + # Sort by subclass that has the most keys i.e. the one with the most keys if a match is likely our class + key_subclass_combos = dict(sorted(key_subclass_combos.items(), key = lambda x : len(x[1]), reverse=True)) + + # Choose the keys from our dict we supplied that matches the required keys of a subclass + for subclass, key_subclass_combo in key_subclass_combos.items(): + check_required_keys = all(key in list(item_annotation_keys) for key in key_subclass_combo) + if check_required_keys: + # Keep trying subclasses until we find one that has valid values + with suppress(pydantic_compat.ValidationError): + annotation = subclass(**annotation) + break + if isinstance(annotation, dict): + raise ValueError(f"Could not find subclass for fields: {item_annotation_keys}") + kwargs["annotations"][index] = annotation + super().__init__(**kwargs) + class _Relationship(pydantic_compat.BaseModel): """This object holds information about the relationship""" ndjson: NDRelationship diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/metric.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/metric.py index 5abbf2761..e01de247e 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/metric.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/metric.py @@ -1,7 +1,7 @@ from typing import Optional, Union, Type from labelbox.data.annotation_types.data import ImageData, TextData -from labelbox.data.serialization.ndjson.base import DataRow, NDJsonBase +from labelbox.data.serialization.ndjson.base import DataRow, NDJsonBase, SubclassRegistryBase from labelbox.data.annotation_types.metrics.scalar import ( ScalarMetric, ScalarMetricAggregation, ScalarMetricValue, ScalarMetricConfidenceValue) @@ -26,7 +26,7 @@ def dict(self, *args, **kwargs): return res -class NDConfusionMatrixMetric(BaseNDMetric): +class NDConfusionMatrixMetric(BaseNDMetric, SubclassRegistryBase): metric_value: Union[ConfusionMatrixMetricValue, ConfusionMatrixMetricConfidenceValue] metric_name: str @@ -53,10 +53,10 @@ def from_common( data_row=DataRow(id=data.uid, global_key=data.global_key)) -class NDScalarMetric(BaseNDMetric): +class NDScalarMetric(BaseNDMetric, SubclassRegistryBase): metric_value: Union[ScalarMetricValue, ScalarMetricConfidenceValue] metric_name: Optional[str] - aggregation: ScalarMetricAggregation = ScalarMetricAggregation.ARITHMETIC_MEAN + aggregation: Optional[ScalarMetricAggregation] = ScalarMetricAggregation.ARITHMETIC_MEAN def to_common(self) -> ScalarMetric: return ScalarMetric(value=self.metric_value, diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py index fd13a6bf6..7ddbe47af 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py @@ -19,8 +19,10 @@ from ...annotation_types.geometry import DocumentRectangle, Rectangle, Polygon, Line, Point, Mask from ...annotation_types.annotation import ClassificationAnnotation, ObjectAnnotation from ...annotation_types.video import VideoMaskAnnotation, DICOMMaskAnnotation, MaskFrame, MaskInstance -from .classification import NDClassification, NDSubclassification, NDSubclassificationType -from .base import DataRow, NDAnnotation, NDJsonBase +from .classification import NDSubclassification, NDSubclassificationType +from labelbox.data.serialization.ndjson.base import DataRow, NDAnnotation, NDJsonBase, SubclassRegistryBase + + class NDBaseObject(NDAnnotation): @@ -48,7 +50,7 @@ class Bbox(pydantic_compat.BaseModel): width: float -class NDPoint(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): +class NDPoint(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, SubclassRegistryBase): point: _Point def to_common(self) -> Point: @@ -109,7 +111,7 @@ def from_common( classifications=classifications) -class NDLine(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): +class NDLine(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, SubclassRegistryBase): line: List[_Point] def to_common(self) -> Line: @@ -187,7 +189,7 @@ def to_common(self, name: str, feature_schema_id: Cuid, segment_index: int, group_key=group_key) -class NDPolygon(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): +class NDPolygon(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, SubclassRegistryBase): polygon: List[_Point] def to_common(self) -> Polygon: @@ -218,7 +220,7 @@ def from_common( custom_metrics=custom_metrics) -class NDRectangle(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): +class NDRectangle(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, SubclassRegistryBase): bbox: Bbox def to_common(self) -> Rectangle: @@ -254,7 +256,7 @@ def from_common( custom_metrics=custom_metrics) -class NDDocumentRectangle(NDRectangle): +class NDDocumentRectangle(NDRectangle, SubclassRegistryBase): page: int unit: str @@ -362,7 +364,6 @@ def to_common(self, name: str, feature_schema_id: Cuid, uuid: str, @classmethod def from_common(cls, segment): nd_frame_object_type = cls.lookup_segment_object_type(segment) - return cls(keyframes=[ nd_frame_object_type.from_common( object_annotation.frame, object_annotation.value, [ @@ -398,7 +399,7 @@ def to_common(self, name: str, feature_schema_id: Cuid, uuid: str, ] -class NDSegments(NDBaseObject): +class NDSegments(NDBaseObject, SubclassRegistryBase): segments: List[NDSegment] def to_common(self, name: str, feature_schema_id: Cuid): @@ -425,7 +426,7 @@ def from_common(cls, segments: List[VideoObjectAnnotation], data: VideoData, uuid=extra.get('uuid')) -class NDDicomSegments(NDBaseObject, DicomSupported): +class NDDicomSegments(NDBaseObject, DicomSupported, SubclassRegistryBase): segments: List[NDDicomSegment] def to_common(self, name: str, feature_schema_id: Cuid): @@ -463,7 +464,7 @@ class _PNGMask(pydantic_compat.BaseModel): png: str -class NDMask(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): +class NDMask(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, SubclassRegistryBase): mask: Union[_URIMask, _PNGMask] def to_common(self) -> Mask: @@ -517,7 +518,7 @@ class NDVideoMasksFramesInstances(pydantic_compat.BaseModel): instances: List[MaskInstance] -class NDVideoMasks(NDJsonBase, ConfidenceMixin, CustomMetricsNotSupportedMixin): +class NDVideoMasks(NDJsonBase, ConfidenceMixin, CustomMetricsNotSupportedMixin, SubclassRegistryBase): masks: NDVideoMasksFramesInstances def to_common(self) -> VideoMaskAnnotation: @@ -545,7 +546,7 @@ def from_common(cls, annotation, data): ) -class NDDicomMasks(NDVideoMasks, DicomSupported): +class NDDicomMasks(NDVideoMasks, DicomSupported, SubclassRegistryBase): def to_common(self) -> DICOMMaskAnnotation: return DICOMMaskAnnotation( @@ -569,7 +570,7 @@ class Location(pydantic_compat.BaseModel): end: int -class NDTextEntity(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): +class NDTextEntity(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, SubclassRegistryBase): location: Location def to_common(self) -> TextEntity: @@ -601,7 +602,7 @@ def from_common( custom_metrics=custom_metrics) -class NDDocumentEntity(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): +class NDDocumentEntity(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, SubclassRegistryBase): name: str text_selections: List[DocumentTextSelection] @@ -633,7 +634,7 @@ def from_common( custom_metrics=custom_metrics) -class NDConversationEntity(NDTextEntity): +class NDConversationEntity(NDTextEntity, SubclassRegistryBase): message_id: str def to_common(self) -> ConversationEntity: @@ -773,12 +774,9 @@ def lookup_object( return result -# NOTE: Deserialization of subclasses in pydantic is a known PIA, see here https://blog.devgenius.io/deserialize-child-classes-with-pydantic-that-gonna-work-784230e1cf83 -# I could implement the registry approach suggested there, but I found that if I list subclass (that has more attributes) before the parent class, it works -# This is a bit of a hack, but it works for now NDEntityType = Union[NDConversationEntity, NDTextEntity] NDObjectType = Union[NDLine, NDPolygon, NDPoint, NDDocumentRectangle, NDRectangle, NDMask, NDEntityType, NDDocumentEntity] -NDFrameObjectType = NDFrameRectangle, NDFramePoint, NDFrameLine +NDFrameObjectType = Union[NDFrameRectangle, NDFramePoint, NDFrameLine] NDDicomObjectType = NDDicomLine diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/relationship.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/relationship.py index 82976aedb..f42cc3b7a 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/relationship.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/relationship.py @@ -5,7 +5,7 @@ from ...annotation_types.relationship import RelationshipAnnotation from ...annotation_types.relationship import Relationship from .objects import NDObjectType -from .base import DataRow +from labelbox.data.serialization.ndjson.base import DataRow, SubclassRegistryBase SUPPORTED_ANNOTATIONS = NDObjectType @@ -16,7 +16,7 @@ class _Relationship(pydantic_compat.BaseModel): type: str -class NDRelationship(NDAnnotation): +class NDRelationship(NDAnnotation, SubclassRegistryBase): relationship: _Relationship @staticmethod From e2383ed16fe5a6bfa5ebd08f4c519ba2170abe5e Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Wed, 31 Jul 2024 11:08:17 -0500 Subject: [PATCH 6/7] remove bad file --- libs/labelbox/tests/conftest.py | 1086 ------------------------------- 1 file changed, 1086 deletions(-) delete mode 100644 libs/labelbox/tests/conftest.py diff --git a/libs/labelbox/tests/conftest.py b/libs/labelbox/tests/conftest.py deleted file mode 100644 index db47cc071..000000000 --- a/libs/labelbox/tests/conftest.py +++ /dev/null @@ -1,1086 +0,0 @@ -from datetime import datetime -from random import randint -from string import ascii_letters - -import json -import os -import re -import uuid -import time -import requests -import pytest -from types import SimpleNamespace -from typing import Type -from enum import Enum -from typing import Tuple - -from labelbox import Dataset, DataRow -from labelbox import MediaType -from labelbox.orm import query -from labelbox.pagination import PaginatedCollection -from labelbox.schema.invite import Invite -from labelbox.schema.quality_mode import QualityMode -from labelbox.schema.queue_mode import QueueMode -from labelbox import Client - -from labelbox import Dataset, DataRow -from labelbox import LabelingFrontend -from labelbox import OntologyBuilder, Tool, Option, Classification, MediaType -from labelbox.orm import query -from labelbox.pagination import PaginatedCollection -from labelbox.schema.annotation_import import LabelImport -from labelbox.schema.catalog import Catalog -from labelbox.schema.enums import AnnotationImportState -from labelbox.schema.invite import Invite -from labelbox.schema.quality_mode import QualityMode -from labelbox.schema.queue_mode import QueueMode -from labelbox.schema.user import User -from labelbox.exceptions import LabelboxError -from contextlib import suppress -from labelbox import Client - -IMG_URL = "https://picsum.photos/200/300.jpg" -MASKABLE_IMG_URL = "https://storage.googleapis.com/labelbox-datasets/image_sample_data/2560px-Kitano_Street_Kobe01s5s4110.jpeg" -SMALL_DATASET_URL = "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/potato.jpeg" -DATA_ROW_PROCESSING_WAIT_TIMEOUT_SECONDS = 30 -DATA_ROW_PROCESSING_WAIT_SLEEP_INTERNAL_SECONDS = 3 -EPHEMERAL_BASE_URL = "http://lb-api-public" -IMAGE_URL = "https://storage.googleapis.com/diagnostics-demo-data/coco/COCO_train2014_000000000034.jpg" -EXTERNAL_ID = "my-image" - -pytest_plugins = [] - - -@pytest.fixture(scope="session") -def rand_gen(): - - def gen(field_type): - if field_type is str: - return "".join(ascii_letters[randint(0, - len(ascii_letters) - 1)] - for _ in range(16)) - - if field_type is datetime: - return datetime.now() - - raise Exception("Can't random generate for field type '%r'" % - field_type) - - return gen - - -class Environ(Enum): - LOCAL = 'local' - PROD = 'prod' - STAGING = 'staging' - CUSTOM = 'custom' - STAGING_EU = 'staging-eu' - EPHEMERAL = 'ephemeral' # Used for testing PRs with ephemeral environments - - -@pytest.fixture -def image_url() -> str: - return MASKABLE_IMG_URL - - -@pytest.fixture -def external_id() -> str: - return EXTERNAL_ID - - -def ephemeral_endpoint() -> str: - return os.getenv('LABELBOX_TEST_BASE_URL', EPHEMERAL_BASE_URL) - - -def graphql_url(environ: str) -> str: - if environ == Environ.LOCAL: - return 'http://localhost:3000/api/graphql' - elif environ == Environ.PROD: - return 'https://api.labelbox.com/graphql' - elif environ == Environ.STAGING: - return 'https://api.lb-stage.xyz/graphql' - elif environ == Environ.CUSTOM: - graphql_api_endpoint = os.environ.get( - 'LABELBOX_TEST_GRAPHQL_API_ENDPOINT') - if graphql_api_endpoint is None: - raise Exception("Missing LABELBOX_TEST_GRAPHQL_API_ENDPOINT") - return graphql_api_endpoint - elif environ == Environ.EPHEMERAL: - return f"{ephemeral_endpoint()}/graphql" - return 'http://host.docker.internal:8080/graphql' - - -def rest_url(environ: str) -> str: - if environ == Environ.LOCAL: - return 'http://localhost:3000/api/v1' - elif environ == Environ.PROD: - return 'https://api.labelbox.com/api/v1' - elif environ == Environ.STAGING: - return 'https://api.lb-stage.xyz/api/v1' - elif environ == Environ.CUSTOM: - rest_api_endpoint = os.environ.get('LABELBOX_TEST_REST_API_ENDPOINT') - if rest_api_endpoint is None: - raise Exception("Missing LABELBOX_TEST_REST_API_ENDPOINT") - return rest_api_endpoint - elif environ == Environ.EPHEMERAL: - return f"{ephemeral_endpoint()}/api/v1" - return 'http://host.docker.internal:8080/api/v1' - - -def testing_api_key(environ: Environ) -> str: - keys = [ - f"LABELBOX_TEST_API_KEY_{environ.value.upper()}", - "LABELBOX_TEST_API_KEY", - "LABELBOX_API_KEY" - ] - for key in keys: - value = os.environ.get(key) - if value is not None: - return value - raise Exception("Cannot find API to use for tests") - - -def service_api_key() -> str: - service_api_key = os.environ["SERVICE_API_KEY"] - if service_api_key is None: - raise Exception( - "SERVICE_API_KEY is missing and needed for admin client") - return service_api_key - - -class IntegrationClient(Client): - - def __init__(self, environ: str) -> None: - api_url = graphql_url(environ) - api_key = testing_api_key(environ) - rest_endpoint = rest_url(environ) - super().__init__(api_key, - api_url, - enable_experimental=True, - rest_endpoint=rest_endpoint) - self.queries = [] - - def execute(self, query=None, params=None, check_naming=True, **kwargs): - if check_naming and query is not None: - assert re.match(r"\s*(?:query|mutation) \w+PyApi", - query) is not None - self.queries.append((query, params)) - if not kwargs.get('timeout'): - kwargs['timeout'] = 30.0 - return super().execute(query, params, **kwargs) - - -class AdminClient(Client): - - def __init__(self, env): - """ - The admin client creates organizations and users using admin api described here https://labelbox.atlassian.net/wiki/spaces/AP/pages/2206564433/Internal+Admin+APIs. - """ - self._api_key = service_api_key() - self._admin_endpoint = f"{ephemeral_endpoint()}/admin/v1" - self._api_url = graphql_url(env) - self._rest_endpoint = rest_url(env) - - super().__init__(self._api_key, - self._api_url, - enable_experimental=True, - rest_endpoint=self._rest_endpoint) - - def _create_organization(self) -> str: - endpoint = f"{self._admin_endpoint}/organizations/" - response = requests.post( - endpoint, - headers=self.headers, - json={"name": f"Test Org {uuid.uuid4()}"}, - ) - - data = response.json() - if response.status_code not in [ - requests.codes.created, requests.codes.ok - ]: - raise Exception("Failed to create org, message: " + - str(data['message'])) - - return data['id'] - - def _create_user(self, organization_id=None) -> Tuple[str, str]: - if organization_id is None: - organization_id = self.organization_id - - endpoint = f"{self._admin_endpoint}/user-identities/" - identity_id = f"e2e+{uuid.uuid4()}" - - response = requests.post( - endpoint, - headers=self.headers, - json={ - "identityId": identity_id, - "email": "email@email.com", - "name": f"tester{uuid.uuid4()}", - "verificationStatus": "VERIFIED", - }, - ) - data = response.json() - if response.status_code not in [ - requests.codes.created, requests.codes.ok - ]: - raise Exception("Failed to create user, message: " + - str(data['message'])) - - user_identity_id = data['identityId'] - - endpoint = f"{self._admin_endpoint}/organizations/{organization_id}/users/" - response = requests.post( - endpoint, - headers=self.headers, - json={ - "identityId": user_identity_id, - "organizationRole": "Admin" - }, - ) - - data = response.json() - if response.status_code not in [ - requests.codes.created, requests.codes.ok - ]: - raise Exception("Failed to create link user to org, message: " + - str(data['message'])) - - user_id = data['id'] - - endpoint = f"{self._admin_endpoint}/users/{user_id}/token" - response = requests.get( - endpoint, - headers=self.headers, - ) - data = response.json() - if response.status_code not in [ - requests.codes.created, requests.codes.ok - ]: - raise Exception("Failed to create ephemeral user, message: " + - str(data['message'])) - - token = data["token"] - - return user_id, token - - def create_api_key_for_user(self) -> str: - organization_id = self._create_organization() - _, user_token = self._create_user(organization_id) - key_name = f"test-key+{uuid.uuid4()}" - query = """mutation CreateApiKeyPyApi($name: String!) { - createApiKey(data: {name: $name}) { - id - jwt - } - } - """ - params = {"name": key_name} - self.headers["Authorization"] = f"Bearer {user_token}" - res = self.execute(query, params, error_log_key="errors") - - return res["createApiKey"]["jwt"] - - -class EphemeralClient(Client): - - def __init__(self, environ=Environ.EPHEMERAL): - self.admin_client = AdminClient(environ) - self.api_key = self.admin_client.create_api_key_for_user() - api_url = graphql_url(environ) - rest_endpoint = rest_url(environ) - - super().__init__(self.api_key, - api_url, - enable_experimental=True, - rest_endpoint=rest_endpoint) - - -@pytest.fixture -def ephmeral_client() -> EphemeralClient: - return EphemeralClient - - -@pytest.fixture -def admin_client() -> AdminClient: - return AdminClient - - -@pytest.fixture -def integration_client() -> IntegrationClient: - return IntegrationClient - - -@pytest.fixture(scope="session") -def environ() -> Environ: - """ - Checks environment variables for LABELBOX_ENVIRON to be - 'prod' or 'staging' - Make sure to set LABELBOX_TEST_ENVIRON in .github/workflows/python-package.yaml - """ - keys = [ - "LABELBOX_TEST_ENV", - "LABELBOX_TEST_ENVIRON", - "LABELBOX_ENV" - ] - for key in keys: - value = os.environ.get(key) - if value is not None: - return Environ(value) - raise Exception(f'Missing env key in: {os.environ}') - - -def cancel_invite(client, invite_id): - """ - Do not use. Only for testing. - """ - query_str = """mutation CancelInvitePyApi($where: WhereUniqueIdInput!) { - cancelInvite(where: $where) {id}}""" - client.execute(query_str, {'where': {'id': invite_id}}, experimental=True) - - -def get_project_invites(client, project_id): - """ - Do not use. Only for testing. - """ - id_param = "projectId" - query_str = """query GetProjectInvitationsPyApi($from: ID, $first: PageSize, $%s: ID!) { - project(where: {id: $%s}) {id - invites(from: $from, first: $first) { nodes { %s - projectInvites { projectId projectRoleName } } nextCursor}}} - """ % (id_param, id_param, query.results_query_part(Invite)) - return PaginatedCollection(client, - query_str, {id_param: project_id}, - ['project', 'invites', 'nodes'], - Invite, - cursor_path=['project', 'invites', 'nextCursor']) - - -def get_invites(client): - """ - Do not use. Only for testing. - """ - query_str = """query GetOrgInvitationsPyApi($from: ID, $first: PageSize) { - organization { id invites(from: $from, first: $first) { - nodes { id createdAt organizationRoleName inviteeEmail } nextCursor }}}""" - invites = PaginatedCollection( - client, - query_str, {}, ['organization', 'invites', 'nodes'], - Invite, - cursor_path=['organization', 'invites', 'nextCursor'], - experimental=True) - return invites - - -@pytest.fixture -def queries(): - return SimpleNamespace(cancel_invite=cancel_invite, - get_project_invites=get_project_invites, - get_invites=get_invites) - - -@pytest.fixture(scope="session") -def admin_client(environ: str): - return AdminClient(environ) - - -@pytest.fixture(scope="session") -def client(environ: str): - if environ == Environ.EPHEMERAL: - return EphemeralClient() - return IntegrationClient(environ) - - -@pytest.fixture(scope="session") -def pdf_url(client): - pdf_url = client.upload_file('tests/assets/loremipsum.pdf') - return {"row_data": {"pdf_url": pdf_url,}, "global_key": str(uuid.uuid4())} - - -@pytest.fixture(scope="session") -def pdf_entity_data_row(client): - pdf_url = client.upload_file( - 'tests/assets/arxiv-pdf_data_99-word-token-pdfs_0801.3483.pdf') - text_layer_url = client.upload_file( - 'tests/assets/arxiv-pdf_data_99-word-token-pdfs_0801.3483-lb-textlayer.json' - ) - - return { - "row_data": { - "pdf_url": pdf_url, - "text_layer_url": text_layer_url - }, - "global_key": str(uuid.uuid4()) - } - - -@pytest.fixture() -def conversation_entity_data_row(client, rand_gen): - return { - "row_data": - "https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json", - "global_key": - f"https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json-{rand_gen(str)}", - } - - -@pytest.fixture -def project(client, rand_gen): - project = client.create_project(name=rand_gen(str), - queue_mode=QueueMode.Batch, - media_type=MediaType.Image) - yield project - project.delete() - - -@pytest.fixture -def consensus_project(client, rand_gen): - project = client.create_project(name=rand_gen(str), - quality_mode=QualityMode.Consensus, - queue_mode=QueueMode.Batch, - media_type=MediaType.Image) - yield project - project.delete() - - -@pytest.fixture -def model_config(client, rand_gen, valid_model_id): - model_config = client.create_model_config( - name=rand_gen(str), - model_id=valid_model_id, - inference_params={"param": "value"}) - yield model_config - client.delete_model_config(model_config.uid) - - -@pytest.fixture -def consensus_project_with_batch(consensus_project, initial_dataset, rand_gen, - image_url): - project = consensus_project - dataset = initial_dataset - - data_rows = [] - for _ in range(3): - data_rows.append({ - DataRow.row_data: image_url, - DataRow.global_key: str(uuid.uuid4()) - }) - task = dataset.create_data_rows(data_rows) - task.wait_till_done() - assert task.status == "COMPLETE" - - data_rows = list(dataset.data_rows()) - assert len(data_rows) == 3 - batch = project.create_batch( - rand_gen(str), - data_rows, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) - ) - - yield [project, batch, data_rows] - batch.delete() - - -@pytest.fixture -def dataset(client, rand_gen): - dataset = client.create_dataset(name=rand_gen(str)) - yield dataset - dataset.delete() - - -@pytest.fixture(scope='function') -def unique_dataset(client, rand_gen): - dataset = client.create_dataset(name=rand_gen(str)) - yield dataset - dataset.delete() - - -@pytest.fixture -def small_dataset(dataset: Dataset): - task = dataset.create_data_rows([ - { - "row_data": SMALL_DATASET_URL, - "external_id": "my-image" - }, - ] * 2) - task.wait_till_done() - - yield dataset - - -@pytest.fixture -def data_row(dataset, image_url, rand_gen): - global_key = f"global-key-{rand_gen(str)}" - task = dataset.create_data_rows([ - { - "row_data": image_url, - "external_id": "my-image", - "global_key": global_key - }, - ]) - task.wait_till_done() - dr = dataset.data_rows().get_one() - yield dr - dr.delete() - - -@pytest.fixture -def data_row_and_global_key(dataset, image_url, rand_gen): - global_key = f"global-key-{rand_gen(str)}" - task = dataset.create_data_rows([ - { - "row_data": image_url, - "external_id": "my-image", - "global_key": global_key - }, - ]) - task.wait_till_done() - dr = dataset.data_rows().get_one() - yield dr, global_key - dr.delete() - - -# can be used with -# @pytest.mark.parametrize('data_rows', [], indirect=True) -# if omitted, count defaults to 1 -@pytest.fixture -def data_rows(dataset, image_url, request, wait_for_data_row_processing, - client): - count = 1 - if hasattr(request, 'param'): - count = request.param - - datarows = [ - dict(row_data=image_url, global_key=f"global-key-{uuid.uuid4()}") - for _ in range(count) - ] - - task = dataset.create_data_rows(datarows) - task.wait_till_done() - datarows = dataset.data_rows().get_many(count) - for dr in dataset.data_rows(): - wait_for_data_row_processing(client, dr) - - yield datarows - - for datarow in datarows: - datarow.delete() - - -@pytest.fixture -def iframe_url(environ) -> str: - if environ in [Environ.PROD, Environ.LOCAL]: - return 'https://editor.labelbox.com' - elif environ == Environ.STAGING: - return 'https://editor.lb-stage.xyz' - - -@pytest.fixture -def sample_image() -> str: - path_to_video = 'tests/integration/media/sample_image.jpg' - return path_to_video - - -@pytest.fixture -def sample_video() -> str: - path_to_video = 'tests/integration/media/cat.mp4' - return path_to_video - - -@pytest.fixture -def sample_bulk_conversation() -> list: - path_to_conversation = 'tests/integration/media/bulk_conversation.json' - with open(path_to_conversation) as json_file: - conversations = json.load(json_file) - return conversations - - -@pytest.fixture -def organization(client): - # Must have at least one seat open in your org to run these tests - org = client.get_organization() - # Clean up before and after incase this wasn't run for some reason. - for invite in get_invites(client): - if "@labelbox.com" in invite.email: - cancel_invite(client, invite.uid) - yield org - for invite in get_invites(client): - if "@labelbox.com" in invite.email: - cancel_invite(client, invite.uid) - - -@pytest.fixture -def configured_project_with_label(client, rand_gen, image_url, project, dataset, - data_row, wait_for_label_processing): - """Project with a connected dataset, having one datarow - Project contains an ontology with 1 bbox tool - Additionally includes a create_label method for any needed extra labels - One label is already created and yielded when using fixture - """ - project._wait_until_data_rows_are_processed( - data_row_ids=[data_row.uid], - wait_processing_max_seconds=DATA_ROW_PROCESSING_WAIT_TIMEOUT_SECONDS, - sleep_interval=DATA_ROW_PROCESSING_WAIT_SLEEP_INTERNAL_SECONDS) - - project.create_batch( - rand_gen(str), - [data_row.uid], # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) - ) - ontology = _setup_ontology(project) - label = _create_label(project, data_row, ontology, - wait_for_label_processing) - yield [project, dataset, data_row, label] - - for label in project.labels(): - label.delete() - - -def _create_label(project, data_row, ontology, wait_for_label_processing): - predictions = [{ - "uuid": str(uuid.uuid4()), - "schemaId": ontology.tools[0].feature_schema_id, - "dataRow": { - "id": data_row.uid - }, - "bbox": { - "top": 20, - "left": 20, - "height": 50, - "width": 50 - } - }] - - def create_label(): - """ Ad-hoc function to create a LabelImport - Creates a LabelImport task which will create a label - """ - upload_task = LabelImport.create_from_objects( - project.client, project.uid, f'label-import-{uuid.uuid4()}', - predictions) - upload_task.wait_until_done(sleep_time_seconds=5) - assert upload_task.state == AnnotationImportState.FINISHED, "Label Import did not finish" - assert len( - upload_task.errors - ) == 0, f"Label Import {upload_task.name} failed with errors {upload_task.errors}" - - project.create_label = create_label - project.create_label() - label = wait_for_label_processing(project)[0] - return label - - -def _setup_ontology(project): - editor = list( - project.client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] - ontology_builder = OntologyBuilder(tools=[ - Tool(tool=Tool.Type.BBOX, name="test-bbox-class"), - ]) - project.setup(editor, ontology_builder.asdict()) - # TODO: ontology may not be synchronous after setup. remove sleep when api is more consistent - time.sleep(2) - return OntologyBuilder.from_project(project) - - -@pytest.fixture -def big_dataset(dataset: Dataset): - task = dataset.create_data_rows([ - { - "row_data": IMAGE_URL, - "external_id": EXTERNAL_ID - }, - ] * 3) - task.wait_till_done() - - yield dataset - - -@pytest.fixture -def configured_batch_project_with_label(project, dataset, data_row, - wait_for_label_processing): - """Project with a batch having one datarow - Project contains an ontology with 1 bbox tool - Additionally includes a create_label method for any needed extra labels - One label is already created and yielded when using fixture - """ - data_rows = [dr.uid for dr in list(dataset.data_rows())] - project._wait_until_data_rows_are_processed(data_row_ids=data_rows, - sleep_interval=3) - project.create_batch("test-batch", data_rows) - project.data_row_ids = data_rows - - ontology = _setup_ontology(project) - label = _create_label(project, data_row, ontology, - wait_for_label_processing) - - yield [project, dataset, data_row, label] - - for label in project.labels(): - label.delete() - - -@pytest.fixture -def configured_batch_project_with_multiple_datarows(project, dataset, data_rows, - wait_for_label_processing): - """Project with a batch having multiple datarows - Project contains an ontology with 1 bbox tool - Additionally includes a create_label method for any needed extra labels - """ - global_keys = [dr.global_key for dr in data_rows] - - batch_name = f'batch {uuid.uuid4()}' - project.create_batch(batch_name, global_keys=global_keys) - - ontology = _setup_ontology(project) - for datarow in data_rows: - _create_label(project, datarow, ontology, wait_for_label_processing) - - yield [project, dataset, data_rows] - - for label in project.labels(): - label.delete() - - -# NOTE this is nice heuristics, also there is this logic _wait_until_data_rows_are_processed in Project -# in case we still have flakiness in the future, we can use it -@pytest.fixture -def wait_for_data_row_processing(): - """ - Do not use. Only for testing. - - Returns DataRow after waiting for it to finish processing media_attributes. - Some tests, specifically ones that rely on label export, rely on - DataRow be fully processed with media_attributes - """ - - def func(client, data_row, custom_check=None): - """ - added check_updated_at because when a data_row is updated from say - an image to pdf, it already has media_attributes and the loop does - not wait for processing to a pdf - """ - data_row_id = data_row.uid - timeout_seconds = 60 - while True: - data_row = client.get_data_row(data_row_id) - passed_custom_check = not custom_check or custom_check(data_row) - if data_row.media_attributes and passed_custom_check: - return data_row - timeout_seconds -= 2 - if timeout_seconds <= 0: - raise TimeoutError( - f"Timed out waiting for DataRow '{data_row_id}' to finish processing media_attributes" - ) - time.sleep(2) - - return func - - -@pytest.fixture -def wait_for_label_processing(): - """ - Do not use. Only for testing. - - Returns project's labels as a list after waiting for them to finish processing. - If `project.labels()` is called before label is fully processed, - it may return an empty set - """ - - def func(project): - timeout_seconds = 10 - while True: - labels = list(project.labels()) - if len(labels) > 0: - return labels - timeout_seconds -= 2 - if timeout_seconds <= 0: - raise TimeoutError( - f"Timed out waiting for label for project '{project.uid}' to finish processing" - ) - time.sleep(2) - - return func - - -@pytest.fixture -def initial_dataset(client, rand_gen): - dataset = client.create_dataset(name=rand_gen(str)) - yield dataset - - dataset.delete() - - -@pytest.fixture -def video_data(client, rand_gen, video_data_row, wait_for_data_row_processing): - dataset = client.create_dataset(name=rand_gen(str)) - data_row_ids = [] - data_row = dataset.create_data_row(video_data_row) - data_row = wait_for_data_row_processing(client, data_row) - data_row_ids.append(data_row.uid) - yield dataset, data_row_ids - dataset.delete() - - -def create_video_data_row(rand_gen): - return { - "row_data": - "https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4", - "global_key": - f"https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4-{rand_gen(str)}", - "media_type": - "VIDEO", - } - - -@pytest.fixture -def video_data_100_rows(client, rand_gen, wait_for_data_row_processing): - dataset = client.create_dataset(name=rand_gen(str)) - data_row_ids = [] - for _ in range(100): - data_row = dataset.create_data_row(create_video_data_row(rand_gen)) - data_row = wait_for_data_row_processing(client, data_row) - data_row_ids.append(data_row.uid) - yield dataset, data_row_ids - dataset.delete() - - -@pytest.fixture() -def video_data_row(rand_gen): - return create_video_data_row(rand_gen) - - -class ExportV2Helpers: - - @classmethod - def run_project_export_v2_task(cls, - project, - num_retries=5, - task_name=None, - filters={}, - params={}): - task = None - params = params if params else { - "project_details": True, - "performance_details": False, - "data_row_details": True, - "label_details": True - } - while (num_retries > 0): - task = project.export_v2(task_name=task_name, - filters=filters, - params=params) - task.wait_till_done() - assert task.status == "COMPLETE" - assert task.errors is None - if len(task.result) == 0: - num_retries -= 1 - time.sleep(5) - else: - break - return task.result - - @classmethod - def run_dataset_export_v2_task(cls, - dataset, - num_retries=5, - task_name=None, - filters={}, - params={}): - task = None - params = params if params else { - "performance_details": False, - "label_details": True - } - while (num_retries > 0): - task = dataset.export_v2(task_name=task_name, - filters=filters, - params=params) - task.wait_till_done() - assert task.status == "COMPLETE" - assert task.errors is None - if len(task.result) == 0: - num_retries -= 1 - time.sleep(5) - else: - break - - return task.result - - @classmethod - def run_catalog_export_v2_task(cls, - client, - num_retries=5, - task_name=None, - filters={}, - params={}): - task = None - params = params if params else { - "performance_details": False, - "label_details": True - } - catalog = client.get_catalog() - while (num_retries > 0): - - task = catalog.export_v2(task_name=task_name, - filters=filters, - params=params) - task.wait_till_done() - assert task.status == "COMPLETE" - assert task.errors is None - if len(task.result) == 0: - num_retries -= 1 - time.sleep(5) - else: - break - - return task.result - - -@pytest.fixture -def export_v2_test_helpers() -> Type[ExportV2Helpers]: - return ExportV2Helpers() - - -@pytest.fixture -def big_dataset_data_row_ids(big_dataset: Dataset): - export_task = big_dataset.export() - export_task.wait_till_done() - stream = export_task.get_buffered_stream() - yield [dr.json["data_row"]["id"] for dr in stream] - - -@pytest.fixture(scope='function') -def dataset_with_invalid_data_rows(unique_dataset: Dataset, - upload_invalid_data_rows_for_dataset): - upload_invalid_data_rows_for_dataset(unique_dataset) - - yield unique_dataset - - -@pytest.fixture -def upload_invalid_data_rows_for_dataset(): - - def _upload_invalid_data_rows_for_dataset(dataset: Dataset): - task = dataset.create_data_rows([ - { - "row_data": 'gs://invalid-bucket/example.png', # forbidden - "external_id": "image-without-access.jpg" - }, - ] * 2) - task.wait_till_done() - - return _upload_invalid_data_rows_for_dataset - - -@pytest.fixture -def configured_project(project_with_empty_ontology, initial_dataset, rand_gen, - image_url): - dataset = initial_dataset - data_row_id = dataset.create_data_row(row_data=image_url).uid - project = project_with_empty_ontology - - batch = project.create_batch( - rand_gen(str), - [data_row_id], # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) - ) - project.data_row_ids = [data_row_id] - - yield project - - batch.delete() - - -@pytest.fixture -def project_with_empty_ontology(project): - editor = list( - project.client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] - empty_ontology = {"tools": [], "classifications": []} - project.setup(editor, empty_ontology) - yield project - - -@pytest.fixture -def configured_project_with_complex_ontology(client, initial_dataset, rand_gen, - image_url): - project = client.create_project(name=rand_gen(str), - queue_mode=QueueMode.Batch, - media_type=MediaType.Image) - dataset = initial_dataset - data_row = dataset.create_data_row(row_data=image_url) - data_row_ids = [data_row.uid] - - project.create_batch( - rand_gen(str), - data_row_ids, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) - ) - project.data_row_ids = data_row_ids - - editor = list( - project.client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] - - ontology = OntologyBuilder() - tools = [ - Tool(tool=Tool.Type.BBOX, name="test-bbox-class"), - Tool(tool=Tool.Type.LINE, name="test-line-class"), - Tool(tool=Tool.Type.POINT, name="test-point-class"), - Tool(tool=Tool.Type.POLYGON, name="test-polygon-class"), - Tool(tool=Tool.Type.NER, name="test-ner-class") - ] - - options = [ - Option(value="first option answer"), - Option(value="second option answer"), - Option(value="third option answer") - ] - - classifications = [ - Classification(class_type=Classification.Type.TEXT, - name="test-text-class"), - Classification(class_type=Classification.Type.DROPDOWN, - name="test-dropdown-class", - options=options), - Classification(class_type=Classification.Type.RADIO, - name="test-radio-class", - options=options), - Classification(class_type=Classification.Type.CHECKLIST, - name="test-checklist-class", - options=options) - ] - - for t in tools: - for c in classifications: - t.add_classification(c) - ontology.add_tool(t) - for c in classifications: - ontology.add_classification(c) - - project.setup(editor, ontology.asdict()) - - yield [project, data_row] - project.delete() - - -@pytest.fixture(scope="session") -def embedding(client: Client, environ): - - uuid_str = uuid.uuid4().hex - embedding = client.create_embedding(f"sdk-int-{uuid_str}", 8) - yield embedding - # Remove all embeddings on staging - if environ == Environ.STAGING: - embeddings = client.get_embeddings() - for embedding in embeddings: - with suppress(LabelboxError): - embedding.delete() - else: - embedding.delete() - - -@pytest.fixture -def valid_model_id(): - return "2c903542-d1da-48fd-9db1-8c62571bd3d2" From d47df233c75b4cd0cf0198a16460dce78d753c2a Mon Sep 17 00:00:00 2001 From: Gabe <33893811+Gabefire@users.noreply.github.com> Date: Wed, 31 Jul 2024 11:10:02 -0500 Subject: [PATCH 7/7] Create conftest.py --- libs/labelbox/tests/conftest.py | 1086 +++++++++++++++++++++++++++++++ 1 file changed, 1086 insertions(+) create mode 100644 libs/labelbox/tests/conftest.py diff --git a/libs/labelbox/tests/conftest.py b/libs/labelbox/tests/conftest.py new file mode 100644 index 000000000..db47cc071 --- /dev/null +++ b/libs/labelbox/tests/conftest.py @@ -0,0 +1,1086 @@ +from datetime import datetime +from random import randint +from string import ascii_letters + +import json +import os +import re +import uuid +import time +import requests +import pytest +from types import SimpleNamespace +from typing import Type +from enum import Enum +from typing import Tuple + +from labelbox import Dataset, DataRow +from labelbox import MediaType +from labelbox.orm import query +from labelbox.pagination import PaginatedCollection +from labelbox.schema.invite import Invite +from labelbox.schema.quality_mode import QualityMode +from labelbox.schema.queue_mode import QueueMode +from labelbox import Client + +from labelbox import Dataset, DataRow +from labelbox import LabelingFrontend +from labelbox import OntologyBuilder, Tool, Option, Classification, MediaType +from labelbox.orm import query +from labelbox.pagination import PaginatedCollection +from labelbox.schema.annotation_import import LabelImport +from labelbox.schema.catalog import Catalog +from labelbox.schema.enums import AnnotationImportState +from labelbox.schema.invite import Invite +from labelbox.schema.quality_mode import QualityMode +from labelbox.schema.queue_mode import QueueMode +from labelbox.schema.user import User +from labelbox.exceptions import LabelboxError +from contextlib import suppress +from labelbox import Client + +IMG_URL = "https://picsum.photos/200/300.jpg" +MASKABLE_IMG_URL = "https://storage.googleapis.com/labelbox-datasets/image_sample_data/2560px-Kitano_Street_Kobe01s5s4110.jpeg" +SMALL_DATASET_URL = "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/potato.jpeg" +DATA_ROW_PROCESSING_WAIT_TIMEOUT_SECONDS = 30 +DATA_ROW_PROCESSING_WAIT_SLEEP_INTERNAL_SECONDS = 3 +EPHEMERAL_BASE_URL = "http://lb-api-public" +IMAGE_URL = "https://storage.googleapis.com/diagnostics-demo-data/coco/COCO_train2014_000000000034.jpg" +EXTERNAL_ID = "my-image" + +pytest_plugins = [] + + +@pytest.fixture(scope="session") +def rand_gen(): + + def gen(field_type): + if field_type is str: + return "".join(ascii_letters[randint(0, + len(ascii_letters) - 1)] + for _ in range(16)) + + if field_type is datetime: + return datetime.now() + + raise Exception("Can't random generate for field type '%r'" % + field_type) + + return gen + + +class Environ(Enum): + LOCAL = 'local' + PROD = 'prod' + STAGING = 'staging' + CUSTOM = 'custom' + STAGING_EU = 'staging-eu' + EPHEMERAL = 'ephemeral' # Used for testing PRs with ephemeral environments + + +@pytest.fixture +def image_url() -> str: + return MASKABLE_IMG_URL + + +@pytest.fixture +def external_id() -> str: + return EXTERNAL_ID + + +def ephemeral_endpoint() -> str: + return os.getenv('LABELBOX_TEST_BASE_URL', EPHEMERAL_BASE_URL) + + +def graphql_url(environ: str) -> str: + if environ == Environ.LOCAL: + return 'http://localhost:3000/api/graphql' + elif environ == Environ.PROD: + return 'https://api.labelbox.com/graphql' + elif environ == Environ.STAGING: + return 'https://api.lb-stage.xyz/graphql' + elif environ == Environ.CUSTOM: + graphql_api_endpoint = os.environ.get( + 'LABELBOX_TEST_GRAPHQL_API_ENDPOINT') + if graphql_api_endpoint is None: + raise Exception("Missing LABELBOX_TEST_GRAPHQL_API_ENDPOINT") + return graphql_api_endpoint + elif environ == Environ.EPHEMERAL: + return f"{ephemeral_endpoint()}/graphql" + return 'http://host.docker.internal:8080/graphql' + + +def rest_url(environ: str) -> str: + if environ == Environ.LOCAL: + return 'http://localhost:3000/api/v1' + elif environ == Environ.PROD: + return 'https://api.labelbox.com/api/v1' + elif environ == Environ.STAGING: + return 'https://api.lb-stage.xyz/api/v1' + elif environ == Environ.CUSTOM: + rest_api_endpoint = os.environ.get('LABELBOX_TEST_REST_API_ENDPOINT') + if rest_api_endpoint is None: + raise Exception("Missing LABELBOX_TEST_REST_API_ENDPOINT") + return rest_api_endpoint + elif environ == Environ.EPHEMERAL: + return f"{ephemeral_endpoint()}/api/v1" + return 'http://host.docker.internal:8080/api/v1' + + +def testing_api_key(environ: Environ) -> str: + keys = [ + f"LABELBOX_TEST_API_KEY_{environ.value.upper()}", + "LABELBOX_TEST_API_KEY", + "LABELBOX_API_KEY" + ] + for key in keys: + value = os.environ.get(key) + if value is not None: + return value + raise Exception("Cannot find API to use for tests") + + +def service_api_key() -> str: + service_api_key = os.environ["SERVICE_API_KEY"] + if service_api_key is None: + raise Exception( + "SERVICE_API_KEY is missing and needed for admin client") + return service_api_key + + +class IntegrationClient(Client): + + def __init__(self, environ: str) -> None: + api_url = graphql_url(environ) + api_key = testing_api_key(environ) + rest_endpoint = rest_url(environ) + super().__init__(api_key, + api_url, + enable_experimental=True, + rest_endpoint=rest_endpoint) + self.queries = [] + + def execute(self, query=None, params=None, check_naming=True, **kwargs): + if check_naming and query is not None: + assert re.match(r"\s*(?:query|mutation) \w+PyApi", + query) is not None + self.queries.append((query, params)) + if not kwargs.get('timeout'): + kwargs['timeout'] = 30.0 + return super().execute(query, params, **kwargs) + + +class AdminClient(Client): + + def __init__(self, env): + """ + The admin client creates organizations and users using admin api described here https://labelbox.atlassian.net/wiki/spaces/AP/pages/2206564433/Internal+Admin+APIs. + """ + self._api_key = service_api_key() + self._admin_endpoint = f"{ephemeral_endpoint()}/admin/v1" + self._api_url = graphql_url(env) + self._rest_endpoint = rest_url(env) + + super().__init__(self._api_key, + self._api_url, + enable_experimental=True, + rest_endpoint=self._rest_endpoint) + + def _create_organization(self) -> str: + endpoint = f"{self._admin_endpoint}/organizations/" + response = requests.post( + endpoint, + headers=self.headers, + json={"name": f"Test Org {uuid.uuid4()}"}, + ) + + data = response.json() + if response.status_code not in [ + requests.codes.created, requests.codes.ok + ]: + raise Exception("Failed to create org, message: " + + str(data['message'])) + + return data['id'] + + def _create_user(self, organization_id=None) -> Tuple[str, str]: + if organization_id is None: + organization_id = self.organization_id + + endpoint = f"{self._admin_endpoint}/user-identities/" + identity_id = f"e2e+{uuid.uuid4()}" + + response = requests.post( + endpoint, + headers=self.headers, + json={ + "identityId": identity_id, + "email": "email@email.com", + "name": f"tester{uuid.uuid4()}", + "verificationStatus": "VERIFIED", + }, + ) + data = response.json() + if response.status_code not in [ + requests.codes.created, requests.codes.ok + ]: + raise Exception("Failed to create user, message: " + + str(data['message'])) + + user_identity_id = data['identityId'] + + endpoint = f"{self._admin_endpoint}/organizations/{organization_id}/users/" + response = requests.post( + endpoint, + headers=self.headers, + json={ + "identityId": user_identity_id, + "organizationRole": "Admin" + }, + ) + + data = response.json() + if response.status_code not in [ + requests.codes.created, requests.codes.ok + ]: + raise Exception("Failed to create link user to org, message: " + + str(data['message'])) + + user_id = data['id'] + + endpoint = f"{self._admin_endpoint}/users/{user_id}/token" + response = requests.get( + endpoint, + headers=self.headers, + ) + data = response.json() + if response.status_code not in [ + requests.codes.created, requests.codes.ok + ]: + raise Exception("Failed to create ephemeral user, message: " + + str(data['message'])) + + token = data["token"] + + return user_id, token + + def create_api_key_for_user(self) -> str: + organization_id = self._create_organization() + _, user_token = self._create_user(organization_id) + key_name = f"test-key+{uuid.uuid4()}" + query = """mutation CreateApiKeyPyApi($name: String!) { + createApiKey(data: {name: $name}) { + id + jwt + } + } + """ + params = {"name": key_name} + self.headers["Authorization"] = f"Bearer {user_token}" + res = self.execute(query, params, error_log_key="errors") + + return res["createApiKey"]["jwt"] + + +class EphemeralClient(Client): + + def __init__(self, environ=Environ.EPHEMERAL): + self.admin_client = AdminClient(environ) + self.api_key = self.admin_client.create_api_key_for_user() + api_url = graphql_url(environ) + rest_endpoint = rest_url(environ) + + super().__init__(self.api_key, + api_url, + enable_experimental=True, + rest_endpoint=rest_endpoint) + + +@pytest.fixture +def ephmeral_client() -> EphemeralClient: + return EphemeralClient + + +@pytest.fixture +def admin_client() -> AdminClient: + return AdminClient + + +@pytest.fixture +def integration_client() -> IntegrationClient: + return IntegrationClient + + +@pytest.fixture(scope="session") +def environ() -> Environ: + """ + Checks environment variables for LABELBOX_ENVIRON to be + 'prod' or 'staging' + Make sure to set LABELBOX_TEST_ENVIRON in .github/workflows/python-package.yaml + """ + keys = [ + "LABELBOX_TEST_ENV", + "LABELBOX_TEST_ENVIRON", + "LABELBOX_ENV" + ] + for key in keys: + value = os.environ.get(key) + if value is not None: + return Environ(value) + raise Exception(f'Missing env key in: {os.environ}') + + +def cancel_invite(client, invite_id): + """ + Do not use. Only for testing. + """ + query_str = """mutation CancelInvitePyApi($where: WhereUniqueIdInput!) { + cancelInvite(where: $where) {id}}""" + client.execute(query_str, {'where': {'id': invite_id}}, experimental=True) + + +def get_project_invites(client, project_id): + """ + Do not use. Only for testing. + """ + id_param = "projectId" + query_str = """query GetProjectInvitationsPyApi($from: ID, $first: PageSize, $%s: ID!) { + project(where: {id: $%s}) {id + invites(from: $from, first: $first) { nodes { %s + projectInvites { projectId projectRoleName } } nextCursor}}} + """ % (id_param, id_param, query.results_query_part(Invite)) + return PaginatedCollection(client, + query_str, {id_param: project_id}, + ['project', 'invites', 'nodes'], + Invite, + cursor_path=['project', 'invites', 'nextCursor']) + + +def get_invites(client): + """ + Do not use. Only for testing. + """ + query_str = """query GetOrgInvitationsPyApi($from: ID, $first: PageSize) { + organization { id invites(from: $from, first: $first) { + nodes { id createdAt organizationRoleName inviteeEmail } nextCursor }}}""" + invites = PaginatedCollection( + client, + query_str, {}, ['organization', 'invites', 'nodes'], + Invite, + cursor_path=['organization', 'invites', 'nextCursor'], + experimental=True) + return invites + + +@pytest.fixture +def queries(): + return SimpleNamespace(cancel_invite=cancel_invite, + get_project_invites=get_project_invites, + get_invites=get_invites) + + +@pytest.fixture(scope="session") +def admin_client(environ: str): + return AdminClient(environ) + + +@pytest.fixture(scope="session") +def client(environ: str): + if environ == Environ.EPHEMERAL: + return EphemeralClient() + return IntegrationClient(environ) + + +@pytest.fixture(scope="session") +def pdf_url(client): + pdf_url = client.upload_file('tests/assets/loremipsum.pdf') + return {"row_data": {"pdf_url": pdf_url,}, "global_key": str(uuid.uuid4())} + + +@pytest.fixture(scope="session") +def pdf_entity_data_row(client): + pdf_url = client.upload_file( + 'tests/assets/arxiv-pdf_data_99-word-token-pdfs_0801.3483.pdf') + text_layer_url = client.upload_file( + 'tests/assets/arxiv-pdf_data_99-word-token-pdfs_0801.3483-lb-textlayer.json' + ) + + return { + "row_data": { + "pdf_url": pdf_url, + "text_layer_url": text_layer_url + }, + "global_key": str(uuid.uuid4()) + } + + +@pytest.fixture() +def conversation_entity_data_row(client, rand_gen): + return { + "row_data": + "https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json", + "global_key": + f"https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json-{rand_gen(str)}", + } + + +@pytest.fixture +def project(client, rand_gen): + project = client.create_project(name=rand_gen(str), + queue_mode=QueueMode.Batch, + media_type=MediaType.Image) + yield project + project.delete() + + +@pytest.fixture +def consensus_project(client, rand_gen): + project = client.create_project(name=rand_gen(str), + quality_mode=QualityMode.Consensus, + queue_mode=QueueMode.Batch, + media_type=MediaType.Image) + yield project + project.delete() + + +@pytest.fixture +def model_config(client, rand_gen, valid_model_id): + model_config = client.create_model_config( + name=rand_gen(str), + model_id=valid_model_id, + inference_params={"param": "value"}) + yield model_config + client.delete_model_config(model_config.uid) + + +@pytest.fixture +def consensus_project_with_batch(consensus_project, initial_dataset, rand_gen, + image_url): + project = consensus_project + dataset = initial_dataset + + data_rows = [] + for _ in range(3): + data_rows.append({ + DataRow.row_data: image_url, + DataRow.global_key: str(uuid.uuid4()) + }) + task = dataset.create_data_rows(data_rows) + task.wait_till_done() + assert task.status == "COMPLETE" + + data_rows = list(dataset.data_rows()) + assert len(data_rows) == 3 + batch = project.create_batch( + rand_gen(str), + data_rows, # sample of data row objects + 5 # priority between 1(Highest) - 5(lowest) + ) + + yield [project, batch, data_rows] + batch.delete() + + +@pytest.fixture +def dataset(client, rand_gen): + dataset = client.create_dataset(name=rand_gen(str)) + yield dataset + dataset.delete() + + +@pytest.fixture(scope='function') +def unique_dataset(client, rand_gen): + dataset = client.create_dataset(name=rand_gen(str)) + yield dataset + dataset.delete() + + +@pytest.fixture +def small_dataset(dataset: Dataset): + task = dataset.create_data_rows([ + { + "row_data": SMALL_DATASET_URL, + "external_id": "my-image" + }, + ] * 2) + task.wait_till_done() + + yield dataset + + +@pytest.fixture +def data_row(dataset, image_url, rand_gen): + global_key = f"global-key-{rand_gen(str)}" + task = dataset.create_data_rows([ + { + "row_data": image_url, + "external_id": "my-image", + "global_key": global_key + }, + ]) + task.wait_till_done() + dr = dataset.data_rows().get_one() + yield dr + dr.delete() + + +@pytest.fixture +def data_row_and_global_key(dataset, image_url, rand_gen): + global_key = f"global-key-{rand_gen(str)}" + task = dataset.create_data_rows([ + { + "row_data": image_url, + "external_id": "my-image", + "global_key": global_key + }, + ]) + task.wait_till_done() + dr = dataset.data_rows().get_one() + yield dr, global_key + dr.delete() + + +# can be used with +# @pytest.mark.parametrize('data_rows', [], indirect=True) +# if omitted, count defaults to 1 +@pytest.fixture +def data_rows(dataset, image_url, request, wait_for_data_row_processing, + client): + count = 1 + if hasattr(request, 'param'): + count = request.param + + datarows = [ + dict(row_data=image_url, global_key=f"global-key-{uuid.uuid4()}") + for _ in range(count) + ] + + task = dataset.create_data_rows(datarows) + task.wait_till_done() + datarows = dataset.data_rows().get_many(count) + for dr in dataset.data_rows(): + wait_for_data_row_processing(client, dr) + + yield datarows + + for datarow in datarows: + datarow.delete() + + +@pytest.fixture +def iframe_url(environ) -> str: + if environ in [Environ.PROD, Environ.LOCAL]: + return 'https://editor.labelbox.com' + elif environ == Environ.STAGING: + return 'https://editor.lb-stage.xyz' + + +@pytest.fixture +def sample_image() -> str: + path_to_video = 'tests/integration/media/sample_image.jpg' + return path_to_video + + +@pytest.fixture +def sample_video() -> str: + path_to_video = 'tests/integration/media/cat.mp4' + return path_to_video + + +@pytest.fixture +def sample_bulk_conversation() -> list: + path_to_conversation = 'tests/integration/media/bulk_conversation.json' + with open(path_to_conversation) as json_file: + conversations = json.load(json_file) + return conversations + + +@pytest.fixture +def organization(client): + # Must have at least one seat open in your org to run these tests + org = client.get_organization() + # Clean up before and after incase this wasn't run for some reason. + for invite in get_invites(client): + if "@labelbox.com" in invite.email: + cancel_invite(client, invite.uid) + yield org + for invite in get_invites(client): + if "@labelbox.com" in invite.email: + cancel_invite(client, invite.uid) + + +@pytest.fixture +def configured_project_with_label(client, rand_gen, image_url, project, dataset, + data_row, wait_for_label_processing): + """Project with a connected dataset, having one datarow + Project contains an ontology with 1 bbox tool + Additionally includes a create_label method for any needed extra labels + One label is already created and yielded when using fixture + """ + project._wait_until_data_rows_are_processed( + data_row_ids=[data_row.uid], + wait_processing_max_seconds=DATA_ROW_PROCESSING_WAIT_TIMEOUT_SECONDS, + sleep_interval=DATA_ROW_PROCESSING_WAIT_SLEEP_INTERNAL_SECONDS) + + project.create_batch( + rand_gen(str), + [data_row.uid], # sample of data row objects + 5 # priority between 1(Highest) - 5(lowest) + ) + ontology = _setup_ontology(project) + label = _create_label(project, data_row, ontology, + wait_for_label_processing) + yield [project, dataset, data_row, label] + + for label in project.labels(): + label.delete() + + +def _create_label(project, data_row, ontology, wait_for_label_processing): + predictions = [{ + "uuid": str(uuid.uuid4()), + "schemaId": ontology.tools[0].feature_schema_id, + "dataRow": { + "id": data_row.uid + }, + "bbox": { + "top": 20, + "left": 20, + "height": 50, + "width": 50 + } + }] + + def create_label(): + """ Ad-hoc function to create a LabelImport + Creates a LabelImport task which will create a label + """ + upload_task = LabelImport.create_from_objects( + project.client, project.uid, f'label-import-{uuid.uuid4()}', + predictions) + upload_task.wait_until_done(sleep_time_seconds=5) + assert upload_task.state == AnnotationImportState.FINISHED, "Label Import did not finish" + assert len( + upload_task.errors + ) == 0, f"Label Import {upload_task.name} failed with errors {upload_task.errors}" + + project.create_label = create_label + project.create_label() + label = wait_for_label_processing(project)[0] + return label + + +def _setup_ontology(project): + editor = list( + project.client.get_labeling_frontends( + where=LabelingFrontend.name == "editor"))[0] + ontology_builder = OntologyBuilder(tools=[ + Tool(tool=Tool.Type.BBOX, name="test-bbox-class"), + ]) + project.setup(editor, ontology_builder.asdict()) + # TODO: ontology may not be synchronous after setup. remove sleep when api is more consistent + time.sleep(2) + return OntologyBuilder.from_project(project) + + +@pytest.fixture +def big_dataset(dataset: Dataset): + task = dataset.create_data_rows([ + { + "row_data": IMAGE_URL, + "external_id": EXTERNAL_ID + }, + ] * 3) + task.wait_till_done() + + yield dataset + + +@pytest.fixture +def configured_batch_project_with_label(project, dataset, data_row, + wait_for_label_processing): + """Project with a batch having one datarow + Project contains an ontology with 1 bbox tool + Additionally includes a create_label method for any needed extra labels + One label is already created and yielded when using fixture + """ + data_rows = [dr.uid for dr in list(dataset.data_rows())] + project._wait_until_data_rows_are_processed(data_row_ids=data_rows, + sleep_interval=3) + project.create_batch("test-batch", data_rows) + project.data_row_ids = data_rows + + ontology = _setup_ontology(project) + label = _create_label(project, data_row, ontology, + wait_for_label_processing) + + yield [project, dataset, data_row, label] + + for label in project.labels(): + label.delete() + + +@pytest.fixture +def configured_batch_project_with_multiple_datarows(project, dataset, data_rows, + wait_for_label_processing): + """Project with a batch having multiple datarows + Project contains an ontology with 1 bbox tool + Additionally includes a create_label method for any needed extra labels + """ + global_keys = [dr.global_key for dr in data_rows] + + batch_name = f'batch {uuid.uuid4()}' + project.create_batch(batch_name, global_keys=global_keys) + + ontology = _setup_ontology(project) + for datarow in data_rows: + _create_label(project, datarow, ontology, wait_for_label_processing) + + yield [project, dataset, data_rows] + + for label in project.labels(): + label.delete() + + +# NOTE this is nice heuristics, also there is this logic _wait_until_data_rows_are_processed in Project +# in case we still have flakiness in the future, we can use it +@pytest.fixture +def wait_for_data_row_processing(): + """ + Do not use. Only for testing. + + Returns DataRow after waiting for it to finish processing media_attributes. + Some tests, specifically ones that rely on label export, rely on + DataRow be fully processed with media_attributes + """ + + def func(client, data_row, custom_check=None): + """ + added check_updated_at because when a data_row is updated from say + an image to pdf, it already has media_attributes and the loop does + not wait for processing to a pdf + """ + data_row_id = data_row.uid + timeout_seconds = 60 + while True: + data_row = client.get_data_row(data_row_id) + passed_custom_check = not custom_check or custom_check(data_row) + if data_row.media_attributes and passed_custom_check: + return data_row + timeout_seconds -= 2 + if timeout_seconds <= 0: + raise TimeoutError( + f"Timed out waiting for DataRow '{data_row_id}' to finish processing media_attributes" + ) + time.sleep(2) + + return func + + +@pytest.fixture +def wait_for_label_processing(): + """ + Do not use. Only for testing. + + Returns project's labels as a list after waiting for them to finish processing. + If `project.labels()` is called before label is fully processed, + it may return an empty set + """ + + def func(project): + timeout_seconds = 10 + while True: + labels = list(project.labels()) + if len(labels) > 0: + return labels + timeout_seconds -= 2 + if timeout_seconds <= 0: + raise TimeoutError( + f"Timed out waiting for label for project '{project.uid}' to finish processing" + ) + time.sleep(2) + + return func + + +@pytest.fixture +def initial_dataset(client, rand_gen): + dataset = client.create_dataset(name=rand_gen(str)) + yield dataset + + dataset.delete() + + +@pytest.fixture +def video_data(client, rand_gen, video_data_row, wait_for_data_row_processing): + dataset = client.create_dataset(name=rand_gen(str)) + data_row_ids = [] + data_row = dataset.create_data_row(video_data_row) + data_row = wait_for_data_row_processing(client, data_row) + data_row_ids.append(data_row.uid) + yield dataset, data_row_ids + dataset.delete() + + +def create_video_data_row(rand_gen): + return { + "row_data": + "https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4", + "global_key": + f"https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4-{rand_gen(str)}", + "media_type": + "VIDEO", + } + + +@pytest.fixture +def video_data_100_rows(client, rand_gen, wait_for_data_row_processing): + dataset = client.create_dataset(name=rand_gen(str)) + data_row_ids = [] + for _ in range(100): + data_row = dataset.create_data_row(create_video_data_row(rand_gen)) + data_row = wait_for_data_row_processing(client, data_row) + data_row_ids.append(data_row.uid) + yield dataset, data_row_ids + dataset.delete() + + +@pytest.fixture() +def video_data_row(rand_gen): + return create_video_data_row(rand_gen) + + +class ExportV2Helpers: + + @classmethod + def run_project_export_v2_task(cls, + project, + num_retries=5, + task_name=None, + filters={}, + params={}): + task = None + params = params if params else { + "project_details": True, + "performance_details": False, + "data_row_details": True, + "label_details": True + } + while (num_retries > 0): + task = project.export_v2(task_name=task_name, + filters=filters, + params=params) + task.wait_till_done() + assert task.status == "COMPLETE" + assert task.errors is None + if len(task.result) == 0: + num_retries -= 1 + time.sleep(5) + else: + break + return task.result + + @classmethod + def run_dataset_export_v2_task(cls, + dataset, + num_retries=5, + task_name=None, + filters={}, + params={}): + task = None + params = params if params else { + "performance_details": False, + "label_details": True + } + while (num_retries > 0): + task = dataset.export_v2(task_name=task_name, + filters=filters, + params=params) + task.wait_till_done() + assert task.status == "COMPLETE" + assert task.errors is None + if len(task.result) == 0: + num_retries -= 1 + time.sleep(5) + else: + break + + return task.result + + @classmethod + def run_catalog_export_v2_task(cls, + client, + num_retries=5, + task_name=None, + filters={}, + params={}): + task = None + params = params if params else { + "performance_details": False, + "label_details": True + } + catalog = client.get_catalog() + while (num_retries > 0): + + task = catalog.export_v2(task_name=task_name, + filters=filters, + params=params) + task.wait_till_done() + assert task.status == "COMPLETE" + assert task.errors is None + if len(task.result) == 0: + num_retries -= 1 + time.sleep(5) + else: + break + + return task.result + + +@pytest.fixture +def export_v2_test_helpers() -> Type[ExportV2Helpers]: + return ExportV2Helpers() + + +@pytest.fixture +def big_dataset_data_row_ids(big_dataset: Dataset): + export_task = big_dataset.export() + export_task.wait_till_done() + stream = export_task.get_buffered_stream() + yield [dr.json["data_row"]["id"] for dr in stream] + + +@pytest.fixture(scope='function') +def dataset_with_invalid_data_rows(unique_dataset: Dataset, + upload_invalid_data_rows_for_dataset): + upload_invalid_data_rows_for_dataset(unique_dataset) + + yield unique_dataset + + +@pytest.fixture +def upload_invalid_data_rows_for_dataset(): + + def _upload_invalid_data_rows_for_dataset(dataset: Dataset): + task = dataset.create_data_rows([ + { + "row_data": 'gs://invalid-bucket/example.png', # forbidden + "external_id": "image-without-access.jpg" + }, + ] * 2) + task.wait_till_done() + + return _upload_invalid_data_rows_for_dataset + + +@pytest.fixture +def configured_project(project_with_empty_ontology, initial_dataset, rand_gen, + image_url): + dataset = initial_dataset + data_row_id = dataset.create_data_row(row_data=image_url).uid + project = project_with_empty_ontology + + batch = project.create_batch( + rand_gen(str), + [data_row_id], # sample of data row objects + 5 # priority between 1(Highest) - 5(lowest) + ) + project.data_row_ids = [data_row_id] + + yield project + + batch.delete() + + +@pytest.fixture +def project_with_empty_ontology(project): + editor = list( + project.client.get_labeling_frontends( + where=LabelingFrontend.name == "editor"))[0] + empty_ontology = {"tools": [], "classifications": []} + project.setup(editor, empty_ontology) + yield project + + +@pytest.fixture +def configured_project_with_complex_ontology(client, initial_dataset, rand_gen, + image_url): + project = client.create_project(name=rand_gen(str), + queue_mode=QueueMode.Batch, + media_type=MediaType.Image) + dataset = initial_dataset + data_row = dataset.create_data_row(row_data=image_url) + data_row_ids = [data_row.uid] + + project.create_batch( + rand_gen(str), + data_row_ids, # sample of data row objects + 5 # priority between 1(Highest) - 5(lowest) + ) + project.data_row_ids = data_row_ids + + editor = list( + project.client.get_labeling_frontends( + where=LabelingFrontend.name == "editor"))[0] + + ontology = OntologyBuilder() + tools = [ + Tool(tool=Tool.Type.BBOX, name="test-bbox-class"), + Tool(tool=Tool.Type.LINE, name="test-line-class"), + Tool(tool=Tool.Type.POINT, name="test-point-class"), + Tool(tool=Tool.Type.POLYGON, name="test-polygon-class"), + Tool(tool=Tool.Type.NER, name="test-ner-class") + ] + + options = [ + Option(value="first option answer"), + Option(value="second option answer"), + Option(value="third option answer") + ] + + classifications = [ + Classification(class_type=Classification.Type.TEXT, + name="test-text-class"), + Classification(class_type=Classification.Type.DROPDOWN, + name="test-dropdown-class", + options=options), + Classification(class_type=Classification.Type.RADIO, + name="test-radio-class", + options=options), + Classification(class_type=Classification.Type.CHECKLIST, + name="test-checklist-class", + options=options) + ] + + for t in tools: + for c in classifications: + t.add_classification(c) + ontology.add_tool(t) + for c in classifications: + ontology.add_classification(c) + + project.setup(editor, ontology.asdict()) + + yield [project, data_row] + project.delete() + + +@pytest.fixture(scope="session") +def embedding(client: Client, environ): + + uuid_str = uuid.uuid4().hex + embedding = client.create_embedding(f"sdk-int-{uuid_str}", 8) + yield embedding + # Remove all embeddings on staging + if environ == Environ.STAGING: + embeddings = client.get_embeddings() + for embedding in embeddings: + with suppress(LabelboxError): + embedding.delete() + else: + embedding.delete() + + +@pytest.fixture +def valid_model_id(): + return "2c903542-d1da-48fd-9db1-8c62571bd3d2"