Skip to content

[PLT-1463] Remove deserialize completely #1818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/lbox-develop.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ name: LBox Develop

on:
push:
branches: [develop]
branches: [develop, v6]
Copy link
Collaborator Author

@Gabefire Gabefire Sep 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add our V6 branch here just so tests can run we can remove once V6 is merged onto develop

pull_request:
branches: [develop]
branches: [develop, v6]

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/python-package-develop.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ name: Labelbox Python SDK Staging (Develop)

on:
push:
branches: [develop]
branches: [develop, v6]
pull_request:
branches: [develop]
branches: [develop, v6]

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
Expand Down
12 changes: 0 additions & 12 deletions libs/labelbox/src/labelbox/data/serialization/ndjson/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,6 @@

from ....annotated_types import Cuid

subclass_registry = {}


class _SubclassRegistryBase(BaseModel):
model_config = ConfigDict(extra="allow")

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if cls.__name__ != "NDAnnotation":
with threading.Lock():
subclass_registry[cls.__name__] = cls


class DataRow(_CamelCaseMixin):
id: Optional[str] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
model_serializer,
)
from pydantic.alias_generators import to_camel
from .base import _SubclassRegistryBase


class NDAnswer(ConfidenceMixin, CustomMetricsMixin):
Expand Down Expand Up @@ -224,7 +223,7 @@ def from_common(
# ====== End of subclasses


class NDText(NDAnnotation, NDTextSubclass, _SubclassRegistryBase):
class NDText(NDAnnotation, NDTextSubclass):
@classmethod
def from_common(
cls,
Expand All @@ -249,9 +248,7 @@ def from_common(
)


class NDChecklist(
NDAnnotation, NDChecklistSubclass, VideoSupported, _SubclassRegistryBase
):
class NDChecklist(NDAnnotation, NDChecklistSubclass, VideoSupported):
@model_serializer(mode="wrap")
def serialize_model(self, handler):
res = handler(self)
Expand Down Expand Up @@ -298,9 +295,7 @@ def from_common(
)


class NDRadio(
NDAnnotation, NDRadioSubclass, VideoSupported, _SubclassRegistryBase
):
class NDRadio(NDAnnotation, NDRadioSubclass, VideoSupported):
@classmethod
def from_common(
cls,
Expand Down Expand Up @@ -343,7 +338,7 @@ def serialize_model(self, handler):
return res


class NDPromptText(NDAnnotation, NDPromptTextSubclass, _SubclassRegistryBase):
class NDPromptText(NDAnnotation, NDPromptTextSubclass):
@classmethod
def from_common(
cls,
Expand Down
14 changes: 0 additions & 14 deletions libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,6 @@


class NDJsonConverter:
@staticmethod
def deserialize(json_data: Iterable[Dict[str, Any]]) -> LabelGenerator:
"""
Converts ndjson data (prediction import format) into the common labelbox format.

Args:
json_data: An iterable representing the ndjson data
Returns:
LabelGenerator containing the ndjson data.
"""
data = NDLabel(**{"annotations": copy.copy(json_data)})
res = data.to_common()
return res

@staticmethod
def serialize(
labels: LabelCollection,
Expand Down
64 changes: 1 addition & 63 deletions libs/labelbox/src/labelbox/data/serialization/ndjson/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
from .relationship import NDRelationship
from .base import DataRow
from pydantic import BaseModel, ValidationError
from .base import subclass_registry, _SubclassRegistryBase
from pydantic_core import PydanticUndefined
from contextlib import suppress

Expand All @@ -67,68 +66,7 @@


class NDLabel(BaseModel):
annotations: List[_SubclassRegistryBase]

def __init__(self, **kwargs):
# NOTE: Deserialization of subclasses in pydantic is difficult, see here https://blog.devgenius.io/deserialize-child-classes-with-pydantic-that-gonna-work-784230e1cf83
# Below implements the subclass registry as mentioned in the article. The python dicts we pass in can be missing certain fields
# we essentially have to infer the type against all sub classes that have the _SubclasssRegistryBase inheritance.
# It works by checking if the keys of our annotations we are missing in matches any required subclass.
# More keys are prioritized over less keys (closer match). This is used when importing json to our base models not a lot of customer workflows
# depend on this method but this works for all our existing tests with the bonus of added validation. (no subclass found it throws an error)

for index, annotation in enumerate(kwargs["annotations"]):
if isinstance(annotation, dict):
item_annotation_keys = annotation.keys()
key_subclass_combos = defaultdict(list)
for subclass in subclass_registry.values():
# Get all required keys from subclass
annotation_keys = []
for k, field in subclass.model_fields.items():
if field.default == PydanticUndefined and k != "uuid":
if (
hasattr(field, "alias")
and field.alias in item_annotation_keys
):
annotation_keys.append(field.alias)
elif (
hasattr(field, "validation_alias")
and field.validation_alias
in item_annotation_keys
):
annotation_keys.append(field.validation_alias)
else:
annotation_keys.append(k)

key_subclass_combos[subclass].extend(annotation_keys)

# Sort by subclass that has the most keys i.e. the one with the most keys that matches is most likely our subclass
key_subclass_combos = dict(
sorted(
key_subclass_combos.items(),
key=lambda x: len(x[1]),
reverse=True,
)
)

for subclass, key_subclass_combo in key_subclass_combos.items():
# Choose the keys from our dict we supplied that matches the required keys of a subclass
check_required_keys = all(
key in list(item_annotation_keys)
for key in key_subclass_combo
)
if check_required_keys:
# Keep trying subclasses until we find one that has valid values (does not throw an validation error)
with suppress(ValidationError):
annotation = subclass(**annotation)
break
if isinstance(annotation, dict):
raise ValueError(
f"Could not find subclass for fields: {item_annotation_keys}"
)

kwargs["annotations"][index] = annotation
super().__init__(**kwargs)
annotations: AnnotationType

class _Relationship(BaseModel):
"""This object holds information about the relationship"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
ConfusionMatrixMetricConfidenceValue,
)
from pydantic import ConfigDict, model_serializer
from .base import _SubclassRegistryBase


class BaseNDMetric(NDJsonBase):
Expand All @@ -33,7 +32,7 @@ def serialize_model(self, handler):
return res


class NDConfusionMatrixMetric(BaseNDMetric, _SubclassRegistryBase):
class NDConfusionMatrixMetric(BaseNDMetric):
metric_value: Union[
ConfusionMatrixMetricValue, ConfusionMatrixMetricConfidenceValue
]
Expand Down Expand Up @@ -65,7 +64,7 @@ def from_common(
)


class NDScalarMetric(BaseNDMetric, _SubclassRegistryBase):
class NDScalarMetric(BaseNDMetric):
metric_value: Union[ScalarMetricValue, ScalarMetricConfidenceValue]
metric_name: Optional[str] = None
aggregation: Optional[ScalarMetricAggregation] = (
Expand Down
4 changes: 2 additions & 2 deletions libs/labelbox/src/labelbox/data/serialization/ndjson/mmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from labelbox.utils import _CamelCaseMixin

from .base import _SubclassRegistryBase, DataRow, NDAnnotation
from .base import DataRow, NDAnnotation
from ...annotation_types.mmc import (
MessageSingleSelectionTask,
MessageMultiSelectionTask,
Expand All @@ -20,7 +20,7 @@ class MessageTaskData(_CamelCaseMixin):
]


class NDMessageTask(NDAnnotation, _SubclassRegistryBase):
class NDMessageTask(NDAnnotation):
message_evaluation_task: MessageTaskData

def to_common(self) -> MessageEvaluationTaskAnnotation:
Expand Down
Loading
Loading