Skip to content

Commit 57f60bd

Browse files
author
Gareth
authored
Merge pull request #293 from Labelbox/gj/annot-types-polish
Annotation type documentation
2 parents 0736823 + f599fba commit 57f60bd

File tree

11 files changed

+217
-68
lines changed

11 files changed

+217
-68
lines changed
Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,92 @@
1+
import abc
12
from typing import Any, Dict, List, Union
23

34
from .classification import Checklist, Dropdown, Radio, Text
45
from .feature import FeatureSchema
5-
from .geometry import Geometry
6+
from .geometry import Geometry, Rectangle, Point
67
from .ner import TextEntity
78

89

9-
class BaseAnnotation(FeatureSchema):
10+
class BaseAnnotation(FeatureSchema, abc.ABC):
1011
""" Base annotation class. Shouldn't be directly instantiated
1112
"""
1213
extra: Dict[str, Any] = {}
1314

1415

1516
class ClassificationAnnotation(BaseAnnotation):
16-
"""Class representing classification annotations (annotations that don't have a location) """
17+
"""Classification annotations (non localized)
18+
19+
>>> ClassificationAnnotation(
20+
>>> value=Text(answer="my caption message"),
21+
>>> feature_schema_id="my-feature-schema-id"
22+
>>> )
23+
24+
Args:
25+
name (Optional[str])
26+
feature_schema_id (Optional[Cuid])
27+
value (Union[Text, Checklist, Radio, Dropdown])
28+
extra (Dict[str, Any])
29+
"""
1730

1831
value: Union[Text, Checklist, Radio, Dropdown]
1932

2033

2134
class ObjectAnnotation(BaseAnnotation):
22-
"""Class representing objects annotations (non classifications or annotations that have a location)
35+
"""Generic localized annotation (non classifications)
36+
37+
>>> ObjectAnnotation(
38+
>>> value=Rectangle(
39+
>>> start=Point(x=0, y=0),
40+
>>> end=Point(x=1, y=1)
41+
>>> ),
42+
>>> feature_schema_id="my-feature-schema-id"
43+
>>> )
44+
45+
Args:
46+
name (Optional[str])
47+
feature_schema_id (Optional[Cuid])
48+
value (Union[TextEntity, Geometry]): Localization of the annotation
49+
classifications (Optional[List[ClassificationAnnotation]]): Optional sub classification of the annotation
50+
extra (Dict[str, Any])
2351
"""
2452
value: Union[TextEntity, Geometry]
2553
classifications: List[ClassificationAnnotation] = []
2654

2755

2856
class VideoObjectAnnotation(ObjectAnnotation):
29-
"""
30-
Class for video objects annotations
57+
"""Video object annotation
58+
59+
>>> VideoObjectAnnotation(
60+
>>> keyframe=True,
61+
>>> frame=10,
62+
>>> value=Rectangle(
63+
>>> start=Point(x=0, y=0),
64+
>>> end=Point(x=1, y=1)
65+
>>> ),
66+
>>> feature_schema_id="my-feature-schema-id"
67+
>>>)
3168
3269
Args:
33-
frame: The frame index that this annotation corresponds to
34-
keyframe: Whether or not this annotation was a human generated or interpolated annotation
70+
name (Optional[str])
71+
feature_schema_id (Optional[Cuid])
72+
value (Geometry)
73+
frame (Int): The frame index that this annotation corresponds to
74+
keyframe (bool): Whether or not this annotation was a human generated or interpolated annotation
75+
classifications (List[ClassificationAnnotation]) = []
76+
extra (Dict[str, Any])
3577
"""
3678
frame: int
3779
keyframe: bool
3880

3981

4082
class VideoClassificationAnnotation(ClassificationAnnotation):
41-
"""
42-
Class for video classification annotations
83+
"""Video classification
4384
4485
Args:
45-
frame: The frame index that this annotation corresponds to
86+
name (Optional[str])
87+
feature_schema_id (Optional[Cuid])
88+
value (Union[Text, Checklist, Radio, Dropdown])
89+
frame (int): The frame index that this annotation corresponds to
90+
extra (Dict[str, Any])
4691
"""
4792
frame: int

labelbox/data/annotation_types/data/raster.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,53 @@
1-
from typing import Callable, Optional
2-
from io import BytesIO
31
from abc import ABC
4-
2+
from io import BytesIO
3+
from typing import Callable, Optional, Union
4+
from typing_extensions import Literal
55
import numpy as np
6-
from pydantic import BaseModel
76
import requests
7+
from PIL import Image
88
from google.api_core import retry
9-
from typing_extensions import Literal
9+
from pydantic import BaseModel
1010
from pydantic import root_validator
11-
from PIL import Image
1211

1312
from .base_data import BaseData
1413
from ..types import TypedArray
1514

1615

1716
class RasterData(BaseModel, ABC):
18-
"""
19-
Represents an image or segmentation mask.
17+
"""Represents an image or segmentation mask.
2018
"""
2119
im_bytes: Optional[bytes] = None
2220
file_path: Optional[str] = None
2321
url: Optional[str] = None
2422
arr: Optional[TypedArray[Literal['uint8']]] = None
2523

2624
@classmethod
27-
def from_2D_arr(cls, arr: TypedArray[Literal['uint8']], **kwargs):
25+
def from_2D_arr(cls, arr: Union[TypedArray[Literal['uint8']],
26+
TypedArray[Literal['int']]], **kwargs):
27+
"""Construct from a 2D numpy array
28+
29+
Args:
30+
arr: uint8 compatible numpy array
31+
32+
Returns:
33+
RasterData
34+
"""
35+
2836
if len(arr.shape) != 2:
2937
raise ValueError(
30-
f"Found array with shape {arr.shape}. Expected two dimensions ([W,H])"
38+
f"Found array with shape {arr.shape}. Expected two dimensions [H, W]"
39+
)
40+
41+
if not np.issubdtype(arr.dtype, np.integer):
42+
raise ValueError("Array must be an integer subtype")
43+
44+
if np.can_cast(arr, np.uint8):
45+
arr = arr.astype(np.uint8)
46+
else:
47+
raise ValueError(
48+
"Could not cast array to uint8, check that values are between 0 and 255"
3149
)
50+
3251
arr = np.stack((arr,) * 3, axis=-1)
3352
return cls(arr=arr, **kwargs)
3453

@@ -153,10 +172,10 @@ def validate_args(cls, values):
153172

154173
def __repr__(self) -> str:
155174
symbol_or_none = lambda data: '...' if data is not None else None
156-
return f"{self.__class__.__name__}(im_bytes={symbol_or_none(self.im_bytes)}," \
157-
f"file_path={self.file_path}," \
158-
f"url={self.url}," \
159-
f"arr={symbol_or_none(self.arr)})"
175+
return f"{self.__class__.__name__}(im_bytes={symbol_or_none(self.im_bytes)}," \
176+
f"file_path={self.file_path}," \
177+
f"url={self.url}," \
178+
f"arr={symbol_or_none(self.arr)})"
160179

161180
class Config:
162181
# Required for sharing references
@@ -166,7 +185,25 @@ class Config:
166185

167186

168187
class MaskData(RasterData):
169-
...
188+
"""Used to represent a segmentation Mask
189+
190+
All segments within a mask must be mutually exclusive. At a
191+
single cell, only one class can be present. All Mask data is
192+
converted to a [H,W,3] image. Classes are
193+
194+
>>> # 3x3 mask with two classes and back ground
195+
>>> MaskData.from_2D_arr([
196+
>>> [0, 0, 0],
197+
>>> [1, 1, 1],
198+
>>> [2, 2, 2],
199+
>>>])
200+
201+
Args:
202+
im_bytes: Optional[bytes] = None
203+
file_path: Optional[str] = None
204+
url: Optional[str] = None
205+
arr: Optional[TypedArray[Literal['uint8']]] = None
206+
"""
170207

171208

172209
class ImageData(RasterData, BaseData):

labelbox/data/annotation_types/geometry/geometry.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88

99

1010
class Geometry(BaseModel, ABC):
11-
"""
12-
Base class for geometry objects.
11+
"""Abstract base class for geometry objects
1312
"""
1413
extra: Dict[str, Any] = {}
1514

labelbox/data/annotation_types/geometry/line.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99

1010

1111
class Line(Geometry):
12+
"""Line annotation
13+
14+
Args:
15+
points (List[Point]): A list of `Point` geometries
16+
17+
"""
1218
points: List[Point]
1319

1420
@property

labelbox/data/annotation_types/geometry/mask.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,30 @@
1111

1212

1313
class Mask(Geometry):
14-
# Mask data can be shared across multiple masks
14+
"""Mask used to represent a single class in a larger segmentation mask
15+
16+
Example of a mutually exclusive class
17+
18+
>>> arr = MaskData.from_2D_arr([
19+
>>> [0, 0, 0],
20+
>>> [1, 1, 1],
21+
>>> [2, 2, 2],
22+
>>>])
23+
>>> annotations = [
24+
>>> ObjectAnnotation(value=Mask(mask=arr, color=1), name="dog"),
25+
>>> ObjectAnnotation(value=Mask(mask=arr, color=2), name="cat"),
26+
>>>]
27+
28+
Args:
29+
mask (MaskData): An object containing the actual mask, `MaskData` can
30+
be shared across multiple `Masks` to more efficiently store data
31+
for mutually exclusive segmentations.
32+
color (Tuple[uint8, uint8, uint8]): RGB color or a single value
33+
indicating the values of the class in the `MaskData`
34+
"""
35+
1536
mask: MaskData
16-
color: Tuple[int, int, int]
37+
color: Union[Tuple[int, int, int], int]
1738

1839
@property
1940
def geometry(self):
@@ -31,8 +52,7 @@ def draw(self,
3152
canvas: Optional[np.ndarray] = None,
3253
color: Optional[Union[int, Tuple[int, int, int]]] = None,
3354
thickness=None) -> np.ndarray:
34-
"""
35-
Converts the Mask object into a numpy array
55+
"""Converts the Mask object into a numpy array
3656
3757
Args:
3858
height (int): Optionally resize mask height before drawing.
@@ -43,6 +63,7 @@ def draw(self,
4363
int will return the mask as a 1d array
4464
tuple[int,int,int] will return the mask as a 3d array
4565
thickness (None): Unused, exists for a consistent interface.
66+
4667
Returns:
4768
np.ndarray representing only this object
4869
as opposed to the mask that this object references which might have multiple objects determined by colors
@@ -79,8 +100,9 @@ def create_url(self, signer: Callable[[bytes], str]) -> str:
79100

80101
@validator('color')
81102
def is_valid_color(cls, color):
82-
#Does the dtype matter? Can it be a float?
83103
if isinstance(color, (tuple, list)):
104+
if len(color) == 1:
105+
color = [color[0]] * 3
84106
if len(color) != 3:
85107
raise ValueError(
86108
"Segmentation colors must be either a (r,g,b) tuple or a single grayscale value"

labelbox/data/annotation_types/geometry/point.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,21 @@
88

99

1010
class Point(Geometry):
11+
"""Point geometry
12+
13+
>>> Point(x=0, y=0)
14+
15+
Args:
16+
x (float)
17+
y (float)
18+
19+
"""
1120
x: float
1221
y: float
1322

1423
@property
1524
def geometry(self) -> geojson.Point:
16-
return geojson.Point([self.x, self.y])
25+
return geojson.Point((self.x, self.y))
1726

1827
def draw(self,
1928
height: Optional[int] = None,

labelbox/data/annotation_types/geometry/polygon.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,34 @@
11
from typing import List, Optional, Union, Tuple
22

3-
import numpy as np
4-
import geojson
53
import cv2
4+
import geojson
5+
import numpy as np
66
from pydantic import validator
77

8-
from .point import Point
98
from .geometry import Geometry
9+
from .point import Point
1010

1111

1212
class Polygon(Geometry):
13+
"""Polygon geometry
14+
15+
A polygon is created from a collection of points
16+
17+
>>> Polygon(points=[Point(x=0, y=0), Point(x=1, y=0), Point(x=1, y=1), Point(x=0, y=0)])
18+
19+
Args:
20+
points (List[Point]): List of `Points`, minimum of three points. If you do not
21+
close the polygon (the last point and first point are the same) an additional
22+
point is added to close it.
23+
24+
"""
1325
points: List[Point]
1426

1527
@property
16-
def geometry(self) -> geojson.MultiPolygon:
28+
def geometry(self) -> geojson.Polygon:
1729
if self.points[0] != self.points[-1]:
1830
self.points.append(self.points[0])
19-
return geojson.Polygon([[[point.x, point.y] for point in self.points]])
31+
return geojson.Polygon([[(point.x, point.y) for point in self.points]])
2032

2133
def draw(self,
2234
height: Optional[int] = None,
@@ -48,4 +60,7 @@ def is_geom_valid(cls, points):
4860
raise ValueError(
4961
f"A polygon must have at least 3 points to be valid. Found {points}"
5062
)
63+
if points[0] != points[-1]:
64+
points.append(points[0])
65+
5166
return points

labelbox/data/annotation_types/geometry/rectangle.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99

1010

1111
class Rectangle(Geometry):
12-
"""
13-
Represents a 2d rectangle. Also known as a bounding box.
12+
"""Represents a 2d rectangle. Also known as a bounding box
13+
14+
>>> Rectangle(start=Point(x=0, y=0), end=Point(x=1, y=1))
1415
15-
start: Top left coordinate of the rectangle
16-
end: Bottom right coordinate of the rectangle
16+
Args:
17+
start (Point): Top left coordinate of the rectangle
18+
end (Point): Bottom right coordinate of the rectangle
1719
"""
1820
start: Point
1921
end: Point
@@ -51,3 +53,8 @@ def draw(self,
5153
if thickness == -1:
5254
return cv2.fillPoly(canvas, pts, color)
5355
return cv2.polylines(canvas, pts, True, color, thickness)
56+
57+
@classmethod
58+
def from_xyhw(cls, x: float, y: float, h: float, w: float):
59+
"""Create Rectangle from x,y, height width format"""
60+
return cls(start=Point(x=x, y=y), end=Point(x=x + w, y=y + h))

0 commit comments

Comments
 (0)