Skip to content

Commit d32ccee

Browse files
committed
update
1 parent f85bc0d commit d32ccee

File tree

1 file changed

+69
-49
lines changed

1 file changed

+69
-49
lines changed

labelbox/data/annotation_types/data/tiled_image.py

Lines changed: 69 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,25 @@
33
from enum import Enum
44
from typing import Optional, List, Tuple
55
from concurrent.futures import ThreadPoolExecutor
6+
from io import BytesIO
67

78
import requests
89
import numpy as np
9-
import tensorflow as tf
10+
1011
from retry import retry
12+
import tensorflow as f
13+
from PIL import Image
1114
from pyproj import Transformer
1215
from pydantic import BaseModel, validator, conlist
1316
from pydantic.class_validators import root_validator
1417

1518
from ..geometry import Point
1619
from .base_data import BaseData
1720
from .raster import RasterData
18-
"""TODO: consider how to swap lat,lng to lng,lt when version = 2...
19-
should the bounds validator be inside the TiledImageData class then,
20-
since we need to check on Version?
21-
"""
2221

2322
VALID_LAT_RANGE = range(-90, 90)
2423
VALID_LNG_RANGE = range(-180, 180)
25-
TMS_TILE_SIZE = 256
26-
MAX_TILES = 300 #TODO: thinking of how to choose an appropriate max tiles number. 18 seems too small, but over 1000 likely seems too large
24+
DEFAULT_TMS_TILE_SIZE = 256
2725
TILE_DOWNLOAD_CONCURRENCY = 4
2826

2927
logging.basicConfig(level=logging.INFO)
@@ -40,7 +38,6 @@ class EPSG(Enum):
4038
SIMPLEPIXEL = 1
4139
EPSG4326 = 4326
4240
EPSG3857 = 3857
43-
EPSG3395 = 3395
4441

4542

4643
class TiledBounds(BaseModel):
@@ -73,16 +70,15 @@ def validate_bounds_not_equal(cls, bounds):
7370
def validate_bounds_lat_lng(cls, values):
7471
epsg = values.get('epsg')
7572
bounds = values.get('bounds')
76-
#TODO: look into messaging that we only support 4326 right now. raise exception, not implemented
7773

7874
if epsg != EPSG.SIMPLEPIXEL:
7975
for bound in bounds:
8076
lat, lng = bound.y, bound.x
8177
if int(lng) not in VALID_LNG_RANGE or int(
8278
lat) not in VALID_LAT_RANGE:
8379
raise ValueError(f"Invalid lat/lng bounds. Found {bounds}. "
84-
"lat must be in {VALID_LAT_RANGE}. "
85-
"lng must be in {VALID_LNG_RANGE}.")
80+
f"lat must be in {VALID_LAT_RANGE}. "
81+
f"lng must be in {VALID_LNG_RANGE}.")
8682
return values
8783

8884

@@ -103,7 +99,7 @@ class TileLayer(BaseModel):
10399
def validate_url(cls, url):
104100
xyz_format = "/{z}/{x}/{y}"
105101
if xyz_format not in url:
106-
raise AssertionError(f"{url} needs to contain {xyz_format}")
102+
raise ValueError(f"{url} needs to contain {xyz_format}")
107103
return url
108104

109105

@@ -129,12 +125,16 @@ class TiledImageData(BaseData):
129125
tile_layer: TileLayer
130126
tile_bounds: TiledBounds
131127
alternative_layers: List[TileLayer] = None
132-
zoom_levels: conlist(item_type=int, min_items=2, max_items=2)
133-
max_native_zoom: int = None
134-
tile_size: Optional[int] = TMS_TILE_SIZE
135-
version: int = 2
136-
137-
def _as_raster(self, zoom=0) -> RasterData:
128+
zoom_levels: Tuple[int, int]
129+
max_native_zoom: Optional[int] = None
130+
tile_size: Optional[int] = DEFAULT_TMS_TILE_SIZE
131+
version: Optional[int] = 2
132+
multithread: bool = True
133+
134+
def as_raster_data(self,
135+
zoom: int = 0,
136+
max_tiles: int = 32,
137+
multithread=True) -> RasterData:
138138
"""Converts the tiled image asset into a RasterData object containing an
139139
np.ndarray.
140140
@@ -146,34 +146,29 @@ def _as_raster(self, zoom=0) -> RasterData:
146146
# Currently our editor doesn't support anything other than 3857.
147147
# Since the user provided projection is ignored by the editor
148148
# we will ignore it here and assume that the projection is 3857.
149-
else:
150-
if self.tile_bounds.epsg != EPSG.EPSG3857:
151-
logger.info(
152-
f"User provided EPSG is being ignored {self.tile_bounds.epsg}."
153-
)
149+
elif self.tile_bounds.epsg == EPSG.EPSG3857:
154150
xstart, ystart, xend, yend = self._get_3857_image_params(zoom)
151+
else:
152+
raise ValueError(
153+
f"Unsupported epsg found...{self.tile_bounds.epsg}")
155154

156-
total_n_tiles = (yend - ystart + 1) * (xend - xstart + 1)
157-
if total_n_tiles > MAX_TILES:
158-
logger.info(
159-
f"Too many tiles requested. Total tiles attempted {total_n_tiles}."
160-
)
161-
return None
155+
self._validate_num_tiles(xstart, ystart, xend, yend, max_tiles)
162156

163157
rounded_tiles, pixel_offsets = list(
164158
zip(*[
165159
self._tile_to_pixel(pt) for pt in [xstart, ystart, xend, yend]
166160
]))
167161

168-
image = self._fetch_image_for_bounds(*rounded_tiles, zoom)
162+
image = self._fetch_image_for_bounds(*rounded_tiles, zoom, multithread)
169163
arr = self._crop_to_bounds(image, *pixel_offsets)
170164
return RasterData(arr=arr)
171165

172166
@property
173167
def value(self) -> np.ndarray:
174168
"""Returns the value of a generated RasterData object.
175169
"""
176-
return self._as_raster(self.zoom_levels[0]).value
170+
return self.as_raster_data(self.zoom_levels[0],
171+
multithread=self.multithread).value
177172

178173
def _get_simple_image_params(self,
179174
zoom) -> Tuple[float, float, float, float]:
@@ -234,29 +229,43 @@ def _tile_to_pixel(self, tile: float) -> Tuple[int, int]:
234229
pixel_offset = int(self.tile_size * remainder)
235230
return rounded_tile, pixel_offset
236231

237-
def _fetch_image_for_bounds(
238-
self,
239-
x_tile_start: int,
240-
y_tile_start: int,
241-
x_tile_end: int,
242-
y_tile_end: int,
243-
zoom: int,
244-
) -> np.ndarray:
232+
def _fetch_image_for_bounds(self,
233+
x_tile_start: int,
234+
y_tile_start: int,
235+
x_tile_end: int,
236+
y_tile_end: int,
237+
zoom: int,
238+
multithread=True) -> np.ndarray:
245239
"""Fetches the tiles and combines them into a single image
246240
"""
247241
tiles = {}
248-
with ThreadPoolExecutor(max_workers=TILE_DOWNLOAD_CONCURRENCY) as exc:
242+
if multithread:
243+
with ThreadPoolExecutor(
244+
max_workers=TILE_DOWNLOAD_CONCURRENCY) as exc:
245+
for x in range(x_tile_start, x_tile_end + 1):
246+
for y in range(y_tile_start, y_tile_end + 1):
247+
tiles[(x, y)] = exc.submit(self._fetch_tile, x, y, zoom)
248+
249+
rows = []
250+
for y in range(y_tile_start, y_tile_end + 1):
251+
rows.append(
252+
np.hstack([
253+
tiles[(x, y)].result()
254+
for x in range(x_tile_start, x_tile_end + 1)
255+
]))
256+
#no multithreading
257+
else:
249258
for x in range(x_tile_start, x_tile_end + 1):
250259
for y in range(y_tile_start, y_tile_end + 1):
251-
tiles[(x, y)] = exc.submit(self._fetch_tile, x, y, zoom)
260+
tiles[(x, y)] = self._fetch_tile(x, y, zoom)
252261

253-
rows = []
254-
for y in range(y_tile_start, y_tile_end + 1):
255-
rows.append(
256-
np.hstack([
257-
tiles[(x, y)].result()
258-
for x in range(x_tile_start, x_tile_end + 1)
259-
]))
262+
rows = []
263+
for y in range(y_tile_start, y_tile_end + 1):
264+
rows.append(
265+
np.hstack([
266+
tiles[(x, y)]
267+
for x in range(x_tile_start, x_tile_end + 1)
268+
]))
260269

261270
return np.vstack(rows)
262271

@@ -269,7 +278,7 @@ def _fetch_tile(self, x: int, y: int, z: int) -> np.ndarray:
269278
try:
270279
data = requests.get(self.tile_layer.url.format(x=x, y=y, z=z))
271280
data.raise_for_status()
272-
decoded = tf.image.decode_image(data.content, channels=3).numpy()
281+
decoded = np.array(Image.open(BytesIO(data.content)))[..., :3]
273282
if decoded.shape[:2] != (self.tile_size, self.tile_size):
274283
logger.warning(
275284
f"Unexpected tile size {decoded.shape}. Results aren't guarenteed to be correct."
@@ -305,6 +314,17 @@ def invert_point(pt):
305314
x_px_end, y_px_end = invert_point(x_px_end), invert_point(y_px_end)
306315
return image[y_px_start:y_px_end, x_px_start:x_px_end, :]
307316

317+
def _validate_num_tiles(self, xstart: float, ystart: float, xend: float,
318+
yend: float, max_tiles: int):
319+
"""Calculates the number of expected tiles we would fetch.
320+
321+
If this is greater than the number of max tiles, raise an error.
322+
"""
323+
total_n_tiles = (yend - ystart + 1) * (xend - xstart + 1)
324+
if total_n_tiles > max_tiles:
325+
raise ValueError(f"Requested zoom results in {total_n_tiles} tiles."
326+
f"Max allowed tiles are {max_tiles}")
327+
308328
@validator('zoom_levels')
309329
def validate_zoom_levels(cls, zoom_levels):
310330
if zoom_levels[0] > zoom_levels[1]:

0 commit comments

Comments
 (0)