|
| 1 | +from typing import Any, Dict, List, Union, Optional |
| 2 | + |
| 3 | +from pydantic import BaseModel, validator |
| 4 | + |
| 5 | +from labelbox.utils import camel_case |
| 6 | +from ...annotation_types.annotation import ClassificationAnnotation, VideoClassificationAnnotation |
| 7 | +from ...annotation_types.classification.classification import ClassificationAnswer, Dropdown, Text, Checklist, Radio |
| 8 | +from ...annotation_types.types import Cuid |
| 9 | +from ...annotation_types.data import TextData, VideoData, RasterData |
| 10 | +from .base import NDAnnotation |
| 11 | + |
| 12 | + |
| 13 | +class NDFeature(BaseModel): |
| 14 | + schema_id: Cuid |
| 15 | + |
| 16 | + @validator('schema_id', pre=True, always=True) |
| 17 | + def validate_id(cls, v): |
| 18 | + if v is None: |
| 19 | + raise ValueError( |
| 20 | + "Schema ids are not set. Use `LabelGenerator.assign_schema_ids`, `LabelList.assign_schema_ids`, or `Label.assign_schema_ids`." |
| 21 | + ) |
| 22 | + return v |
| 23 | + |
| 24 | + class Config: |
| 25 | + allow_population_by_field_name = True |
| 26 | + alias_generator = camel_case |
| 27 | + |
| 28 | + |
| 29 | +class FrameLocation(BaseModel): |
| 30 | + end: int |
| 31 | + start: int |
| 32 | + |
| 33 | + |
| 34 | +class VideoSupported(BaseModel): |
| 35 | + #Note that frames are only allowed as top level inferences for video |
| 36 | + frames: Optional[List[FrameLocation]] = None |
| 37 | + |
| 38 | + def dict(self, *args, **kwargs): |
| 39 | + res = super().dict(*args, **kwargs) |
| 40 | + # This means these are no video frames .. |
| 41 | + if self.frames is None: |
| 42 | + res.pop('frames') |
| 43 | + return res |
| 44 | + |
| 45 | + |
| 46 | +class NDTextSubclass(NDFeature): |
| 47 | + answer: str |
| 48 | + |
| 49 | + def to_common(self) -> Text: |
| 50 | + return Text(answer=self.answer) |
| 51 | + |
| 52 | + @classmethod |
| 53 | + def from_common(cls, text: Text, schema_id: Cuid) -> "NDTextSubclass": |
| 54 | + return cls(answer=text.answer, schema_id=schema_id) |
| 55 | + |
| 56 | + |
| 57 | +class NDChecklistSubclass(NDFeature): |
| 58 | + answer: List[NDFeature] |
| 59 | + |
| 60 | + def to_common(self) -> Checklist: |
| 61 | + return Checklist(answer=[ |
| 62 | + ClassificationAnswer(schema_id=answer.schema_id) |
| 63 | + for answer in self.answer |
| 64 | + ]) |
| 65 | + |
| 66 | + @classmethod |
| 67 | + def from_common(cls, checklist: Checklist, |
| 68 | + schema_id: Cuid) -> "NDChecklistSubclass": |
| 69 | + return cls(answer=[ |
| 70 | + NDFeature(schema_id=answer.schema_id) for answer in checklist.answer |
| 71 | + ], |
| 72 | + schema_id=schema_id) |
| 73 | + |
| 74 | + |
| 75 | +class NDRadioSubclass(NDFeature): |
| 76 | + answer: NDFeature |
| 77 | + |
| 78 | + def to_common(self) -> Radio: |
| 79 | + return Radio(answer=ClassificationAnswer( |
| 80 | + schema_id=self.answer.schema_id)) |
| 81 | + |
| 82 | + @classmethod |
| 83 | + def from_common(cls, radio: Radio, schema_id: Cuid) -> "NDRadioSubclass": |
| 84 | + return cls(answer=NDFeature(schema_id=radio.answer.schema_id), |
| 85 | + schema_id=schema_id) |
| 86 | + |
| 87 | + |
| 88 | +### ====== End of subclasses |
| 89 | + |
| 90 | + |
| 91 | +class NDText(NDAnnotation, NDTextSubclass): |
| 92 | + |
| 93 | + @classmethod |
| 94 | + def from_common(cls, text: Text, schema_id: Cuid, extra: Dict[str, Any], |
| 95 | + data: Union[TextData, RasterData]) -> "NDText": |
| 96 | + return cls( |
| 97 | + answer=text.answer, |
| 98 | + dataRow={'id': data.uid}, |
| 99 | + schema_id=schema_id, |
| 100 | + uuid=extra.get('uuid'), |
| 101 | + ) |
| 102 | + |
| 103 | + |
| 104 | +class NDChecklist(NDAnnotation, NDChecklistSubclass, VideoSupported): |
| 105 | + |
| 106 | + @classmethod |
| 107 | + def from_common( |
| 108 | + cls, checklist: Checklist, schema_id: Cuid, extra: Dict[str, Any], |
| 109 | + data: Union[VideoData, TextData, RasterData]) -> "NDChecklist": |
| 110 | + return cls(answer=[ |
| 111 | + NDFeature(schema_id=answer.schema_id) for answer in checklist.answer |
| 112 | + ], |
| 113 | + dataRow={'id': data.uid}, |
| 114 | + schema_id=schema_id, |
| 115 | + uuid=extra.get('uuid'), |
| 116 | + frames=extra.get('frames')) |
| 117 | + |
| 118 | + |
| 119 | +class NDRadio(NDAnnotation, NDRadioSubclass, VideoSupported): |
| 120 | + |
| 121 | + @classmethod |
| 122 | + def from_common(cls, radio: Radio, schema_id: Cuid, extra: Dict[str, Any], |
| 123 | + data: Union[VideoData, TextData, RasterData]) -> "NDRadio": |
| 124 | + return cls(answer=NDFeature(schema_id=radio.answer.schema_id), |
| 125 | + dataRow={'id': data.uid}, |
| 126 | + schema_id=schema_id, |
| 127 | + uuid=extra.get('uuid'), |
| 128 | + frames=extra.get('frames')) |
| 129 | + |
| 130 | + |
| 131 | +class NDSubclassification: |
| 132 | + |
| 133 | + @classmethod |
| 134 | + def from_common( |
| 135 | + cls, annotation: ClassificationAnnotation |
| 136 | + ) -> Union[NDTextSubclass, NDChecklistSubclass, NDRadioSubclass]: |
| 137 | + classify_obj = cls.lookup_subclassification(annotation) |
| 138 | + if classify_obj is None: |
| 139 | + raise TypeError( |
| 140 | + f"Unable to convert object to MAL format. `{type(annotation.value)}`" |
| 141 | + ) |
| 142 | + return classify_obj.from_common(annotation.value, annotation.schema_id) |
| 143 | + |
| 144 | + @staticmethod |
| 145 | + def to_common( |
| 146 | + annotation: "NDClassificationType") -> ClassificationAnnotation: |
| 147 | + return ClassificationAnnotation(value=annotation.to_common(), |
| 148 | + schema_id=annotation.schema_id) |
| 149 | + |
| 150 | + @staticmethod |
| 151 | + def lookup_subclassification( |
| 152 | + annotation: ClassificationAnnotation |
| 153 | + ) -> Union[NDTextSubclass, NDChecklistSubclass, NDRadioSubclass]: |
| 154 | + if isinstance(annotation, Dropdown): |
| 155 | + raise TypeError("Dropdowns are not supported for MAL") |
| 156 | + return { |
| 157 | + Text: NDTextSubclass, |
| 158 | + Checklist: NDChecklistSubclass, |
| 159 | + Radio: NDRadioSubclass, |
| 160 | + }.get(type(annotation.value)) |
| 161 | + |
| 162 | + |
| 163 | +class NDClassification: |
| 164 | + |
| 165 | + @staticmethod |
| 166 | + def to_common( |
| 167 | + annotation: "NDClassificationType" |
| 168 | + ) -> Union[ClassificationAnnotation, VideoClassificationAnnotation]: |
| 169 | + common = ClassificationAnnotation(value=annotation.to_common(), |
| 170 | + schema_id=annotation.schema_id, |
| 171 | + extra={'uuid': annotation.uuid}) |
| 172 | + if getattr(annotation, 'frames', None) is None: |
| 173 | + return [common] |
| 174 | + results = [] |
| 175 | + for frame in annotation.frames: |
| 176 | + for idx in range(frame.start, frame.end + 1, 1): |
| 177 | + results.append( |
| 178 | + VideoClassificationAnnotation(frame=idx, **common.dict())) |
| 179 | + return results |
| 180 | + |
| 181 | + @classmethod |
| 182 | + def from_common( |
| 183 | + cls, annotation: Union[ClassificationAnnotation, |
| 184 | + VideoClassificationAnnotation], |
| 185 | + data: Union[VideoData, TextData, RasterData] |
| 186 | + ) -> Union[NDTextSubclass, NDChecklistSubclass, NDRadioSubclass]: |
| 187 | + classify_obj = cls.lookup_classification(annotation) |
| 188 | + if classify_obj is None: |
| 189 | + raise TypeError( |
| 190 | + f"Unable to convert object to MAL format. `{type(annotation.value)}`" |
| 191 | + ) |
| 192 | + return classify_obj.from_common(annotation.value, annotation.schema_id, |
| 193 | + annotation.extra, data) |
| 194 | + |
| 195 | + @staticmethod |
| 196 | + def lookup_classification( |
| 197 | + annotation: Union[ClassificationAnnotation, |
| 198 | + VideoClassificationAnnotation] |
| 199 | + ) -> Union[NDText, NDChecklist, NDRadio]: |
| 200 | + if isinstance(annotation, Dropdown): |
| 201 | + raise TypeError("Dropdowns are not supported for MAL") |
| 202 | + return { |
| 203 | + Text: NDText, |
| 204 | + Checklist: NDChecklist, |
| 205 | + Radio: NDRadio, |
| 206 | + Dropdown: NDChecklist, |
| 207 | + }.get(type(annotation.value)) |
| 208 | + |
| 209 | + |
| 210 | +NDSubclassificationType = Union[NDRadioSubclass, NDChecklistSubclass, |
| 211 | + NDTextSubclass] |
| 212 | +NDClassificationType = Union[NDRadio, NDChecklist, NDText] |
0 commit comments