|
1 | 1 | from collections import defaultdict
|
| 2 | +import copy |
2 | 3 | from itertools import groupby
|
3 | 4 | from operator import itemgetter
|
4 |
| -from typing import Dict, Generator, List, Tuple, Union |
| 5 | +from typing import Generator, List, Tuple, Union |
| 6 | +from uuid import uuid4 |
5 | 7 |
|
6 | 8 | from pydantic import BaseModel
|
7 | 9 |
|
8 | 10 | from ...annotation_types.annotation import (
|
9 | 11 | ClassificationAnnotation,
|
10 | 12 | ObjectAnnotation,
|
11 | 13 | )
|
12 |
| -from ...annotation_types.collection import LabelCollection, LabelGenerator |
13 |
| -from ...annotation_types.data.generic_data_row_data import GenericDataRowData |
| 14 | +from ...annotation_types.collection import LabelCollection |
14 | 15 | from ...annotation_types.label import Label
|
15 | 16 | from ...annotation_types.llm_prompt_response.prompt import (
|
16 | 17 | PromptClassificationAnnotation,
|
|
23 | 24 | VideoMaskAnnotation,
|
24 | 25 | VideoObjectAnnotation,
|
25 | 26 | )
|
26 |
| -from .base import DataRow |
27 | 27 | from .classification import (
|
28 | 28 | NDChecklistSubclass,
|
29 | 29 | NDClassification,
|
|
60 | 60 | class NDLabel(BaseModel):
|
61 | 61 | annotations: AnnotationType
|
62 | 62 |
|
63 |
| - class _Relationship(BaseModel): |
64 |
| - """This object holds information about the relationship""" |
65 |
| - |
66 |
| - ndjson: NDRelationship |
67 |
| - source: str |
68 |
| - target: str |
69 |
| - |
70 |
| - class _AnnotationGroup(BaseModel): |
71 |
| - """Stores all the annotations and relationships per datarow""" |
72 |
| - |
73 |
| - data_row: DataRow = None |
74 |
| - ndjson_annotations: Dict[str, AnnotationType] = {} |
75 |
| - relationships: List["NDLabel._Relationship"] = [] |
76 |
| - |
77 |
| - def to_common(self) -> LabelGenerator: |
78 |
| - annotation_groups = defaultdict(NDLabel._AnnotationGroup) |
79 |
| - |
80 |
| - for ndjson_annotation in self.annotations: |
81 |
| - key = ( |
82 |
| - ndjson_annotation.data_row.id |
83 |
| - or ndjson_annotation.data_row.global_key |
84 |
| - ) |
85 |
| - group = annotation_groups[key] |
86 |
| - |
87 |
| - if isinstance(ndjson_annotation, NDRelationship): |
88 |
| - group.relationships.append( |
89 |
| - NDLabel._Relationship( |
90 |
| - ndjson=ndjson_annotation, |
91 |
| - source=ndjson_annotation.relationship.source, |
92 |
| - target=ndjson_annotation.relationship.target, |
93 |
| - ) |
94 |
| - ) |
95 |
| - else: |
96 |
| - # if this is the first object in this group, we |
97 |
| - # take note of the DataRow this group belongs to |
98 |
| - # and store it in the _AnnotationGroupTuple |
99 |
| - if not group.ndjson_annotations: |
100 |
| - group.data_row = ndjson_annotation.data_row |
101 |
| - |
102 |
| - # if this assertion fails and it's a valid case, |
103 |
| - # we need to change the value type of |
104 |
| - # `_AnnotationGroupTuple.ndjson_objects` to accept a list of objects |
105 |
| - # and adapt the code to support duplicate UUIDs |
106 |
| - assert ( |
107 |
| - ndjson_annotation.uuid not in group.ndjson_annotations |
108 |
| - ), f"UUID '{ndjson_annotation.uuid}' is not unique" |
109 |
| - |
110 |
| - group.ndjson_annotations[ndjson_annotation.uuid] = ( |
111 |
| - ndjson_annotation |
112 |
| - ) |
113 |
| - |
114 |
| - return LabelGenerator( |
115 |
| - data=self._generate_annotations(annotation_groups) |
116 |
| - ) |
117 |
| - |
118 | 63 | @classmethod
|
119 | 64 | def from_common(
|
120 | 65 | cls, data: LabelCollection
|
121 | 66 | ) -> Generator["NDLabel", None, None]:
|
122 | 67 | for label in data:
|
| 68 | + if all( |
| 69 | + isinstance(model, RelationshipAnnotation) |
| 70 | + for model in label.annotations |
| 71 | + ): |
| 72 | + yield from cls._create_relationship_annotations(label) |
123 | 73 | yield from cls._create_non_video_annotations(label)
|
124 | 74 | yield from cls._create_video_annotations(label)
|
125 | 75 |
|
126 |
| - def _generate_annotations( |
127 |
| - self, annotation_groups: Dict[str, _AnnotationGroup] |
128 |
| - ) -> Generator[Label, None, None]: |
129 |
| - for _, group in annotation_groups.items(): |
130 |
| - relationship_annotations: Dict[str, ObjectAnnotation] = {} |
131 |
| - annotations = [] |
132 |
| - # first, we iterate through all the NDJSON objects and store the |
133 |
| - # deserialized objects in the _AnnotationGroupTuple |
134 |
| - # object *if* the object can be used in a relationship |
135 |
| - for uuid, ndjson_annotation in group.ndjson_annotations.items(): |
136 |
| - if isinstance(ndjson_annotation, NDSegments): |
137 |
| - annotations.extend( |
138 |
| - NDSegments.to_common( |
139 |
| - ndjson_annotation, |
140 |
| - ndjson_annotation.name, |
141 |
| - ndjson_annotation.schema_id, |
142 |
| - ) |
143 |
| - ) |
144 |
| - elif isinstance(ndjson_annotation, NDVideoMasks): |
145 |
| - annotations.append( |
146 |
| - NDVideoMasks.to_common(ndjson_annotation) |
147 |
| - ) |
148 |
| - elif isinstance(ndjson_annotation, NDObjectType.__args__): |
149 |
| - annotation = NDObject.to_common(ndjson_annotation) |
150 |
| - annotations.append(annotation) |
151 |
| - relationship_annotations[uuid] = annotation |
152 |
| - elif isinstance( |
153 |
| - ndjson_annotation, NDClassificationType.__args__ |
154 |
| - ): |
155 |
| - annotations.extend( |
156 |
| - NDClassification.to_common(ndjson_annotation) |
157 |
| - ) |
158 |
| - elif isinstance( |
159 |
| - ndjson_annotation, (NDScalarMetric, NDConfusionMatrixMetric) |
160 |
| - ): |
161 |
| - annotations.append( |
162 |
| - NDMetricAnnotation.to_common(ndjson_annotation) |
163 |
| - ) |
164 |
| - elif isinstance(ndjson_annotation, NDPromptClassificationType): |
165 |
| - annotation = NDPromptClassification.to_common( |
166 |
| - ndjson_annotation |
167 |
| - ) |
168 |
| - annotations.append(annotation) |
169 |
| - elif isinstance(ndjson_annotation, NDMessageTask): |
170 |
| - annotations.append(ndjson_annotation.to_common()) |
171 |
| - else: |
172 |
| - raise TypeError( |
173 |
| - f"Unsupported annotation. {type(ndjson_annotation)}" |
174 |
| - ) |
175 |
| - |
176 |
| - # after all the annotations have been discovered, we can now create |
177 |
| - # the relationship objects and use references to the objects |
178 |
| - # involved |
179 |
| - for relationship in group.relationships: |
180 |
| - try: |
181 |
| - source, target = ( |
182 |
| - relationship_annotations[relationship.source], |
183 |
| - relationship_annotations[relationship.target], |
184 |
| - ) |
185 |
| - except KeyError: |
186 |
| - raise ValueError( |
187 |
| - f"Relationship object refers to nonexistent object with UUID '{relationship.source}' and/or '{relationship.target}'" |
188 |
| - ) |
189 |
| - annotations.append( |
190 |
| - NDRelationship.to_common( |
191 |
| - relationship.ndjson, source, target |
192 |
| - ) |
193 |
| - ) |
194 |
| - |
195 |
| - yield Label( |
196 |
| - annotations=annotations, |
197 |
| - data=GenericDataRowData, |
198 |
| - ) |
199 |
| - |
200 | 76 | @staticmethod
|
201 | 77 | def _get_consecutive_frames(
|
202 | 78 | frames_indices: List[int],
|
@@ -317,3 +193,26 @@ def _create_non_video_annotations(cls, label: Label):
|
317 | 193 | raise TypeError(
|
318 | 194 | f"Unable to convert object to MAL format. `{type(getattr(annotation, 'value',annotation))}`"
|
319 | 195 | )
|
| 196 | + |
| 197 | + def _create_relationship_annotations(cls, label: Label): |
| 198 | + relationship_annotations = [ |
| 199 | + annotation |
| 200 | + for annotation in label.annotations |
| 201 | + if isinstance(annotation, RelationshipAnnotation) |
| 202 | + ] |
| 203 | + for relationship_annotation in relationship_annotations: |
| 204 | + uuid1 = uuid4() |
| 205 | + uuid2 = uuid4() |
| 206 | + source = copy.copy(relationship_annotation.value.source) |
| 207 | + target = copy.copy(relationship_annotation.value.target) |
| 208 | + if not isinstance(source, ObjectAnnotation) or not isinstance( |
| 209 | + target, ObjectAnnotation |
| 210 | + ): |
| 211 | + raise TypeError( |
| 212 | + f"Unable to create relationship with non ObjectAnnotations. `Source: {type(source)} Target: {type(target)}`" |
| 213 | + ) |
| 214 | + if not source._uuid: |
| 215 | + source._uuid = uuid1 |
| 216 | + if not target._uuid: |
| 217 | + target._uuid = uuid2 |
| 218 | + yield relationship_annotation |
0 commit comments