1
1
from typing import Any , Dict , List , Union , Optional
2
2
3
+ from labelbox .data .annotation_types import ImageData , TextData , VideoData
3
4
from labelbox .data .mixins import ConfidenceMixin , CustomMetric , CustomMetricsMixin
4
5
from labelbox .data .serialization .ndjson .base import DataRow , NDAnnotation
5
6
7
+ from ....annotated_types import Cuid
8
+
6
9
from ...annotation_types .annotation import ClassificationAnnotation
7
10
from ...annotation_types .video import VideoClassificationAnnotation
8
11
from ...annotation_types .llm_prompt_response .prompt import PromptClassificationAnnotation , PromptText
9
12
from ...annotation_types .classification .classification import ClassificationAnswer , Text , Checklist , Radio
10
- from ...annotation_types .types import Cuid
11
- from ...annotation_types .data import TextData , VideoData , ImageData
12
13
from pydantic import model_validator , Field , BaseModel , ConfigDict , model_serializer
13
14
from pydantic .alias_generators import to_camel
14
15
from .base import _SubclassRegistryBase
@@ -18,11 +19,12 @@ class NDAnswer(ConfidenceMixin, CustomMetricsMixin):
18
19
name : Optional [str ] = None
19
20
schema_id : Optional [Cuid ] = None
20
21
classifications : Optional [List ['NDSubclassificationType' ]] = None
21
- model_config = ConfigDict (populate_by_name = True , alias_generator = to_camel )
22
+ model_config = ConfigDict (populate_by_name = True , alias_generator = to_camel )
22
23
23
24
@model_validator (mode = "after" )
24
25
def must_set_one (self ):
25
- if (not hasattr (self , "schema_id" ) or self .schema_id is None ) and (not hasattr (self , "name" ) or self .name is None ):
26
+ if (not hasattr (self , "schema_id" ) or self .schema_id
27
+ is None ) and (not hasattr (self , "name" ) or self .name is None ):
26
28
raise ValueError ("Schema id or name are not set. Set either one." )
27
29
return self
28
30
@@ -102,7 +104,10 @@ def from_common(cls, checklist: Checklist, name: str,
102
104
NDAnswer (name = answer .name ,
103
105
schema_id = answer .feature_schema_id ,
104
106
confidence = answer .confidence ,
105
- classifications = [NDSubclassification .from_common (annot ) for annot in answer .classifications ] if answer .classifications else None ,
107
+ classifications = [
108
+ NDSubclassification .from_common (annot )
109
+ for annot in answer .classifications
110
+ ] if answer .classifications else None ,
106
111
custom_metrics = answer .custom_metrics )
107
112
for answer in checklist .answer
108
113
],
@@ -152,8 +157,8 @@ class NDPromptTextSubclass(NDAnswer):
152
157
153
158
def to_common (self ) -> PromptText :
154
159
return PromptText (answer = self .answer ,
155
- confidence = self .confidence ,
156
- custom_metrics = self .custom_metrics )
160
+ confidence = self .confidence ,
161
+ custom_metrics = self .custom_metrics )
157
162
158
163
@classmethod
159
164
def from_common (cls , prompt_text : PromptText , name : str ,
@@ -194,7 +199,8 @@ def from_common(cls,
194
199
)
195
200
196
201
197
- class NDChecklist (NDAnnotation , NDChecklistSubclass , VideoSupported , _SubclassRegistryBase ):
202
+ class NDChecklist (NDAnnotation , NDChecklistSubclass , VideoSupported ,
203
+ _SubclassRegistryBase ):
198
204
199
205
@model_serializer (mode = "wrap" )
200
206
def serialize_model (self , handler ):
@@ -237,7 +243,8 @@ def from_common(
237
243
confidence = confidence )
238
244
239
245
240
- class NDRadio (NDAnnotation , NDRadioSubclass , VideoSupported , _SubclassRegistryBase ):
246
+ class NDRadio (NDAnnotation , NDRadioSubclass , VideoSupported ,
247
+ _SubclassRegistryBase ):
241
248
242
249
@classmethod
243
250
def from_common (
@@ -266,35 +273,32 @@ def from_common(
266
273
frames = extra .get ('frames' ),
267
274
message_id = message_id ,
268
275
confidence = confidence )
269
-
276
+
270
277
@model_serializer (mode = "wrap" )
271
278
def serialize_model (self , handler ):
272
279
res = handler (self )
273
280
if "classifications" in res and res ["classifications" ] == []:
274
281
del res ["classifications" ]
275
282
return res
276
-
277
-
283
+
284
+
278
285
class NDPromptText (NDAnnotation , NDPromptTextSubclass , _SubclassRegistryBase ):
279
-
286
+
280
287
@classmethod
281
- def from_common (
282
- cls ,
283
- uuid : str ,
284
- text : PromptText ,
285
- name ,
286
- data : Dict ,
287
- feature_schema_id : Cuid ,
288
- confidence : Optional [float ] = None
289
- ) -> "NDPromptText" :
290
- return cls (
291
- answer = text .answer ,
292
- data_row = DataRow (id = data .uid , global_key = data .global_key ),
293
- name = name ,
294
- schema_id = feature_schema_id ,
295
- uuid = uuid ,
296
- confidence = text .confidence ,
297
- custom_metrics = text .custom_metrics )
288
+ def from_common (cls ,
289
+ uuid : str ,
290
+ text : PromptText ,
291
+ name ,
292
+ data : Dict ,
293
+ feature_schema_id : Cuid ,
294
+ confidence : Optional [float ] = None ) -> "NDPromptText" :
295
+ return cls (answer = text .answer ,
296
+ data_row = DataRow (id = data .uid , global_key = data .global_key ),
297
+ name = name ,
298
+ schema_id = feature_schema_id ,
299
+ uuid = uuid ,
300
+ confidence = text .confidence ,
301
+ custom_metrics = text .custom_metrics )
298
302
299
303
300
304
class NDSubclassification :
@@ -350,7 +354,8 @@ def to_common(
350
354
for frame in annotation .frames :
351
355
for idx in range (frame .start , frame .end + 1 , 1 ):
352
356
results .append (
353
- VideoClassificationAnnotation (frame = idx , ** common .model_dump (exclude_none = True )))
357
+ VideoClassificationAnnotation (
358
+ frame = idx , ** common .model_dump (exclude_none = True )))
354
359
return results
355
360
356
361
@classmethod
@@ -382,6 +387,7 @@ def lookup_classification(
382
387
Radio : NDRadio
383
388
}.get (type (annotation .value ))
384
389
390
+
385
391
class NDPromptClassification :
386
392
387
393
@staticmethod
@@ -404,8 +410,7 @@ def from_common(
404
410
data : Union [VideoData , TextData , ImageData ]
405
411
) -> Union [NDTextSubclass , NDChecklistSubclass , NDRadioSubclass ]:
406
412
return NDPromptText .from_common (str (annotation ._uuid ), annotation .value ,
407
- annotation .name ,
408
- data ,
413
+ annotation .name , data ,
409
414
annotation .feature_schema_id ,
410
415
annotation .confidence )
411
416
@@ -427,4 +432,4 @@ def from_common(
427
432
# Make sure to keep NDChecklist prior to NDRadio in the list,
428
433
# otherwise list of answers gets parsed by NDRadio whereas NDChecklist must be used
429
434
NDClassificationType = Union [NDChecklist , NDRadio , NDText ]
430
- NDPromptClassificationType = Union [NDPromptText ]
435
+ NDPromptClassificationType = Union [NDPromptText ]
0 commit comments