Skip to content

Commit fa2c9fd

Browse files
author
gdj0nes
committed
Resolve comments and handle uint8 conversion
1 parent 8421933 commit fa2c9fd

File tree

2 files changed

+24
-19
lines changed

2 files changed

+24
-19
lines changed

labelbox/data/annotation_types/data/raster.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,51 @@
1-
from typing import Callable, Optional
2-
from io import BytesIO
31
from abc import ABC
2+
from io import BytesIO
3+
from typing import Callable, Optional
44

55
import numpy as np
6-
from pydantic import BaseModel
76
import requests
7+
from PIL import Image
88
from google.api_core import retry
9-
from typing_extensions import Literal
9+
from pydantic import BaseModel
1010
from pydantic import root_validator
11-
from PIL import Image
11+
from typing_extensions import Literal
1212

1313
from .base_data import BaseData
1414
from ..types import TypedArray
1515

1616

1717
class RasterData(BaseModel, ABC):
1818
"""Represents an image or segmentation mask.
19-
2019
"""
2120
im_bytes: Optional[bytes] = None
2221
file_path: Optional[str] = None
2322
url: Optional[str] = None
2423
arr: Optional[TypedArray[Literal['uint8']]] = None
2524

26-
2725
@classmethod
28-
def from_2D_arr(cls, arr: TypedArray[Literal['uint8']], **kwargs):
29-
"""Construct
26+
def from_2D_arr(cls, arr: Union[TypedArray[Literal['uint8']], TypedArray[Literal['int']]], **kwargs):
27+
"""Construct from a 2D numpy array
3028
3129
Args:
32-
arr:
33-
**kwargs:
30+
arr: uint8 compatible numpy array
3431
3532
Returns:
36-
33+
RasterData
3734
"""
3835

3936
if len(arr.shape) != 2:
4037
raise ValueError(
4138
f"Found array with shape {arr.shape}. Expected two dimensions [H, W]"
4239
)
40+
41+
if not np.issubdtype(arr.dtype, np.integer):
42+
raise ValueError("Array must be an integer subtype")
43+
44+
if np.can_cast(arr, np.uint8):
45+
arr = arr.astype(np.uint8)
46+
else:
47+
raise ValueError("Could not cast array to uint8, check that values are between 0 and 255")
48+
4349
arr = np.stack((arr,) * 3, axis=-1)
4450
return cls(arr=arr, **kwargs)
4551

@@ -164,10 +170,10 @@ def validate_args(cls, values):
164170

165171
def __repr__(self) -> str:
166172
symbol_or_none = lambda data: '...' if data is not None else None
167-
return f"{self.__class__.__name__}(im_bytes={symbol_or_none(self.im_bytes)}," \
168-
f"file_path={self.file_path}," \
169-
f"url={self.url}," \
170-
f"arr={symbol_or_none(self.arr)})"
173+
return f"{self.__class__.__name__}(im_bytes={symbol_or_none(self.im_bytes)}," \
174+
f"file_path={self.file_path}," \
175+
f"url={self.url}," \
176+
f"arr={symbol_or_none(self.arr)})"
171177

172178
class Config:
173179
# Required for sharing references
@@ -198,6 +204,5 @@ class MaskData(RasterData):
198204
"""
199205

200206

201-
202207
class ImageData(RasterData, BaseData):
203208
...

labelbox/data/annotation_types/geometry/mask.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ class Mask(Geometry):
2323
>>> annotations = [
2424
>>> ObjectAnnotation(value=Mask(mask=arr, color=1), name="dog"),
2525
>>> ObjectAnnotation(value=Mask(mask=arr, color=2), name="cat"),
26-
>>> ]
26+
>>>]
2727
2828
Args:
29-
mask (MaskData): A object containing the actual mask, `MaskData` can
29+
mask (MaskData): An object containing the actual mask, `MaskData` can
3030
be shared across multiple `Masks` to more efficiently store data
3131
for mutually exclusive segmentations.
3232
color (Tuple[uint8, uint8, uint8]): RGB color or a single value

0 commit comments

Comments
 (0)