diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py index fe45bed86..55d6b5e62 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py @@ -426,6 +426,42 @@ def from_common(cls, segment): ) +class NDSegments(NDBaseObject): + segments: List[NDSegment] + + def to_common(self, name: str, feature_schema_id: Cuid): + result = [] + for idx, segment in enumerate(self.segments): + result.extend( + segment.to_common( + name=name, + feature_schema_id=feature_schema_id, + segment_index=idx, + uuid=self.uuid, + ) + ) + return result + + @classmethod + def from_common( + cls, + segments: List[VideoObjectAnnotation], + data: GenericDataRowData, + name: str, + feature_schema_id: Cuid, + extra: Dict[str, Any], + ) -> "NDSegments": + segments = [NDSegment.from_common(segment) for segment in segments] + + return cls( + segments=segments, + data_row=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=extra.get("uuid"), + ) + + class _URIMask(BaseModel): instanceURI: str colorRGB: Tuple[int, int, int] @@ -693,7 +729,18 @@ def from_common( obj = cls.lookup_object(annotation) # if it is video segments - if obj == NDVideoMasks: + if obj == NDSegments: + first_video_annotation = annotation[0][0] + args = dict( + segments=annotation, + data=data, + name=first_video_annotation.name, + feature_schema_id=first_video_annotation.feature_schema_id, + extra=first_video_annotation.extra, + ) + + return obj.from_common(**args) + elif obj == NDVideoMasks: return obj.from_common(annotation, data) subclasses = [