3
3
from enum import Enum
4
4
from typing import Optional , List , Tuple
5
5
from concurrent .futures import ThreadPoolExecutor
6
+ from io import BytesIO
6
7
7
8
import requests
8
9
import numpy as np
9
- import tensorflow as tf
10
+
10
11
from retry import retry
12
+ import tensorflow as f
13
+ from PIL import Image
11
14
from pyproj import Transformer
12
15
from pydantic import BaseModel , validator , conlist
13
16
from pydantic .class_validators import root_validator
14
17
15
18
from ..geometry import Point
16
19
from .base_data import BaseData
17
20
from .raster import RasterData
18
- """TODO: consider how to swap lat,lng to lng,lt when version = 2...
19
- should the bounds validator be inside the TiledImageData class then,
20
- since we need to check on Version?
21
- """
22
21
23
22
VALID_LAT_RANGE = range (- 90 , 90 )
24
23
VALID_LNG_RANGE = range (- 180 , 180 )
25
- TMS_TILE_SIZE = 256
26
- MAX_TILES = 300 #TODO: thinking of how to choose an appropriate max tiles number. 18 seems too small, but over 1000 likely seems too large
24
+ DEFAULT_TMS_TILE_SIZE = 256
27
25
TILE_DOWNLOAD_CONCURRENCY = 4
28
26
29
27
logging .basicConfig (level = logging .INFO )
@@ -40,7 +38,6 @@ class EPSG(Enum):
40
38
SIMPLEPIXEL = 1
41
39
EPSG4326 = 4326
42
40
EPSG3857 = 3857
43
- EPSG3395 = 3395
44
41
45
42
46
43
class TiledBounds (BaseModel ):
@@ -73,16 +70,15 @@ def validate_bounds_not_equal(cls, bounds):
73
70
def validate_bounds_lat_lng (cls , values ):
74
71
epsg = values .get ('epsg' )
75
72
bounds = values .get ('bounds' )
76
- #TODO: look into messaging that we only support 4326 right now. raise exception, not implemented
77
73
78
74
if epsg != EPSG .SIMPLEPIXEL :
79
75
for bound in bounds :
80
76
lat , lng = bound .y , bound .x
81
77
if int (lng ) not in VALID_LNG_RANGE or int (
82
78
lat ) not in VALID_LAT_RANGE :
83
79
raise ValueError (f"Invalid lat/lng bounds. Found { bounds } . "
84
- "lat must be in {VALID_LAT_RANGE}. "
85
- "lng must be in {VALID_LNG_RANGE}." )
80
+ f "lat must be in { VALID_LAT_RANGE } . "
81
+ f "lng must be in { VALID_LNG_RANGE } ." )
86
82
return values
87
83
88
84
@@ -103,7 +99,7 @@ class TileLayer(BaseModel):
103
99
def validate_url (cls , url ):
104
100
xyz_format = "/{z}/{x}/{y}"
105
101
if xyz_format not in url :
106
- raise AssertionError (f"{ url } needs to contain { xyz_format } " )
102
+ raise ValueError (f"{ url } needs to contain { xyz_format } " )
107
103
return url
108
104
109
105
@@ -129,12 +125,16 @@ class TiledImageData(BaseData):
129
125
tile_layer : TileLayer
130
126
tile_bounds : TiledBounds
131
127
alternative_layers : List [TileLayer ] = None
132
- zoom_levels : conlist (item_type = int , min_items = 2 , max_items = 2 )
133
- max_native_zoom : int = None
134
- tile_size : Optional [int ] = TMS_TILE_SIZE
135
- version : int = 2
136
-
137
- def _as_raster (self , zoom = 0 ) -> RasterData :
128
+ zoom_levels : Tuple [int , int ]
129
+ max_native_zoom : Optional [int ] = None
130
+ tile_size : Optional [int ] = DEFAULT_TMS_TILE_SIZE
131
+ version : Optional [int ] = 2
132
+ multithread : bool = True
133
+
134
+ def as_raster_data (self ,
135
+ zoom : int = 0 ,
136
+ max_tiles : int = 32 ,
137
+ multithread = True ) -> RasterData :
138
138
"""Converts the tiled image asset into a RasterData object containing an
139
139
np.ndarray.
140
140
@@ -146,34 +146,29 @@ def _as_raster(self, zoom=0) -> RasterData:
146
146
# Currently our editor doesn't support anything other than 3857.
147
147
# Since the user provided projection is ignored by the editor
148
148
# we will ignore it here and assume that the projection is 3857.
149
- else :
150
- if self .tile_bounds .epsg != EPSG .EPSG3857 :
151
- logger .info (
152
- f"User provided EPSG is being ignored { self .tile_bounds .epsg } ."
153
- )
149
+ elif self .tile_bounds .epsg == EPSG .EPSG3857 :
154
150
xstart , ystart , xend , yend = self ._get_3857_image_params (zoom )
151
+ else :
152
+ raise ValueError (
153
+ f"Unsupported epsg found...{ self .tile_bounds .epsg } " )
155
154
156
- total_n_tiles = (yend - ystart + 1 ) * (xend - xstart + 1 )
157
- if total_n_tiles > MAX_TILES :
158
- logger .info (
159
- f"Too many tiles requested. Total tiles attempted { total_n_tiles } ."
160
- )
161
- return None
155
+ self ._validate_num_tiles (xstart , ystart , xend , yend , max_tiles )
162
156
163
157
rounded_tiles , pixel_offsets = list (
164
158
zip (* [
165
159
self ._tile_to_pixel (pt ) for pt in [xstart , ystart , xend , yend ]
166
160
]))
167
161
168
- image = self ._fetch_image_for_bounds (* rounded_tiles , zoom )
162
+ image = self ._fetch_image_for_bounds (* rounded_tiles , zoom , multithread )
169
163
arr = self ._crop_to_bounds (image , * pixel_offsets )
170
164
return RasterData (arr = arr )
171
165
172
166
@property
173
167
def value (self ) -> np .ndarray :
174
168
"""Returns the value of a generated RasterData object.
175
169
"""
176
- return self ._as_raster (self .zoom_levels [0 ]).value
170
+ return self .as_raster_data (self .zoom_levels [0 ],
171
+ multithread = self .multithread ).value
177
172
178
173
def _get_simple_image_params (self ,
179
174
zoom ) -> Tuple [float , float , float , float ]:
@@ -234,29 +229,43 @@ def _tile_to_pixel(self, tile: float) -> Tuple[int, int]:
234
229
pixel_offset = int (self .tile_size * remainder )
235
230
return rounded_tile , pixel_offset
236
231
237
- def _fetch_image_for_bounds (
238
- self ,
239
- x_tile_start : int ,
240
- y_tile_start : int ,
241
- x_tile_end : int ,
242
- y_tile_end : int ,
243
- zoom : int ,
244
- ) -> np .ndarray :
232
+ def _fetch_image_for_bounds (self ,
233
+ x_tile_start : int ,
234
+ y_tile_start : int ,
235
+ x_tile_end : int ,
236
+ y_tile_end : int ,
237
+ zoom : int ,
238
+ multithread = True ) -> np .ndarray :
245
239
"""Fetches the tiles and combines them into a single image
246
240
"""
247
241
tiles = {}
248
- with ThreadPoolExecutor (max_workers = TILE_DOWNLOAD_CONCURRENCY ) as exc :
242
+ if multithread :
243
+ with ThreadPoolExecutor (
244
+ max_workers = TILE_DOWNLOAD_CONCURRENCY ) as exc :
245
+ for x in range (x_tile_start , x_tile_end + 1 ):
246
+ for y in range (y_tile_start , y_tile_end + 1 ):
247
+ tiles [(x , y )] = exc .submit (self ._fetch_tile , x , y , zoom )
248
+
249
+ rows = []
250
+ for y in range (y_tile_start , y_tile_end + 1 ):
251
+ rows .append (
252
+ np .hstack ([
253
+ tiles [(x , y )].result ()
254
+ for x in range (x_tile_start , x_tile_end + 1 )
255
+ ]))
256
+ #no multithreading
257
+ else :
249
258
for x in range (x_tile_start , x_tile_end + 1 ):
250
259
for y in range (y_tile_start , y_tile_end + 1 ):
251
- tiles [(x , y )] = exc . submit ( self ._fetch_tile , x , y , zoom )
260
+ tiles [(x , y )] = self ._fetch_tile ( x , y , zoom )
252
261
253
- rows = []
254
- for y in range (y_tile_start , y_tile_end + 1 ):
255
- rows .append (
256
- np .hstack ([
257
- tiles [(x , y )]. result ()
258
- for x in range (x_tile_start , x_tile_end + 1 )
259
- ]))
262
+ rows = []
263
+ for y in range (y_tile_start , y_tile_end + 1 ):
264
+ rows .append (
265
+ np .hstack ([
266
+ tiles [(x , y )]
267
+ for x in range (x_tile_start , x_tile_end + 1 )
268
+ ]))
260
269
261
270
return np .vstack (rows )
262
271
@@ -269,7 +278,7 @@ def _fetch_tile(self, x: int, y: int, z: int) -> np.ndarray:
269
278
try :
270
279
data = requests .get (self .tile_layer .url .format (x = x , y = y , z = z ))
271
280
data .raise_for_status ()
272
- decoded = tf . image . decode_image ( data .content , channels = 3 ). numpy ()
281
+ decoded = np . array ( Image . open ( BytesIO ( data .content )))[..., : 3 ]
273
282
if decoded .shape [:2 ] != (self .tile_size , self .tile_size ):
274
283
logger .warning (
275
284
f"Unexpected tile size { decoded .shape } . Results aren't guarenteed to be correct."
@@ -305,6 +314,17 @@ def invert_point(pt):
305
314
x_px_end , y_px_end = invert_point (x_px_end ), invert_point (y_px_end )
306
315
return image [y_px_start :y_px_end , x_px_start :x_px_end , :]
307
316
317
+ def _validate_num_tiles (self , xstart : float , ystart : float , xend : float ,
318
+ yend : float , max_tiles : int ):
319
+ """Calculates the number of expected tiles we would fetch.
320
+
321
+ If this is greater than the number of max tiles, raise an error.
322
+ """
323
+ total_n_tiles = (yend - ystart + 1 ) * (xend - xstart + 1 )
324
+ if total_n_tiles > max_tiles :
325
+ raise ValueError (f"Requested zoom results in { total_n_tiles } tiles."
326
+ f"Max allowed tiles are { max_tiles } " )
327
+
308
328
@validator ('zoom_levels' )
309
329
def validate_zoom_levels (cls , zoom_levels ):
310
330
if zoom_levels [0 ] > zoom_levels [1 ]:
0 commit comments