Skip to content

Vb/pydaticv2 plt 600 #1657

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion libs/labelbox/src/labelbox/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,8 @@ def delete_model_config(self, id: str) -> bool:
params = {"id": id}
result = self.execute(query, params)
if not result:
raise labelbox.exceptions.ResourceNotFoundError(Entity.ModelConfig, params)
raise labelbox.exceptions.ResourceNotFoundError(
Entity.ModelConfig, params)
return result['deleteModelConfig']['success']

def create_dataset(self,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import abc
from uuid import UUID, uuid4
from typing import Any, Dict, Optional
from labelbox import pydantic_compat
from pydantic import PrivateAttr

from .feature import FeatureSchema


class BaseAnnotation(FeatureSchema, abc.ABC):
""" Base annotation class. Shouldn't be directly instantiated
"""
_uuid: Optional[UUID] = pydantic_compat.PrivateAttr()
_uuid: Optional[UUID] = PrivateAttr()
extra: Dict[str, Any] = {}

def __init__(self, **data):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
except:
from typing_extensions import Literal

from labelbox import pydantic_compat
from pydantic import BaseModel
from ..feature import FeatureSchema


# TODO: Replace when pydantic adds support for unions that don't coerce types
class _TempName(ConfidenceMixin, pydantic_compat.BaseModel):
class _TempName(ConfidenceMixin, BaseModel):
name: str

def dict(self, *args, **kwargs):
Expand Down Expand Up @@ -47,7 +47,7 @@ def dict(self, *args, **kwargs) -> Dict[str, str]:
return res


class Radio(ConfidenceMixin, CustomMetricsMixin, pydantic_compat.BaseModel):
class Radio(ConfidenceMixin, CustomMetricsMixin, BaseModel):
""" A classification with only one selected option allowed

>>> Radio(answer = ClassificationAnswer(name = "dog"))
Expand All @@ -66,7 +66,7 @@ class Checklist(_TempName):
answer: List[ClassificationAnswer]


class Text(ConfidenceMixin, CustomMetricsMixin, pydantic_compat.BaseModel):
class Text(ConfidenceMixin, CustomMetricsMixin, BaseModel):
""" Free form text

>>> Text(answer = "some text answer")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from abc import ABC
from typing import Optional, Dict, List, Any

from labelbox import pydantic_compat
from pydantic import BaseModel


class BaseData(pydantic_compat.BaseModel, ABC):
class BaseData(BaseModel, ABC):
"""
Base class for objects representing data.
This class shouldn't directly be used
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Callable, Literal, Optional

from labelbox import pydantic_compat
from pydantic import BaseModel, model_validator
from labelbox.data.annotation_types.data.base_data import BaseData
from labelbox.utils import _NoCoercionMixin

Expand All @@ -14,7 +14,8 @@ class GenericDataRowData(BaseData, _NoCoercionMixin):
def create_url(self, signer: Callable[[bytes], str]) -> Optional[str]:
return self.url

@pydantic_compat.root_validator(pre=True)
@model_validator(mode='before')
@classmethod
def validate_one_datarow_key_present(cls, data):
keys = ['external_id', 'global_key', 'uid']
count = sum([key in data for key in keys])
Expand Down
30 changes: 10 additions & 20 deletions libs/labelbox/src/labelbox/data/annotation_types/data/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
import requests
import numpy as np

from labelbox import pydantic_compat
from pydantic import BaseModel, model_validator, ConfigDict, Extra
from labelbox.exceptions import InternalServerError
from .base_data import BaseData
from ..types import TypedArray


class RasterData(pydantic_compat.BaseModel, ABC):
class RasterData(BaseModel, ABC):
"""Represents an image or segmentation mask.
"""
im_bytes: Optional[bytes] = None
Expand Down Expand Up @@ -155,28 +155,22 @@ def create_url(self, signer: Callable[[bytes], str]) -> str:
"One of url, im_bytes, file_path, arr must not be None.")
return self.url

@pydantic_compat.root_validator()
def validate_args(cls, values):
file_path = values.get("file_path")
im_bytes = values.get("im_bytes")
url = values.get("url")
arr = values.get("arr")
uid = values.get('uid')
global_key = values.get('global_key')
if uid == file_path == im_bytes == url == global_key == None and arr is None:
@model_validator(mode='after')
def validate_args(self):
if self.file_path == self.im_bytes == self.url == self.global_key == None and self.arr is None:
raise ValueError(
"One of `file_path`, `im_bytes`, `url`, `uid`, `global_key` or `arr` required."
)
if arr is not None:
if arr.dtype != np.uint8:
if self.arr is not None:
if self.arr.dtype != np.uint8:
raise TypeError(
"Numpy array representing segmentation mask must be np.uint8"
)
elif len(arr.shape) != 3:
elif len(self.arr.shape) != 3:
raise ValueError(
"unsupported image format. Must be 3D ([H,W,C])."
f"Use {cls.__name__}.from_2D_arr to construct from 2D")
return values
return self

def __repr__(self) -> str:
symbol_or_none = lambda data: '...' if data is not None else None
Expand All @@ -185,11 +179,7 @@ def __repr__(self) -> str:
f"url={self.url}," \
f"arr={symbol_or_none(self.arr)})"

class Config:
# Required for sharing references
copy_on_model_validation = 'none'
# Required for discriminating between data types
extra = 'forbid'
model_config = ConfigDict(extra='forbid',)


class MaskData(RasterData):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from requests.exceptions import ConnectTimeout
from google.api_core import retry

from labelbox import pydantic_compat
from pydantic import BaseModel, model_validator
from labelbox.exceptions import InternalServerError
from labelbox.typing_imports import Literal
from labelbox.utils import _NoCoercionMixin
Expand Down Expand Up @@ -90,7 +90,8 @@ def create_url(self, signer: Callable[[bytes], str]) -> None:
"One of url, im_bytes, file_path, numpy must not be None.")
return self.url

@pydantic_compat.root_validator
@model_validator(mode='before')
@classmethod
def validate_date(cls, values):
file_path = values.get("file_path")
text = values.get("text")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from PIL import Image
from pyproj import Transformer
from pygeotile.point import Point as PygeoPoint
from labelbox import pydantic_compat
from pydantic import BaseModel, model_validator, field_validator, ConfigDict

from labelbox.data.annotation_types import Rectangle, Point, Line, Polygon
from .base_data import BaseData
Expand Down Expand Up @@ -40,7 +40,7 @@ class EPSG(Enum):
EPSG3857 = 3857


class TiledBounds(pydantic_compat.BaseModel):
class TiledBounds(BaseModel):
""" Bounds for a tiled image asset related to the relevant epsg.

Bounds should be Point objects.
Expand All @@ -54,7 +54,8 @@ class TiledBounds(pydantic_compat.BaseModel):
epsg: EPSG
bounds: List[Point]

@pydantic_compat.validator('bounds')
@field_validator('bounds')
@classmethod
def validate_bounds_not_equal(cls, bounds):
first_bound = bounds[0]
second_bound = bounds[1]
Expand All @@ -66,7 +67,8 @@ def validate_bounds_not_equal(cls, bounds):
return bounds

#validate bounds are within lat,lng range if they are EPSG4326
@pydantic_compat.root_validator
@model_validator(mode='before')
@classmethod
def validate_bounds_lat_lng(cls, values):
epsg = values.get('epsg')
bounds = values.get('bounds')
Expand All @@ -82,7 +84,7 @@ def validate_bounds_lat_lng(cls, values):
return values


class TileLayer(pydantic_compat.BaseModel):
class TileLayer(BaseModel):
""" Url that contains the tile layer. Must be in the format:

https://c.tile.openstreetmap.org/{z}/{x}/{y}.png
Expand All @@ -98,7 +100,8 @@ class TileLayer(pydantic_compat.BaseModel):
def asdict(self) -> Dict[str, str]:
return {"tileLayerUrl": self.url, "name": self.name}

@pydantic_compat.validator('url')
@field_validator('url')
@classmethod
def validate_url(cls, url):
xyz_format = "/{z}/{x}/{y}"
if xyz_format not in url:
Expand Down Expand Up @@ -343,7 +346,8 @@ def _validate_num_tiles(self, xstart: float, ystart: float, xend: float,
f"Max allowed tiles are {max_tiles}"
f"Increase max tiles or reduce zoom level.")

@pydantic_compat.validator('zoom_levels')
@field_validator('zoom_levels')
@classmethod
def validate_zoom_levels(cls, zoom_levels):
if zoom_levels[0] > zoom_levels[1]:
raise ValueError(
Expand All @@ -352,13 +356,12 @@ def validate_zoom_levels(cls, zoom_levels):
return zoom_levels


class EPSGTransformer(pydantic_compat.BaseModel):
class EPSGTransformer(BaseModel):
"""Transformer class between different EPSG's. Useful when wanting to project
in different formats.
"""

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)

transformer: Any

Expand Down
20 changes: 6 additions & 14 deletions libs/labelbox/src/labelbox/data/annotation_types/data/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .base_data import BaseData
from ..types import TypedArray

from labelbox import pydantic_compat
from pydantic import model_validator, ConfigDict, Extra

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -148,25 +148,17 @@ def frames_to_video(self,
out.release()
return file_path

@pydantic_compat.root_validator
def validate_data(cls, values):
file_path = values.get("file_path")
url = values.get("url")
frames = values.get("frames")
uid = values.get("uid")
global_key = values.get("global_key")

if uid == file_path == frames == url == global_key == None:
@model_validator(mode='after')
def validate_data(self):
if self.uid == self.file_path == self.frames == self.url == self.global_key == None:
raise ValueError(
"One of `file_path`, `frames`, `uid`, `global_key` or `url` required."
)
return values
return self

def __repr__(self) -> str:
return f"VideoData(file_path={self.file_path}," \
f"frames={'...' if self.frames is not None else None}," \
f"url={self.url})"

class Config:
# Required for discriminating between data types
extra = 'forbid'
model_config = ConfigDict(extra='forbid',)
13 changes: 7 additions & 6 deletions libs/labelbox/src/labelbox/data/annotation_types/feature.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Optional

from labelbox import pydantic_compat
from pydantic import BaseModel, model_validator

from .types import Cuid


class FeatureSchema(pydantic_compat.BaseModel):
class FeatureSchema(BaseModel):
"""
Class that represents a feature schema.
Could be a annotation, a subclass, or an option.
Expand All @@ -14,13 +14,14 @@ class FeatureSchema(pydantic_compat.BaseModel):
name: Optional[str] = None
feature_schema_id: Optional[Cuid] = None

@pydantic_compat.root_validator
def must_set_one(cls, values):
if values['feature_schema_id'] is None and values['name'] is None:
@model_validator(mode='after')
def must_set_one(self):
if self.feature_schema_id is None and self.name is None:
raise ValueError(
"Must set either feature_schema_id or name for all feature schemas"
)
return values

return self

def dict(self, *args, **kwargs):
res = super().dict(*args, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

import geojson
import numpy as np
from labelbox import pydantic_compat
from pydantic import BaseModel

from shapely import geometry as geom


class Geometry(pydantic_compat.BaseModel, ABC):
class Geometry(BaseModel, ABC):
"""Abstract base class for geometry objects
"""
extra: Dict[str, Any] = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .point import Point
from .geometry import Geometry

from labelbox import pydantic_compat
from pydantic import field_validator


class Line(Geometry):
Expand Down Expand Up @@ -65,7 +65,8 @@ def draw(self,
color=color,
thickness=thickness)

@pydantic_compat.validator('points')
@field_validator('points')
@classmethod
def is_geom_valid(cls, points):
if len(points) < 2:
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..data import MaskData
from .geometry import Geometry

from labelbox import pydantic_compat
from pydantic import field_validator


class Mask(Geometry):
Expand Down Expand Up @@ -122,7 +122,8 @@ def create_url(self, signer: Callable[[bytes], str]) -> str:
"""
return self.mask.create_url(signer)

@pydantic_compat.validator('color')
@field_validator('color')
@classmethod
def is_valid_color(cls, color):
if isinstance(color, (tuple, list)):
if len(color) == 1:
Expand Down
Loading
Loading