|
1 |
| -from typing import Callable, Optional |
2 |
| -from io import BytesIO |
3 | 1 | from abc import ABC
|
| 2 | +from io import BytesIO |
| 3 | +from typing import Callable, Optional |
4 | 4 |
|
5 | 5 | import numpy as np
|
6 |
| -from pydantic import BaseModel |
7 | 6 | import requests
|
| 7 | +from PIL import Image |
8 | 8 | from google.api_core import retry
|
9 |
| -from typing_extensions import Literal |
| 9 | +from pydantic import BaseModel |
10 | 10 | from pydantic import root_validator
|
11 |
| -from PIL import Image |
| 11 | +from typing_extensions import Literal |
12 | 12 |
|
13 | 13 | from .base_data import BaseData
|
14 | 14 | from ..types import TypedArray
|
15 | 15 |
|
16 | 16 |
|
17 | 17 | class RasterData(BaseModel, ABC):
|
18 | 18 | """Represents an image or segmentation mask.
|
19 |
| -
|
20 | 19 | """
|
21 | 20 | im_bytes: Optional[bytes] = None
|
22 | 21 | file_path: Optional[str] = None
|
23 | 22 | url: Optional[str] = None
|
24 | 23 | arr: Optional[TypedArray[Literal['uint8']]] = None
|
25 | 24 |
|
26 |
| - |
27 | 25 | @classmethod
|
28 |
| - def from_2D_arr(cls, arr: TypedArray[Literal['uint8']], **kwargs): |
29 |
| - """Construct |
| 26 | + def from_2D_arr(cls, arr: Union[TypedArray[Literal['uint8']], TypedArray[Literal['int']]], **kwargs): |
| 27 | + """Construct from a 2D numpy array |
30 | 28 |
|
31 | 29 | Args:
|
32 |
| - arr: |
33 |
| - **kwargs: |
| 30 | + arr: uint8 compatible numpy array |
34 | 31 |
|
35 | 32 | Returns:
|
36 |
| -
|
| 33 | + RasterData |
37 | 34 | """
|
38 | 35 |
|
39 | 36 | if len(arr.shape) != 2:
|
40 | 37 | raise ValueError(
|
41 | 38 | f"Found array with shape {arr.shape}. Expected two dimensions [H, W]"
|
42 | 39 | )
|
| 40 | + |
| 41 | + if not np.issubdtype(arr.dtype, np.integer): |
| 42 | + raise ValueError("Array must be an integer subtype") |
| 43 | + |
| 44 | + if np.can_cast(arr, np.uint8): |
| 45 | + arr = arr.astype(np.uint8) |
| 46 | + else: |
| 47 | + raise ValueError("Could not cast array to uint8, check that values are between 0 and 255") |
| 48 | + |
43 | 49 | arr = np.stack((arr,) * 3, axis=-1)
|
44 | 50 | return cls(arr=arr, **kwargs)
|
45 | 51 |
|
@@ -164,10 +170,10 @@ def validate_args(cls, values):
|
164 | 170 |
|
165 | 171 | def __repr__(self) -> str:
|
166 | 172 | symbol_or_none = lambda data: '...' if data is not None else None
|
167 |
| - return f"{self.__class__.__name__}(im_bytes={symbol_or_none(self.im_bytes)}," \ |
168 |
| - f"file_path={self.file_path}," \ |
169 |
| - f"url={self.url}," \ |
170 |
| - f"arr={symbol_or_none(self.arr)})" |
| 173 | + return f"{self.__class__.__name__}(im_bytes={symbol_or_none(self.im_bytes)}," \ |
| 174 | + f"file_path={self.file_path}," \ |
| 175 | + f"url={self.url}," \ |
| 176 | + f"arr={symbol_or_none(self.arr)})" |
171 | 177 |
|
172 | 178 | class Config:
|
173 | 179 | # Required for sharing references
|
@@ -198,6 +204,5 @@ class MaskData(RasterData):
|
198 | 204 | """
|
199 | 205 |
|
200 | 206 |
|
201 |
| - |
202 | 207 | class ImageData(RasterData, BaseData):
|
203 | 208 | ...
|
0 commit comments