5
5
import sqlalchemy
6
6
from sqlalchemy .engine .interfaces import Dialect
7
7
from sqlalchemy .ext .compiler import compiles
8
+ from sqlalchemy .types import TypeDecorator , UserDefinedType
8
9
9
10
from databricks .sql .utils import ParamEscaper
10
11
@@ -26,6 +27,11 @@ def process_literal_param_hack(value: Any):
26
27
return value
27
28
28
29
30
+ def identity_processor (value ):
31
+ """This method returns the value itself, when no other processor is provided"""
32
+ return value
33
+
34
+
29
35
@compiles (sqlalchemy .types .Enum , "databricks" )
30
36
@compiles (sqlalchemy .types .String , "databricks" )
31
37
@compiles (sqlalchemy .types .Text , "databricks" )
@@ -321,3 +327,73 @@ class TINYINT(sqlalchemy.types.TypeDecorator):
321
327
@compiles (TINYINT , "databricks" )
322
328
def compile_tinyint (type_ , compiler , ** kw ):
323
329
return "TINYINT"
330
+
331
+
332
+ class DatabricksArray (UserDefinedType ):
333
+ """
334
+ A custom array type that can wrap any other SQLAlchemy type.
335
+
336
+ Examples:
337
+ DatabricksArray(String) -> ARRAY<STRING>
338
+ DatabricksArray(Integer) -> ARRAY<INT>
339
+ DatabricksArray(CustomType) -> ARRAY<CUSTOM_TYPE>
340
+ """
341
+
342
+ def __init__ (self , item_type ):
343
+ self .item_type = item_type () if isinstance (item_type , type ) else item_type
344
+
345
+ def bind_processor (self , dialect ):
346
+ item_processor = self .item_type .bind_processor (dialect )
347
+ if item_processor is None :
348
+ item_processor = identity_processor
349
+
350
+ def process (value ):
351
+ return [item_processor (val ) for val in value ]
352
+
353
+ return process
354
+
355
+
356
+ @compiles (DatabricksArray , "databricks" )
357
+ def compile_databricks_array (type_ , compiler , ** kw ):
358
+ inner = compiler .process (type_ .item_type , ** kw )
359
+
360
+ return f"ARRAY<{ inner } >"
361
+
362
+
363
+ class DatabricksMap (UserDefinedType ):
364
+ """
365
+ A custom map type that can wrap any other SQLAlchemy types for both key and value.
366
+
367
+ Examples:
368
+ DatabricksMap(String, String) -> MAP<STRING,STRING>
369
+ DatabricksMap(Integer, String) -> MAP<INT,STRING>
370
+ DatabricksMap(String, DatabricksArray(Integer)) -> MAP<STRING,ARRAY<INT>>
371
+ """
372
+
373
+ def __init__ (self , key_type , value_type ):
374
+ self .key_type = key_type () if isinstance (key_type , type ) else key_type
375
+ self .value_type = value_type () if isinstance (value_type , type ) else value_type
376
+
377
+ def bind_processor (self , dialect ):
378
+ key_processor = self .key_type .bind_processor (dialect )
379
+ value_processor = self .value_type .bind_processor (dialect )
380
+
381
+ if key_processor is None :
382
+ key_processor = identity_processor
383
+ if value_processor is None :
384
+ value_processor = identity_processor
385
+
386
+ def process (value ):
387
+ return {
388
+ key_processor (key ): value_processor (value )
389
+ for key , value in value .items ()
390
+ }
391
+
392
+ return process
393
+
394
+
395
+ @compiles (DatabricksMap , "databricks" )
396
+ def compile_databricks_map (type_ , compiler , ** kw ):
397
+ key_type = compiler .process (type_ .key_type , ** kw )
398
+ value_type = compiler .process (type_ .value_type , ** kw )
399
+ return f"MAP<{ key_type } ,{ value_type } >"
0 commit comments