diff --git a/libs/labelbox/src/labelbox/data/annotation_types/metrics/scalar.py b/libs/labelbox/src/labelbox/data/annotation_types/metrics/scalar.py index 5f1279fd6..ccaa46854 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/metrics/scalar.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/metrics/scalar.py @@ -4,8 +4,9 @@ from .base import ConfidenceValue, BaseMetric from labelbox import pydantic_compat +from typing_extensions import Annotated -ScalarMetricValue = pydantic_compat.confloat(ge=0, le=100_000_000) +ScalarMetricValue = Annotated[float, pydantic_compat.confloat(ge=0, le=100_000_000)] ScalarMetricConfidenceValue = Dict[ConfidenceValue, ScalarMetricValue] @@ -27,11 +28,11 @@ class ScalarMetric(BaseMetric): For backwards compatibility, metric_name is optional. The metric_name will be set to a default name in the editor if it is not set. This is not recommended and support for empty metric_name fields will be removed. - aggregation will be ignored wihtout providing a metric name. + aggregation will be ignored without providing a metric name. """ metric_name: Optional[str] = None value: Union[ScalarMetricValue, ScalarMetricConfidenceValue] - aggregation: ScalarMetricAggregation = ScalarMetricAggregation.ARITHMETIC_MEAN + aggregation: Optional[ScalarMetricAggregation] = ScalarMetricAggregation.ARITHMETIC_MEAN @pydantic_compat.validator('metric_name') def validate_metric_name(cls, name: Union[str, None]): diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/base.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/base.py index 2a9186e02..21ef33dfe 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/base.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/base.py @@ -5,6 +5,18 @@ from labelbox import pydantic_compat from ...annotation_types.types import Cuid +subclass_registry = {} + +class SubclassRegistryBase(pydantic_compat.BaseModel): + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if cls.__name__ != "NDAnnotation": + subclass_registry[cls.__name__] = cls + + class Config: + extra = "allow" + class DataRow(_CamelCaseMixin): id: str = None @@ -19,7 +31,7 @@ def must_set_one(cls, values): class NDJsonBase(_CamelCaseMixin): uuid: str = None - data_row: DataRow + data_row: Optional[DataRow] = None @pydantic_compat.validator('uuid', pre=True, always=True) def set_id(cls, v): @@ -28,10 +40,18 @@ def set_id(cls, v): def dict(self, *args, **kwargs): """ Pop missing id or missing globalKey from dataRow """ res = super().dict(*args, **kwargs) - if not self.data_row.id: - res['dataRow'].pop('id') - if not self.data_row.global_key: - res['dataRow'].pop('globalKey') + if self.data_row and not self.data_row.id: + if "data_row" in res: + res["data_row"].pop("id") + else: + res['dataRow'].pop('id') + if self.data_row and not self.data_row.global_key: + if "data_row" in res: + res["data_row"].pop("global_key") + else: + res['dataRow'].pop('globalKey') + if not self.data_row: + del res["dataRow"] return res diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py index 46b8fc91f..56cc25c4a 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py @@ -11,6 +11,7 @@ from ...annotation_types.classification.classification import ClassificationAnswer, Dropdown, Text, Checklist, Radio from ...annotation_types.types import Cuid from ...annotation_types.data import TextData, VideoData, ImageData +from labelbox.data.serialization.ndjson.base import SubclassRegistryBase class NDAnswer(ConfidenceMixin, CustomMetricsMixin): @@ -174,7 +175,7 @@ def from_common(cls, prompt_text: PromptText, name: str, # ====== End of subclasses -class NDText(NDAnnotation, NDTextSubclass): +class NDText(NDAnnotation, NDTextSubclass, SubclassRegistryBase): @classmethod def from_common(cls, @@ -198,7 +199,7 @@ def from_common(cls, ) -class NDChecklist(NDAnnotation, NDChecklistSubclass, VideoSupported): +class NDChecklist(NDAnnotation, NDChecklistSubclass, VideoSupported, SubclassRegistryBase): @classmethod def from_common( @@ -234,7 +235,7 @@ def from_common( confidence=confidence) -class NDRadio(NDAnnotation, NDRadioSubclass, VideoSupported): +class NDRadio(NDAnnotation, NDRadioSubclass, VideoSupported, SubclassRegistryBase): @classmethod def from_common( @@ -265,7 +266,7 @@ def from_common( confidence=confidence) -class NDPromptText(NDAnnotation, NDPromptTextSubclass): +class NDPromptText(NDAnnotation, NDPromptTextSubclass, SubclassRegistryBase): @classmethod def from_common( @@ -404,8 +405,6 @@ def from_common( annotation.confidence) -# Make sure to keep NDChecklistSubclass prior to NDRadioSubclass in the list, -# otherwise list of answers gets parsed by NDRadio whereas NDChecklist must be used NDSubclassificationType = Union[NDChecklistSubclass, NDRadioSubclass, NDTextSubclass] @@ -418,7 +417,6 @@ def from_common( NDPromptText.update_forward_refs() NDTextSubclass.update_forward_refs() -# Make sure to keep NDChecklist prior to NDRadio in the list, -# otherwise list of answers gets parsed by NDRadio whereas NDChecklist must be used + NDClassificationType = Union[NDChecklist, NDRadio, NDText] NDPromptClassificationType = Union[NDPromptText] \ No newline at end of file diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py index 2ffeb9727..6092a88b6 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py @@ -15,6 +15,7 @@ from ...annotation_types.collection import LabelCollection, LabelGenerator from ...annotation_types.relationship import RelationshipAnnotation from .label import NDLabel +import copy logger = logging.getLogger(__name__) @@ -33,7 +34,9 @@ def deserialize(json_data: Iterable[Dict[str, Any]]) -> LabelGenerator: Returns: LabelGenerator containing the ndjson data. """ - data = NDLabel(**{"annotations": json_data}) + + data = copy.deepcopy(json_data) + data = NDLabel(**{"annotations": data}) res = data.to_common() return res @@ -106,6 +109,7 @@ def serialize( if not isinstance(annotation, RelationshipAnnotation): uuid_safe_annotations.append(annotation) label.annotations = uuid_safe_annotations + for annotation in NDLabel.from_common([label]): annotation_uuid = getattr(annotation, "uuid", None) diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py index 9d34c451b..884b06ee1 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py @@ -20,10 +20,13 @@ from ...annotation_types.llm_prompt_response.prompt import PromptClassificationAnnotation from .metric import NDScalarMetric, NDMetricAnnotation, NDConfusionMatrixMetric -from .classification import NDChecklistSubclass, NDClassification, NDClassificationType, NDRadioSubclass, NDPromptClassification, NDPromptClassificationType, NDPromptText +from .classification import NDClassification, NDClassificationType, NDPromptClassification, NDPromptClassificationType, NDPromptText, NDChecklistSubclass, NDRadioSubclass from .objects import NDObject, NDObjectType, NDSegments, NDDicomSegments, NDVideoMasks, NDDicomMasks from .relationship import NDRelationship from .base import DataRow +from labelbox.utils import camel_case +from labelbox.data.serialization.ndjson.base import SubclassRegistryBase, subclass_registry +from contextlib import suppress AnnotationType = Union[NDObjectType, NDClassificationType, NDPromptClassificationType, NDConfusionMatrixMetric, NDScalarMetric, NDDicomSegments, @@ -32,8 +35,51 @@ class NDLabel(pydantic_compat.BaseModel): - annotations: List[AnnotationType] + annotations: List[SubclassRegistryBase] + + def __init__(self, **kwargs): + # NOTE: Deserialization of subclasses in pydantic is difficult, see here https://blog.devgenius.io/deserialize-child-classes-with-pydantic-that-gonna-work-784230e1cf83 + # Below implements the subclass registry as mentioned in the article. The python dicts we pass in can be missing certain fields + # we essentially have to infer the type against all sub classes that have the SubclasssRegistryBase inheritance. + # It works by checking if the keys of our annotations any required keys inside subclasses. + # More keys are prioritized over less keys (closer match). This is used when importing json to our base models not a lot of customer workflows + # depending on this method but this works for all our existing tests with the bonus of added validation. (no subclass found it throws an error) + # Previous strategies hacked but dont work for pydantic V2 they also make this part of the code less complicated prior solutions depended on order + # of how classes were shown on Python file to work. This should open the door to cut out a lot of the library specifically some subclasses. + + for index in range(len(kwargs["annotations"])): + annotation = kwargs["annotations"][index] + if isinstance(annotation, dict): + item_annotation_keys = annotation.keys() + key_subclass_combos = defaultdict(list) + for subclass in subclass_registry.values(): + subclass = subclass + # Get all required keys from subclass + annotation_keys = [] + for k, field in subclass.__fields__.items(): + # must account for alias + if hasattr(field, "alias") and field.alias == "answers" and "answers" in item_annotation_keys: + annotation_keys.append("answers") + elif field.required is True and k != "uuid": + annotation_keys.append(camel_case(k)) + key_subclass_combos[subclass].extend(annotation_keys) + # Sort by subclass that has the most keys i.e. the one with the most keys if a match is likely our class + key_subclass_combos = dict(sorted(key_subclass_combos.items(), key = lambda x : len(x[1]), reverse=True)) + + # Choose the keys from our dict we supplied that matches the required keys of a subclass + for subclass, key_subclass_combo in key_subclass_combos.items(): + check_required_keys = all(key in list(item_annotation_keys) for key in key_subclass_combo) + if check_required_keys: + # Keep trying subclasses until we find one that has valid values + with suppress(pydantic_compat.ValidationError): + annotation = subclass(**annotation) + break + if isinstance(annotation, dict): + raise ValueError(f"Could not find subclass for fields: {item_annotation_keys}") + kwargs["annotations"][index] = annotation + super().__init__(**kwargs) + class _Relationship(pydantic_compat.BaseModel): """This object holds information about the relationship""" ndjson: NDRelationship diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/metric.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/metric.py index 5abbf2761..e01de247e 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/metric.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/metric.py @@ -1,7 +1,7 @@ from typing import Optional, Union, Type from labelbox.data.annotation_types.data import ImageData, TextData -from labelbox.data.serialization.ndjson.base import DataRow, NDJsonBase +from labelbox.data.serialization.ndjson.base import DataRow, NDJsonBase, SubclassRegistryBase from labelbox.data.annotation_types.metrics.scalar import ( ScalarMetric, ScalarMetricAggregation, ScalarMetricValue, ScalarMetricConfidenceValue) @@ -26,7 +26,7 @@ def dict(self, *args, **kwargs): return res -class NDConfusionMatrixMetric(BaseNDMetric): +class NDConfusionMatrixMetric(BaseNDMetric, SubclassRegistryBase): metric_value: Union[ConfusionMatrixMetricValue, ConfusionMatrixMetricConfidenceValue] metric_name: str @@ -53,10 +53,10 @@ def from_common( data_row=DataRow(id=data.uid, global_key=data.global_key)) -class NDScalarMetric(BaseNDMetric): +class NDScalarMetric(BaseNDMetric, SubclassRegistryBase): metric_value: Union[ScalarMetricValue, ScalarMetricConfidenceValue] metric_name: Optional[str] - aggregation: ScalarMetricAggregation = ScalarMetricAggregation.ARITHMETIC_MEAN + aggregation: Optional[ScalarMetricAggregation] = ScalarMetricAggregation.ARITHMETIC_MEAN def to_common(self) -> ScalarMetric: return ScalarMetric(value=self.metric_value, diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py index fd13a6bf6..7ddbe47af 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py @@ -19,8 +19,10 @@ from ...annotation_types.geometry import DocumentRectangle, Rectangle, Polygon, Line, Point, Mask from ...annotation_types.annotation import ClassificationAnnotation, ObjectAnnotation from ...annotation_types.video import VideoMaskAnnotation, DICOMMaskAnnotation, MaskFrame, MaskInstance -from .classification import NDClassification, NDSubclassification, NDSubclassificationType -from .base import DataRow, NDAnnotation, NDJsonBase +from .classification import NDSubclassification, NDSubclassificationType +from labelbox.data.serialization.ndjson.base import DataRow, NDAnnotation, NDJsonBase, SubclassRegistryBase + + class NDBaseObject(NDAnnotation): @@ -48,7 +50,7 @@ class Bbox(pydantic_compat.BaseModel): width: float -class NDPoint(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): +class NDPoint(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, SubclassRegistryBase): point: _Point def to_common(self) -> Point: @@ -109,7 +111,7 @@ def from_common( classifications=classifications) -class NDLine(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): +class NDLine(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, SubclassRegistryBase): line: List[_Point] def to_common(self) -> Line: @@ -187,7 +189,7 @@ def to_common(self, name: str, feature_schema_id: Cuid, segment_index: int, group_key=group_key) -class NDPolygon(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): +class NDPolygon(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, SubclassRegistryBase): polygon: List[_Point] def to_common(self) -> Polygon: @@ -218,7 +220,7 @@ def from_common( custom_metrics=custom_metrics) -class NDRectangle(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): +class NDRectangle(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, SubclassRegistryBase): bbox: Bbox def to_common(self) -> Rectangle: @@ -254,7 +256,7 @@ def from_common( custom_metrics=custom_metrics) -class NDDocumentRectangle(NDRectangle): +class NDDocumentRectangle(NDRectangle, SubclassRegistryBase): page: int unit: str @@ -362,7 +364,6 @@ def to_common(self, name: str, feature_schema_id: Cuid, uuid: str, @classmethod def from_common(cls, segment): nd_frame_object_type = cls.lookup_segment_object_type(segment) - return cls(keyframes=[ nd_frame_object_type.from_common( object_annotation.frame, object_annotation.value, [ @@ -398,7 +399,7 @@ def to_common(self, name: str, feature_schema_id: Cuid, uuid: str, ] -class NDSegments(NDBaseObject): +class NDSegments(NDBaseObject, SubclassRegistryBase): segments: List[NDSegment] def to_common(self, name: str, feature_schema_id: Cuid): @@ -425,7 +426,7 @@ def from_common(cls, segments: List[VideoObjectAnnotation], data: VideoData, uuid=extra.get('uuid')) -class NDDicomSegments(NDBaseObject, DicomSupported): +class NDDicomSegments(NDBaseObject, DicomSupported, SubclassRegistryBase): segments: List[NDDicomSegment] def to_common(self, name: str, feature_schema_id: Cuid): @@ -463,7 +464,7 @@ class _PNGMask(pydantic_compat.BaseModel): png: str -class NDMask(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): +class NDMask(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, SubclassRegistryBase): mask: Union[_URIMask, _PNGMask] def to_common(self) -> Mask: @@ -517,7 +518,7 @@ class NDVideoMasksFramesInstances(pydantic_compat.BaseModel): instances: List[MaskInstance] -class NDVideoMasks(NDJsonBase, ConfidenceMixin, CustomMetricsNotSupportedMixin): +class NDVideoMasks(NDJsonBase, ConfidenceMixin, CustomMetricsNotSupportedMixin, SubclassRegistryBase): masks: NDVideoMasksFramesInstances def to_common(self) -> VideoMaskAnnotation: @@ -545,7 +546,7 @@ def from_common(cls, annotation, data): ) -class NDDicomMasks(NDVideoMasks, DicomSupported): +class NDDicomMasks(NDVideoMasks, DicomSupported, SubclassRegistryBase): def to_common(self) -> DICOMMaskAnnotation: return DICOMMaskAnnotation( @@ -569,7 +570,7 @@ class Location(pydantic_compat.BaseModel): end: int -class NDTextEntity(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): +class NDTextEntity(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, SubclassRegistryBase): location: Location def to_common(self) -> TextEntity: @@ -601,7 +602,7 @@ def from_common( custom_metrics=custom_metrics) -class NDDocumentEntity(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): +class NDDocumentEntity(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, SubclassRegistryBase): name: str text_selections: List[DocumentTextSelection] @@ -633,7 +634,7 @@ def from_common( custom_metrics=custom_metrics) -class NDConversationEntity(NDTextEntity): +class NDConversationEntity(NDTextEntity, SubclassRegistryBase): message_id: str def to_common(self) -> ConversationEntity: @@ -773,12 +774,9 @@ def lookup_object( return result -# NOTE: Deserialization of subclasses in pydantic is a known PIA, see here https://blog.devgenius.io/deserialize-child-classes-with-pydantic-that-gonna-work-784230e1cf83 -# I could implement the registry approach suggested there, but I found that if I list subclass (that has more attributes) before the parent class, it works -# This is a bit of a hack, but it works for now NDEntityType = Union[NDConversationEntity, NDTextEntity] NDObjectType = Union[NDLine, NDPolygon, NDPoint, NDDocumentRectangle, NDRectangle, NDMask, NDEntityType, NDDocumentEntity] -NDFrameObjectType = NDFrameRectangle, NDFramePoint, NDFrameLine +NDFrameObjectType = Union[NDFrameRectangle, NDFramePoint, NDFrameLine] NDDicomObjectType = NDDicomLine diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/relationship.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/relationship.py index 82976aedb..f42cc3b7a 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/relationship.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/relationship.py @@ -5,7 +5,7 @@ from ...annotation_types.relationship import RelationshipAnnotation from ...annotation_types.relationship import Relationship from .objects import NDObjectType -from .base import DataRow +from labelbox.data.serialization.ndjson.base import DataRow, SubclassRegistryBase SUPPORTED_ANNOTATIONS = NDObjectType @@ -16,7 +16,7 @@ class _Relationship(pydantic_compat.BaseModel): type: str -class NDRelationship(NDAnnotation): +class NDRelationship(NDAnnotation, SubclassRegistryBase): relationship: _Relationship @staticmethod