Skip to content

Commit e169894

Browse files
committed
feedback
1 parent be38463 commit e169894

File tree

9 files changed

+77
-60
lines changed

9 files changed

+77
-60
lines changed

libs/labelbox/src/labelbox/data/annotation_types/data/conversation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
from .base_data import BaseData
44

55

6-
class ConversationData(BaseData):
7-
pass
6+
class ConversationData(BaseData, _NoCoercionMixin):
7+
class_name: Literal["DicomData"] = "ConversationData"

libs/labelbox/src/labelbox/data/serialization/ndjson/base.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,21 @@
33

44
from labelbox.utils import _CamelCaseMixin, is_exactly_one_set
55
from ...annotation_types.types import Cuid
6-
from pydantic import field_validator, model_validator, model_serializer, ConfigDict, BaseModel, Field
7-
from uuid import UUID, uuid4
6+
from pydantic import model_validator, ConfigDict, BaseModel, Field
7+
from uuid import uuid4
8+
import threading
89

910
subclass_registry = {}
1011

11-
class SubclassRegistryBase(BaseModel):
12+
class _SubclassRegistryBase(BaseModel):
1213

1314
model_config = ConfigDict(extra="allow")
1415

1516
def __init_subclass__(cls, **kwargs):
1617
super().__init_subclass__(**kwargs)
1718
if cls.__name__ != "NDAnnotation":
18-
subclass_registry[cls.__name__] = cls
19+
with threading.Lock():
20+
subclass_registry[cls.__name__] = cls
1921

2022
class DataRow(_CamelCaseMixin):
2123
id: Optional[str] = None

libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from ...annotation_types.data import TextData, VideoData, ImageData
1212
from pydantic import model_validator, Field, BaseModel, ConfigDict, model_serializer
1313
from pydantic.alias_generators import to_camel
14-
from .base import SubclassRegistryBase
14+
from .base import _SubclassRegistryBase
1515

1616

1717
class NDAnswer(ConfidenceMixin, CustomMetricsMixin):
@@ -170,7 +170,7 @@ def from_common(cls, prompt_text: PromptText, name: str,
170170
# ====== End of subclasses
171171

172172

173-
class NDText(NDAnnotation, NDTextSubclass, SubclassRegistryBase):
173+
class NDText(NDAnnotation, NDTextSubclass, _SubclassRegistryBase):
174174

175175
@classmethod
176176
def from_common(cls,
@@ -194,7 +194,7 @@ def from_common(cls,
194194
)
195195

196196

197-
class NDChecklist(NDAnnotation, NDChecklistSubclass, VideoSupported, SubclassRegistryBase):
197+
class NDChecklist(NDAnnotation, NDChecklistSubclass, VideoSupported, _SubclassRegistryBase):
198198

199199
@model_serializer(mode="wrap")
200200
def serialize_model(self, handler):
@@ -237,7 +237,7 @@ def from_common(
237237
confidence=confidence)
238238

239239

240-
class NDRadio(NDAnnotation, NDRadioSubclass, VideoSupported, SubclassRegistryBase):
240+
class NDRadio(NDAnnotation, NDRadioSubclass, VideoSupported, _SubclassRegistryBase):
241241

242242
@classmethod
243243
def from_common(
@@ -275,7 +275,7 @@ def serialize_model(self, handler):
275275
return res
276276

277277

278-
class NDPromptText(NDAnnotation, NDPromptTextSubclass, SubclassRegistryBase):
278+
class NDPromptText(NDAnnotation, NDPromptTextSubclass, _SubclassRegistryBase):
279279

280280
@classmethod
281281
def from_common(

libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ def deserialize(json_data: Iterable[Dict[str, Any]]) -> LabelGenerator:
3434
Returns:
3535
LabelGenerator containing the ndjson data.
3636
"""
37-
data = copy.copy(json_data)
38-
data = NDLabel(**{"annotations": data})
37+
data = NDLabel(**{"annotations": copy.copy(json_data)})
3938
res = data.to_common()
4039
return res
4140

libs/labelbox/src/labelbox/data/serialization/ndjson/label.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22
from operator import itemgetter
33
from typing import Dict, Generator, List, Tuple, Union
44
from collections import defaultdict
5-
from typing_extensions import Unpack
6-
import warnings
7-
85

96
from ...annotation_types.annotation import ClassificationAnnotation, ObjectAnnotation
107
from ...annotation_types.relationship import RelationshipAnnotation
@@ -23,10 +20,9 @@
2320
from .objects import NDObject, NDObjectType, NDSegments, NDDicomSegments, NDVideoMasks, NDDicomMasks
2421
from .relationship import NDRelationship
2522
from .base import DataRow
26-
from pydantic import BaseModel, ConfigDict, model_serializer, ValidationError
27-
from .base import subclass_registry, SubclassRegistryBase, NDAnnotation
23+
from pydantic import BaseModel, ValidationError
24+
from .base import subclass_registry, _SubclassRegistryBase
2825
from pydantic_core import PydanticUndefined
29-
from pydantic.alias_generators import to_camel
3026
from contextlib import suppress
3127

3228
AnnotationType = Union[NDObjectType, NDClassificationType, NDPromptClassificationType,
@@ -36,48 +32,51 @@
3632

3733

3834
class NDLabel(BaseModel):
39-
annotations: List[SubclassRegistryBase]
35+
annotations: List[_SubclassRegistryBase]
4036

4137
def __init__(self, **kwargs):
4238
# NOTE: Deserialization of subclasses in pydantic is difficult, see here https://blog.devgenius.io/deserialize-child-classes-with-pydantic-that-gonna-work-784230e1cf83
4339
# Below implements the subclass registry as mentioned in the article. The python dicts we pass in can be missing certain fields
44-
# we essentially have to infer the type against all sub classes that have the SubclasssRegistryBase inheritance.
40+
# we essentially have to infer the type against all sub classes that have the _SubclasssRegistryBase inheritance.
4541
# It works by checking if the keys of our annotations we are missing in matches any required subclass.
4642
# More keys are prioritized over less keys (closer match). This is used when importing json to our base models not a lot of customer workflows
4743
# depend on this method but this works for all our existing tests with the bonus of added validation. (no subclass found it throws an error)
48-
# Previous strategies hacked but dont work for pydantic V2
49-
for index in range(len(kwargs["annotations"])):
50-
annotation = kwargs["annotations"][index]
44+
45+
for index, annotation in enumerate(kwargs["annotations"]):
5146
if isinstance(annotation, dict):
5247
item_annotation_keys = annotation.keys()
5348
key_subclass_combos = defaultdict(list)
54-
for name, subclass in subclass_registry.items():
55-
subclass: BaseModel = subclass
49+
for subclass in subclass_registry.values():
5650

5751
# Get all required keys from subclass
5852
annotation_keys = []
5953
for k, field in subclass.model_fields.items():
60-
# must account for alias
61-
if hasattr(field, "validation_alias") and field.validation_alias == "answers" and "answers" in item_annotation_keys:
62-
annotation_keys.append("answers")
63-
elif field.default == PydanticUndefined and k != "uuid":
64-
annotation_keys.append(to_camel(k))
54+
if field.default == PydanticUndefined and k != "uuid":
55+
if hasattr(field, "alias") and field.alias in item_annotation_keys:
56+
annotation_keys.append(field.alias)
57+
else:
58+
annotation_keys.append(k)
59+
6560
key_subclass_combos[subclass].extend(annotation_keys)
66-
67-
# Sort by subclass that has the most keys
61+
62+
# Sort by subclass that has the most keys i.e. the one with the most keys that matches is most likely our subclass
6863
key_subclass_combos = dict(sorted(key_subclass_combos.items(), key = lambda x : len(x[1]), reverse=True))
69-
70-
# Choose the key that our dict we are passing in has all the keys to match
64+
7165
for subclass, key_subclass_combo in key_subclass_combos.items():
66+
# Choose the keys from our dict we supplied that matches the required keys of a subclass
7267
check_required_keys = all(key in list(item_annotation_keys) for key in key_subclass_combo)
7368
if check_required_keys:
74-
# Keep trying subclasses until we find one that has valid values
69+
# Keep trying subclasses until we find one that has valid values (does not throw an validation error)
7570
with suppress(ValidationError):
7671
annotation = subclass(**annotation)
7772
break
78-
kwargs["annotations"][index] = annotation
73+
if isinstance(annotation, dict):
74+
raise ValueError(f"Could not find subclass for fields: {item_annotation_keys}")
75+
76+
kwargs["annotations"][index] = annotation
7977
super().__init__(**kwargs)
8078

79+
8180
class _Relationship(BaseModel):
8281
"""This object holds information about the relationship"""
8382
ndjson: NDRelationship

libs/labelbox/src/labelbox/data/serialization/ndjson/metric.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
ConfusionMatrixAggregation, ConfusionMatrixMetric,
1010
ConfusionMatrixMetricValue, ConfusionMatrixMetricConfidenceValue)
1111
from pydantic import ConfigDict, model_serializer
12-
from .base import SubclassRegistryBase
12+
from .base import _SubclassRegistryBase
1313

1414

1515
class BaseNDMetric(NDJsonBase):
@@ -27,7 +27,7 @@ def serialize_model(self, handler):
2727
return res
2828

2929

30-
class NDConfusionMatrixMetric(BaseNDMetric, SubclassRegistryBase):
30+
class NDConfusionMatrixMetric(BaseNDMetric, _SubclassRegistryBase):
3131
metric_value: Union[ConfusionMatrixMetricValue,
3232
ConfusionMatrixMetricConfidenceValue]
3333
metric_name: str
@@ -54,7 +54,7 @@ def from_common(
5454
data_row=DataRow(id=data.uid, global_key=data.global_key))
5555

5656

57-
class NDScalarMetric(BaseNDMetric, SubclassRegistryBase):
57+
class NDScalarMetric(BaseNDMetric, _SubclassRegistryBase):
5858
metric_value: Union[ScalarMetricValue, ScalarMetricConfidenceValue]
5959
metric_name: Optional[str] = None
6060
aggregation: Optional[ScalarMetricAggregation] = ScalarMetricAggregation.ARITHMETIC_MEAN

libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ...annotation_types.annotation import ClassificationAnnotation, ObjectAnnotation
2020
from ...annotation_types.video import VideoMaskAnnotation, DICOMMaskAnnotation, MaskFrame, MaskInstance
2121
from .classification import NDClassification, NDSubclassification, NDSubclassificationType
22-
from .base import DataRow, NDAnnotation, NDJsonBase, SubclassRegistryBase
22+
from .base import DataRow, NDAnnotation, NDJsonBase, _SubclassRegistryBase
2323
from pydantic import BaseModel
2424

2525

@@ -48,7 +48,7 @@ class Bbox(BaseModel):
4848
width: float
4949

5050

51-
class NDPoint(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, SubclassRegistryBase):
51+
class NDPoint(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase):
5252
point: _Point
5353

5454
def to_common(self) -> Point:
@@ -79,7 +79,7 @@ def from_common(
7979
custom_metrics=custom_metrics)
8080

8181

82-
class NDFramePoint(VideoSupported, SubclassRegistryBase):
82+
class NDFramePoint(VideoSupported, _SubclassRegistryBase):
8383
point: _Point
8484
classifications: List[NDSubclassificationType] = []
8585

@@ -109,7 +109,7 @@ def from_common(
109109
classifications=classifications)
110110

111111

112-
class NDLine(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, SubclassRegistryBase):
112+
class NDLine(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase):
113113
line: List[_Point]
114114

115115
def to_common(self) -> Line:
@@ -140,7 +140,7 @@ def from_common(
140140
custom_metrics=custom_metrics)
141141

142142

143-
class NDFrameLine(VideoSupported, SubclassRegistryBase):
143+
class NDFrameLine(VideoSupported, _SubclassRegistryBase):
144144
line: List[_Point]
145145
classifications: List[NDSubclassificationType] = []
146146

@@ -173,7 +173,7 @@ def from_common(
173173
classifications=classifications)
174174

175175

176-
class NDDicomLine(NDFrameLine, SubclassRegistryBase):
176+
class NDDicomLine(NDFrameLine, _SubclassRegistryBase):
177177

178178
def to_common(self, name: str, feature_schema_id: Cuid, segment_index: int,
179179
group_key: str) -> DICOMObjectAnnotation:
@@ -187,7 +187,7 @@ def to_common(self, name: str, feature_schema_id: Cuid, segment_index: int,
187187
group_key=group_key)
188188

189189

190-
class NDPolygon(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, SubclassRegistryBase):
190+
class NDPolygon(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase):
191191
polygon: List[_Point]
192192

193193
def to_common(self) -> Polygon:
@@ -218,7 +218,7 @@ def from_common(
218218
custom_metrics=custom_metrics)
219219

220220

221-
class NDRectangle(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, SubclassRegistryBase):
221+
class NDRectangle(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase):
222222
bbox: Bbox
223223

224224
def to_common(self) -> Rectangle:
@@ -254,7 +254,7 @@ def from_common(
254254
custom_metrics=custom_metrics)
255255

256256

257-
class NDDocumentRectangle(NDRectangle, SubclassRegistryBase):
257+
class NDDocumentRectangle(NDRectangle, _SubclassRegistryBase):
258258
page: int
259259
unit: str
260260

@@ -293,7 +293,7 @@ def from_common(
293293
custom_metrics=custom_metrics)
294294

295295

296-
class NDFrameRectangle(VideoSupported, SubclassRegistryBase):
296+
class NDFrameRectangle(VideoSupported, _SubclassRegistryBase):
297297
bbox: Bbox
298298
classifications: List[NDSubclassificationType] = []
299299

@@ -398,7 +398,7 @@ def to_common(self, name: str, feature_schema_id: Cuid, uuid: str,
398398
]
399399

400400

401-
class NDSegments(NDBaseObject, SubclassRegistryBase):
401+
class NDSegments(NDBaseObject, _SubclassRegistryBase):
402402
segments: List[NDSegment]
403403

404404
def to_common(self, name: str, feature_schema_id: Cuid):
@@ -425,7 +425,7 @@ def from_common(cls, segments: List[VideoObjectAnnotation], data: VideoData,
425425
uuid=extra.get('uuid'))
426426

427427

428-
class NDDicomSegments(NDBaseObject, DicomSupported, SubclassRegistryBase):
428+
class NDDicomSegments(NDBaseObject, DicomSupported, _SubclassRegistryBase):
429429
segments: List[NDDicomSegment]
430430

431431
def to_common(self, name: str, feature_schema_id: Cuid):
@@ -463,7 +463,7 @@ class _PNGMask(BaseModel):
463463
png: str
464464

465465

466-
class NDMask(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, SubclassRegistryBase):
466+
class NDMask(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase):
467467
mask: Union[_URIMask, _PNGMask]
468468

469469
def to_common(self) -> Mask:
@@ -517,7 +517,7 @@ class NDVideoMasksFramesInstances(BaseModel):
517517
instances: List[MaskInstance]
518518

519519

520-
class NDVideoMasks(NDJsonBase, ConfidenceMixin, CustomMetricsNotSupportedMixin, SubclassRegistryBase):
520+
class NDVideoMasks(NDJsonBase, ConfidenceMixin, CustomMetricsNotSupportedMixin, _SubclassRegistryBase):
521521
masks: NDVideoMasksFramesInstances
522522

523523
def to_common(self) -> VideoMaskAnnotation:
@@ -545,7 +545,7 @@ def from_common(cls, annotation, data):
545545
)
546546

547547

548-
class NDDicomMasks(NDVideoMasks, DicomSupported, SubclassRegistryBase):
548+
class NDDicomMasks(NDVideoMasks, DicomSupported, _SubclassRegistryBase):
549549

550550
def to_common(self) -> DICOMMaskAnnotation:
551551
return DICOMMaskAnnotation(
@@ -569,7 +569,7 @@ class Location(BaseModel):
569569
end: int
570570

571571

572-
class NDTextEntity(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, SubclassRegistryBase):
572+
class NDTextEntity(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase):
573573
location: Location
574574

575575
def to_common(self) -> TextEntity:
@@ -601,7 +601,7 @@ def from_common(
601601
custom_metrics=custom_metrics)
602602

603603

604-
class NDDocumentEntity(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, SubclassRegistryBase):
604+
class NDDocumentEntity(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase):
605605
name: str
606606
text_selections: List[DocumentTextSelection]
607607

@@ -633,7 +633,7 @@ def from_common(
633633
custom_metrics=custom_metrics)
634634

635635

636-
class NDConversationEntity(NDTextEntity, SubclassRegistryBase):
636+
class NDConversationEntity(NDTextEntity, _SubclassRegistryBase):
637637
message_id: str
638638

639639
def to_common(self) -> ConversationEntity:

libs/labelbox/src/labelbox/data/serialization/ndjson/relationship.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ...annotation_types.relationship import RelationshipAnnotation
66
from ...annotation_types.relationship import Relationship
77
from .objects import NDObjectType
8-
from .base import DataRow, SubclassRegistryBase
8+
from .base import DataRow, _SubclassRegistryBase
99

1010
SUPPORTED_ANNOTATIONS = NDObjectType
1111

@@ -16,7 +16,7 @@ class _Relationship(BaseModel):
1616
type: str
1717

1818

19-
class NDRelationship(NDAnnotation, SubclassRegistryBase):
19+
class NDRelationship(NDAnnotation, _SubclassRegistryBase):
2020
relationship: _Relationship
2121

2222
@staticmethod

0 commit comments

Comments
 (0)