Skip to content

[PLT-600] Remove hacks and simplified annotation import library #1759

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from .base import ConfidenceValue, BaseMetric

from labelbox import pydantic_compat
from typing_extensions import Annotated

ScalarMetricValue = pydantic_compat.confloat(ge=0, le=100_000_000)
ScalarMetricValue = Annotated[float, pydantic_compat.confloat(ge=0, le=100_000_000)]
ScalarMetricConfidenceValue = Dict[ConfidenceValue, ScalarMetricValue]


Expand All @@ -27,11 +28,11 @@ class ScalarMetric(BaseMetric):
For backwards compatibility, metric_name is optional.
The metric_name will be set to a default name in the editor if it is not set.
This is not recommended and support for empty metric_name fields will be removed.
aggregation will be ignored wihtout providing a metric name.
aggregation will be ignored without providing a metric name.
"""
metric_name: Optional[str] = None
value: Union[ScalarMetricValue, ScalarMetricConfidenceValue]
aggregation: ScalarMetricAggregation = ScalarMetricAggregation.ARITHMETIC_MEAN
aggregation: Optional[ScalarMetricAggregation] = ScalarMetricAggregation.ARITHMETIC_MEAN

@pydantic_compat.validator('metric_name')
def validate_metric_name(cls, name: Union[str, None]):
Expand Down
30 changes: 25 additions & 5 deletions libs/labelbox/src/labelbox/data/serialization/ndjson/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,18 @@
from labelbox import pydantic_compat
from ...annotation_types.types import Cuid

subclass_registry = {}

class SubclassRegistryBase(pydantic_compat.BaseModel):

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if cls.__name__ != "NDAnnotation":
subclass_registry[cls.__name__] = cls

class Config:
extra = "allow"


class DataRow(_CamelCaseMixin):
id: str = None
Expand All @@ -19,7 +31,7 @@ def must_set_one(cls, values):

class NDJsonBase(_CamelCaseMixin):
uuid: str = None
data_row: DataRow
data_row: Optional[DataRow] = None

@pydantic_compat.validator('uuid', pre=True, always=True)
def set_id(cls, v):
Expand All @@ -28,10 +40,18 @@ def set_id(cls, v):
def dict(self, *args, **kwargs):
""" Pop missing id or missing globalKey from dataRow """
res = super().dict(*args, **kwargs)
if not self.data_row.id:
res['dataRow'].pop('id')
if not self.data_row.global_key:
res['dataRow'].pop('globalKey')
if self.data_row and not self.data_row.id:
if "data_row" in res:
res["data_row"].pop("id")
else:
res['dataRow'].pop('id')
if self.data_row and not self.data_row.global_key:
if "data_row" in res:
res["data_row"].pop("global_key")
else:
res['dataRow'].pop('globalKey')
if not self.data_row:
del res["dataRow"]
return res


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ...annotation_types.classification.classification import ClassificationAnswer, Dropdown, Text, Checklist, Radio
from ...annotation_types.types import Cuid
from ...annotation_types.data import TextData, VideoData, ImageData
from labelbox.data.serialization.ndjson.base import SubclassRegistryBase


class NDAnswer(ConfidenceMixin, CustomMetricsMixin):
Expand Down Expand Up @@ -174,7 +175,7 @@ def from_common(cls, prompt_text: PromptText, name: str,
# ====== End of subclasses


class NDText(NDAnnotation, NDTextSubclass):
class NDText(NDAnnotation, NDTextSubclass, SubclassRegistryBase):

@classmethod
def from_common(cls,
Expand All @@ -198,7 +199,7 @@ def from_common(cls,
)


class NDChecklist(NDAnnotation, NDChecklistSubclass, VideoSupported):
class NDChecklist(NDAnnotation, NDChecklistSubclass, VideoSupported, SubclassRegistryBase):

@classmethod
def from_common(
Expand Down Expand Up @@ -234,7 +235,7 @@ def from_common(
confidence=confidence)


class NDRadio(NDAnnotation, NDRadioSubclass, VideoSupported):
class NDRadio(NDAnnotation, NDRadioSubclass, VideoSupported, SubclassRegistryBase):

@classmethod
def from_common(
Expand Down Expand Up @@ -265,7 +266,7 @@ def from_common(
confidence=confidence)


class NDPromptText(NDAnnotation, NDPromptTextSubclass):
class NDPromptText(NDAnnotation, NDPromptTextSubclass, SubclassRegistryBase):

@classmethod
def from_common(
Expand Down Expand Up @@ -404,8 +405,6 @@ def from_common(
annotation.confidence)


# Make sure to keep NDChecklistSubclass prior to NDRadioSubclass in the list,
# otherwise list of answers gets parsed by NDRadio whereas NDChecklist must be used
NDSubclassificationType = Union[NDChecklistSubclass, NDRadioSubclass,
NDTextSubclass]

Expand All @@ -418,7 +417,6 @@ def from_common(
NDPromptText.update_forward_refs()
NDTextSubclass.update_forward_refs()

# Make sure to keep NDChecklist prior to NDRadio in the list,
# otherwise list of answers gets parsed by NDRadio whereas NDChecklist must be used

NDClassificationType = Union[NDChecklist, NDRadio, NDText]
NDPromptClassificationType = Union[NDPromptText]
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ...annotation_types.collection import LabelCollection, LabelGenerator
from ...annotation_types.relationship import RelationshipAnnotation
from .label import NDLabel
import copy

logger = logging.getLogger(__name__)

Expand All @@ -33,7 +34,9 @@ def deserialize(json_data: Iterable[Dict[str, Any]]) -> LabelGenerator:
Returns:
LabelGenerator containing the ndjson data.
"""
data = NDLabel(**{"annotations": json_data})

data = copy.deepcopy(json_data)
data = NDLabel(**{"annotations": data})
res = data.to_common()
return res

Expand Down Expand Up @@ -106,6 +109,7 @@ def serialize(
if not isinstance(annotation, RelationshipAnnotation):
uuid_safe_annotations.append(annotation)
label.annotations = uuid_safe_annotations

for annotation in NDLabel.from_common([label]):
annotation_uuid = getattr(annotation, "uuid", None)

Expand Down
50 changes: 48 additions & 2 deletions libs/labelbox/src/labelbox/data/serialization/ndjson/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@
from ...annotation_types.llm_prompt_response.prompt import PromptClassificationAnnotation

from .metric import NDScalarMetric, NDMetricAnnotation, NDConfusionMatrixMetric
from .classification import NDChecklistSubclass, NDClassification, NDClassificationType, NDRadioSubclass, NDPromptClassification, NDPromptClassificationType, NDPromptText
from .classification import NDClassification, NDClassificationType, NDPromptClassification, NDPromptClassificationType, NDPromptText, NDChecklistSubclass, NDRadioSubclass
from .objects import NDObject, NDObjectType, NDSegments, NDDicomSegments, NDVideoMasks, NDDicomMasks
from .relationship import NDRelationship
from .base import DataRow
from labelbox.utils import camel_case
from labelbox.data.serialization.ndjson.base import SubclassRegistryBase, subclass_registry
from contextlib import suppress

AnnotationType = Union[NDObjectType, NDClassificationType, NDPromptClassificationType,
NDConfusionMatrixMetric, NDScalarMetric, NDDicomSegments,
Expand All @@ -32,8 +35,51 @@


class NDLabel(pydantic_compat.BaseModel):
annotations: List[AnnotationType]
annotations: List[SubclassRegistryBase]

def __init__(self, **kwargs):
# NOTE: Deserialization of subclasses in pydantic is difficult, see here https://blog.devgenius.io/deserialize-child-classes-with-pydantic-that-gonna-work-784230e1cf83
# Below implements the subclass registry as mentioned in the article. The python dicts we pass in can be missing certain fields
# we essentially have to infer the type against all sub classes that have the SubclasssRegistryBase inheritance.
# It works by checking if the keys of our annotations any required keys inside subclasses.
# More keys are prioritized over less keys (closer match). This is used when importing json to our base models not a lot of customer workflows
# depending on this method but this works for all our existing tests with the bonus of added validation. (no subclass found it throws an error)
# Previous strategies hacked but dont work for pydantic V2 they also make this part of the code less complicated prior solutions depended on order
# of how classes were shown on Python file to work. This should open the door to cut out a lot of the library specifically some subclasses.

for index in range(len(kwargs["annotations"])):
annotation = kwargs["annotations"][index]
if isinstance(annotation, dict):
item_annotation_keys = annotation.keys()
key_subclass_combos = defaultdict(list)
for subclass in subclass_registry.values():
subclass = subclass

# Get all required keys from subclass
annotation_keys = []
for k, field in subclass.__fields__.items():
# must account for alias
if hasattr(field, "alias") and field.alias == "answers" and "answers" in item_annotation_keys:
annotation_keys.append("answers")
elif field.required is True and k != "uuid":
annotation_keys.append(camel_case(k))
key_subclass_combos[subclass].extend(annotation_keys)
# Sort by subclass that has the most keys i.e. the one with the most keys if a match is likely our class
key_subclass_combos = dict(sorted(key_subclass_combos.items(), key = lambda x : len(x[1]), reverse=True))

# Choose the keys from our dict we supplied that matches the required keys of a subclass
for subclass, key_subclass_combo in key_subclass_combos.items():
check_required_keys = all(key in list(item_annotation_keys) for key in key_subclass_combo)
if check_required_keys:
# Keep trying subclasses until we find one that has valid values
with suppress(pydantic_compat.ValidationError):
annotation = subclass(**annotation)
break
if isinstance(annotation, dict):
raise ValueError(f"Could not find subclass for fields: {item_annotation_keys}")
kwargs["annotations"][index] = annotation
super().__init__(**kwargs)

class _Relationship(pydantic_compat.BaseModel):
"""This object holds information about the relationship"""
ndjson: NDRelationship
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional, Union, Type

from labelbox.data.annotation_types.data import ImageData, TextData
from labelbox.data.serialization.ndjson.base import DataRow, NDJsonBase
from labelbox.data.serialization.ndjson.base import DataRow, NDJsonBase, SubclassRegistryBase
from labelbox.data.annotation_types.metrics.scalar import (
ScalarMetric, ScalarMetricAggregation, ScalarMetricValue,
ScalarMetricConfidenceValue)
Expand All @@ -26,7 +26,7 @@ def dict(self, *args, **kwargs):
return res


class NDConfusionMatrixMetric(BaseNDMetric):
class NDConfusionMatrixMetric(BaseNDMetric, SubclassRegistryBase):
metric_value: Union[ConfusionMatrixMetricValue,
ConfusionMatrixMetricConfidenceValue]
metric_name: str
Expand All @@ -53,10 +53,10 @@ def from_common(
data_row=DataRow(id=data.uid, global_key=data.global_key))


class NDScalarMetric(BaseNDMetric):
class NDScalarMetric(BaseNDMetric, SubclassRegistryBase):
metric_value: Union[ScalarMetricValue, ScalarMetricConfidenceValue]
metric_name: Optional[str]
aggregation: ScalarMetricAggregation = ScalarMetricAggregation.ARITHMETIC_MEAN
aggregation: Optional[ScalarMetricAggregation] = ScalarMetricAggregation.ARITHMETIC_MEAN

def to_common(self) -> ScalarMetric:
return ScalarMetric(value=self.metric_value,
Expand Down
Loading
Loading