Skip to content

Commit 944865a

Browse files
authored
Merge pull request #204 from Labelbox/ms/annotation-updates
annotation clean up
2 parents ed16d64 + a2856fb commit 944865a

26 files changed

+296
-114
lines changed

labelbox/data/annotation_types/annotation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Any, Dict, List, Union
22

3+
from pydantic.main import BaseModel
4+
35
from .classification import Checklist, Dropdown, Radio, Text
46
from .feature import FeatureSchema
57
from .geometry import Geometry
@@ -9,7 +11,6 @@
911
class BaseAnnotation(FeatureSchema):
1012
""" Base annotation class. Shouldn't be directly instantiated
1113
"""
12-
1314
extra: Dict[str, Any] = {}
1415

1516

labelbox/data/annotation_types/collection.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66
from tqdm import tqdm
77

8-
from labelbox import OntologyBuilder
8+
from labelbox.schema import ontology
99
from labelbox.orm.model import Entity
10+
from ..ontology import get_classifications, get_tools
1011
from ..generator import PrefetchGenerator
1112
from .label import Label
1213

@@ -24,8 +25,8 @@ def __init__(self, data: Iterable[Label]):
2425
self._data = data
2526
self._index = 0
2627

27-
def assign_schema_ids(self,
28-
ontology_builder: OntologyBuilder) -> "LabelList":
28+
def assign_schema_ids(
29+
self, ontology_builder: "ontology.OntologyBuilder") -> "LabelList":
2930
"""
3031
Adds schema ids to all FeatureSchema objects in the Labels.
3132
This is necessary for MAL.
@@ -110,6 +111,16 @@ def add_url_to_data(self, signer, max_concurrency=20) -> "LabelList":
110111
...
111112
return self
112113

114+
def get_ontology(self) -> ontology.OntologyBuilder:
115+
classifications = []
116+
tools = []
117+
for label in self._data:
118+
tools = get_tools(label.object_annotations(), tools)
119+
classifications = get_classifications(
120+
label.classification_annotations(), classifications)
121+
return ontology.OntologyBuilder(tools=tools,
122+
classifications=classifications)
123+
113124
def _ensure_unique_external_ids(self) -> None:
114125
external_ids = set()
115126
for label in self._data:
@@ -122,6 +133,9 @@ def _ensure_unique_external_ids(self) -> None:
122133
)
123134
external_ids.add(label.data.external_id)
124135

136+
def append(self, label: Label):
137+
self._data.append(label)
138+
125139
def __iter__(self) -> "LabelList":
126140
self._index = 0
127141
return self
@@ -166,7 +180,8 @@ def as_list(self) -> "LabelList":
166180
return LabelList(data=list(self))
167181

168182
def assign_schema_ids(
169-
self, ontology_builder: OntologyBuilder) -> "LabelGenerator":
183+
self,
184+
ontology_builder: "ontology.OntologyBuilder") -> "LabelGenerator":
170185

171186
def _assign_ids(label: Label):
172187
label.assign_schema_ids(ontology_builder)

labelbox/data/annotation_types/data/raster.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@ class RasterData(BaseData):
2020
url: Optional[str] = None
2121
arr: Optional[TypedArray[Literal['uint8']]] = None
2222

23+
@classmethod
24+
def from_2D_arr(cls, arr: TypedArray[Literal['uint8']], **kwargs):
25+
if len(arr.shape):
26+
raise ValueError(
27+
f"Found array with shape {arr.shape}. Expected two dimensions ([W,H])"
28+
)
29+
arr = np.stack((arr,) * 3, axis=-1)
30+
return cls(arr=arr, **kwargs)
31+
2332
def bytes_to_np(self, image_bytes: bytes) -> np.ndarray:
2433
"""
2534
Converts image bytes to a numpy array
@@ -38,9 +47,9 @@ def np_to_bytes(self, arr: np.ndarray) -> bytes:
3847
Returns:
3948
png encoded bytes
4049
"""
41-
if len(arr.shape) not in [2, 3]:
42-
raise ValueError("unsupported image format")
43-
50+
if len(arr.shape) != 3:
51+
raise ValueError("unsupported image format. Must be 3D ([H,W,C])."
52+
"Use RasterData.from_2D_arr to construct from 2D")
4453
if arr.dtype != np.uint8:
4554
raise TypeError(f"image data type must be uint8. Found {arr.dtype}")
4655

@@ -72,6 +81,9 @@ def data(self) -> np.ndarray:
7281
else:
7382
raise ValueError("Must set either url, file_path or im_bytes")
7483

84+
def set_fetch_fn(self, fn):
85+
object.__setattr__(self, 'fetch_remote', lambda: fn(self))
86+
7587
def fetch_remote(self) -> bytes:
7688
"""
7789
Method for accessing url.
@@ -122,12 +134,19 @@ def validate_args(cls, values):
122134
raise TypeError(
123135
"Numpy array representing segmentation mask must be np.uint8"
124136
)
125-
elif len(arr.shape) not in [2, 3]:
126-
raise TypeError(
127-
f"Numpy array must have 2 or 3 dims. Found shape {arr.shape}"
128-
)
137+
elif len(arr.shape) != 3:
138+
raise ValueError(
139+
"unsupported image format. Must be 3D ([H,W,C])."
140+
"Use RasterData.from_2D_arr to construct from 2D")
129141
return values
130142

143+
def __repr__(self) -> str:
144+
symbol_or_none = lambda data: '...' if data is not None else None
145+
return f"RasterData(im_bytes={symbol_or_none(self.im_bytes)}," \
146+
f"file_path={self.file_path}," \
147+
f"url={self.url}," \
148+
f"arr={symbol_or_none(self.arr)})"
149+
131150
class Config:
132151
# Required for sharing references
133152
copy_on_model_validation = False

labelbox/data/annotation_types/data/text.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ def data(self) -> str:
3636
else:
3737
raise ValueError("Must set either url, file_path or im_bytes")
3838

39+
def set_fetch_fn(self, fn):
40+
object.__setattr__(self, 'fetch_remote', lambda: fn(self))
41+
3942
def fetch_remote(self) -> str:
4043
"""
4144
Method for accessing url.
@@ -79,6 +82,11 @@ def validate_date(cls, values):
7982
"One of `file_path`, `text`, `uid`, or `url` required.")
8083
return values
8184

85+
def __repr__(self) -> str:
86+
return f"TextData(file_path={self.file_path}," \
87+
f"text={self.text[:30] + '...' if self.text is not None else None}," \
88+
f"url={self.url})"
89+
8290
class config:
8391
# Required for discriminating between data types
8492
extra = 'forbid'

labelbox/data/annotation_types/data/video.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ def __getitem__(self, idx: int) -> np.ndarray:
8686
)
8787
return self.frames[idx]
8888

89+
def set_fetch_fn(self, fn):
90+
object.__setattr__(self, 'fetch_remote', lambda: fn(self))
91+
8992
def fetch_remote(self, local_path) -> None:
9093
"""
9194
Method for downloading data from self.url
@@ -153,6 +156,11 @@ def validate_data(cls, values):
153156
"One of `file_path`, `frames`, `uid`, or `url` required.")
154157
return values
155158

159+
def __repr__(self) -> str:
160+
return f"TextData(file_path={self.file_path}," \
161+
f"frames={'...' if self.frames is not None else None}," \
162+
f"url={self.url})"
163+
156164
class Config:
157165
# Required for discriminating between data types
158166
extra = 'forbid'

labelbox/data/annotation_types/geometry/line.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ def raster(self,
2020
height: int,
2121
width: int,
2222
thickness=1,
23-
color=255) -> np.ndarray:
23+
color=(255, 255, 255)) -> np.ndarray:
2424
"""
25-
Draw the line onto a 2d mask
25+
Draw the line onto a 3d mask
2626
2727
Args:
2828
height (int): height of the mask
@@ -32,8 +32,7 @@ def raster(self,
3232
Returns:
3333
numpy array representing the mask with the line drawn on it.
3434
"""
35-
36-
canvas = np.zeros((height, width), dtype=np.uint8)
35+
canvas = np.zeros((height, width, 3), dtype=np.uint8)
3736
pts = np.array(self.geometry['coordinates']).astype(np.int32)
3837
return cv2.polylines(canvas,
3938
pts,

labelbox/data/annotation_types/geometry/mask.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from typing import Callable, Tuple, Union
1+
from typing import Callable, Optional, Tuple, Union
22

33
import numpy as np
44
from pydantic.class_validators import validator
55
from rasterio.features import shapes
66
from shapely.geometry import MultiPolygon, shape
7+
import cv2
78

89
from ..data.raster import RasterData
910
from .geometry import Geometry
@@ -13,7 +14,7 @@ class Mask(Geometry):
1314
# Raster data can be shared across multiple masks... or not
1415
mask: RasterData
1516
# RGB or Grayscale
16-
color: Union[int, Tuple[int, int, int]]
17+
color: Tuple[int, int, int]
1718

1819
@property
1920
def geometry(self):
@@ -25,21 +26,26 @@ def geometry(self):
2526
if val >= 1)
2627
return MultiPolygon(polygons).__geo_interface__
2728

28-
def raster(self, binary=False) -> np.ndarray:
29+
def raster(self,
30+
height: Optional[int] = None,
31+
width: Optional[int] = None,
32+
binary=False) -> np.ndarray:
2933
"""
3034
Removes all pixels from the segmentation mask that do not equal self.color
3135
36+
Args:
37+
height:
38+
3239
Returns:
3340
np.ndarray representing only this object
3441
"""
3542
mask = self.mask.data
36-
if len(mask.shape) == 2:
37-
mask = np.expand_dims(mask, axis=-1)
3843
mask = np.alltrue(mask == self.color, axis=2).astype(np.uint8)
44+
if height is not None or width is not None:
45+
mask = cv2.resize(mask,
46+
(width or mask.shape[1], height or mask.shape[0]))
3947
if binary:
4048
return mask
41-
elif isinstance(self.color, int):
42-
return mask * self.color
4349
else:
4450
color_image = np.zeros((mask.shape[0], mask.shape[1], 3),
4551
dtype=np.uint8)

labelbox/data/annotation_types/geometry/point.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ def raster(self,
1717
height: int,
1818
width: int,
1919
thickness: int = 1,
20-
color=255) -> np.ndarray:
20+
color=(255, 255, 255)) -> np.ndarray:
2121
"""
22-
Draw the point onto a 2d mask
22+
Draw the point onto a 3d mask
2323
2424
Args:
2525
height (int): height of the mask
@@ -29,7 +29,7 @@ def raster(self,
2929
Returns:
3030
numpy array representing the mask with the point drawn on it.
3131
"""
32-
canvas = np.zeros((height, width), dtype=np.uint8)
32+
canvas = np.zeros((height, width, 3), dtype=np.uint8)
3333
return cv2.circle(canvas, (int(self.x), int(self.y)),
3434
radius=thickness,
3535
color=color,

labelbox/data/annotation_types/geometry/polygon.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ def geometry(self) -> geojson.MultiPolygon:
1818
self.points.append(self.points[0])
1919
return geojson.Polygon([[[point.x, point.y] for point in self.points]])
2020

21-
def raster(self, height: int, width: int, color=255) -> np.ndarray:
21+
def raster(self, height: int, width: int,
22+
color=(255, 255, 255)) -> np.ndarray:
2223
"""
23-
Draw the polygon onto a 2d mask
24+
Draw the polygon onto a 3d mask
2425
2526
Args:
2627
height (int): height of the mask
@@ -29,7 +30,7 @@ def raster(self, height: int, width: int, color=255) -> np.ndarray:
2930
Returns:
3031
numpy array representing the mask with the polygon drawn on it.
3132
"""
32-
canvas = np.zeros((height, width), dtype=np.uint8)
33+
canvas = np.zeros((height, width, 3), dtype=np.uint8)
3334
pts = np.array(self.geometry['coordinates']).astype(np.int32)
3435
return cv2.fillPoly(canvas, pts=pts, color=color)
3536

labelbox/data/annotation_types/geometry/rectangle.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@ def geometry(self) -> geojson.geometry.Geometry:
2626
[self.start.x, self.start.y],
2727
]])
2828

29-
def raster(self, height: int, width: int, color: int = 255) -> np.ndarray:
29+
def raster(self, height: int, width: int,
30+
color=(255, 255, 255)) -> np.ndarray:
3031
"""
31-
Draw the rectangle onto a 2d mask
32+
Draw the rectangle onto a 3d mask
3233
3334
Args:
3435
height (int): height of the mask
@@ -37,6 +38,6 @@ def raster(self, height: int, width: int, color: int = 255) -> np.ndarray:
3738
Returns:
3839
numpy array representing the mask with the rectangle drawn on it.
3940
"""
40-
canvas = np.zeros((height, width), dtype=np.uint8)
41+
canvas = np.zeros((height, width, 3), dtype=np.uint8)
4142
pts = np.array(self.geometry['coordinates']).astype(np.int32)
4243
return cv2.fillPoly(canvas, pts=pts, color=color)

0 commit comments

Comments
 (0)