Skip to content

Commit 25bf82f

Browse files
author
Val Brodsky
committed
Deal with TypedArray special validations
1 parent 76dd198 commit 25bf82f

File tree

3 files changed

+30
-33
lines changed

3 files changed

+30
-33
lines changed

libs/labelbox/src/labelbox/data/annotation_types/data/raster.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -155,29 +155,22 @@ def create_url(self, signer: Callable[[bytes], str]) -> str:
155155
"One of url, im_bytes, file_path, arr must not be None.")
156156
return self.url
157157

158-
@model_validator(mode='before')
159-
@classmethod
160-
def validate_args(cls, values):
161-
file_path = values.get("file_path")
162-
im_bytes = values.get("im_bytes")
163-
url = values.get("url")
164-
arr = values.get("arr")
165-
uid = values.get('uid')
166-
global_key = values.get('global_key')
167-
if uid == file_path == im_bytes == url == global_key == None and arr is None:
158+
@model_validator(mode='after')
159+
def validate_args(self):
160+
if self.uid == self.file_path == self.im_bytes == self.url == self.global_key == None and self.arr is None:
168161
raise ValueError(
169162
"One of `file_path`, `im_bytes`, `url`, `uid`, `global_key` or `arr` required."
170163
)
171-
if arr is not None:
172-
if arr.dtype != np.uint8:
164+
if self.arr is not None:
165+
if self.arr.dtype != np.uint8:
173166
raise TypeError(
174167
"Numpy array representing segmentation mask must be np.uint8"
175168
)
176-
elif len(arr.shape) != 3:
169+
elif len(self.arr.shape) != 3:
177170
raise ValueError(
178171
"unsupported image format. Must be 3D ([H,W,C])."
179172
f"Use {cls.__name__}.from_2D_arr to construct from 2D")
180-
return values
173+
return self
181174

182175
def __repr__(self) -> str:
183176
symbol_or_none = lambda data: '...' if data is not None else None

libs/labelbox/src/labelbox/data/annotation_types/data/video.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -148,20 +148,13 @@ def frames_to_video(self,
148148
out.release()
149149
return file_path
150150

151-
@model_validator(mode='before')
152-
@classmethod
153-
def validate_data(cls, values):
154-
file_path = values.get("file_path")
155-
url = values.get("url")
156-
frames = values.get("frames")
157-
uid = values.get("uid")
158-
global_key = values.get("global_key")
159-
160-
if uid == file_path == frames == url == global_key == None:
151+
@model_validator(mode='after')
152+
def validate_data(self):
153+
if self.uid == self.file_path == self.frames == self.url == self.global_key == None:
161154
raise ValueError(
162155
"One of `file_path`, `frames`, `uid`, `global_key` or `url` required."
163156
)
164-
return values
157+
return self
165158

166159
def __repr__(self) -> str:
167160
return f"VideoData(file_path={self.file_path}," \

libs/labelbox/src/labelbox/data/annotation_types/types.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import sys
2-
from typing import Generic, TypeVar, Any
2+
from typing import Generic, TypeVar, Any, Type
33

44
from labelbox.typing_imports import Annotated
55
from packaging import version
66
import numpy as np
77

8-
from pydantic import ValidationInfo, Field
8+
from pydantic import Field, GetCoreSchemaHandler, TypeAdapter
9+
from pydantic_core import core_schema
910

1011
Cuid = Annotated[str, Field(min_length=25, max_length=25)]
1112

@@ -15,18 +16,28 @@
1516

1617
class _TypedArray(np.ndarray, Generic[DType, DShape]):
1718

19+
# @classmethod
20+
# def __get_validators__(cls):
21+
# yield cls.validate
22+
1823
@classmethod
19-
def __get_validators__(cls):
20-
yield cls.validate
24+
def __get_pydantic_core_schema__(
25+
cls, source_type: Type[Any],
26+
handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
27+
28+
# assert source is CompressedString
29+
return core_schema.with_info_after_validator_function(
30+
function=cls.validate,
31+
schema=core_schema.any_schema(),
32+
field_name=source_type.__args__[-1].__args__[0])
2133

2234
@classmethod
23-
def validate(cls, val, field_info: ValidationInfo):
35+
def validate(cls, val, info):
2436
if not isinstance(val, np.ndarray):
2537
raise TypeError(f"Expected numpy array. Found {type(val)}")
2638

27-
actual_dtype = cls.model_fields[field_info.name].type_.__args__[0]
28-
29-
if val.dtype != actual_dtype:
39+
actual_type = info.field_name
40+
if str(val.dtype) != actual_type:
3041
raise TypeError(
3142
f"Expected numpy array have type {actual_dtype}. Found {val.dtype}"
3243
)

0 commit comments

Comments
 (0)