Skip to content

Commit 4474d5d

Browse files
authored
Support for complex params (#30)
* Support for complex params * More tests * Added better parsing * nit * e2e tests * nit * Updated poetry.lock * poetry changes * more tests
1 parent ca58b31 commit 4474d5d

File tree

8 files changed

+534
-13
lines changed

8 files changed

+534
-13
lines changed

poetry.lock

Lines changed: 95 additions & 10 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/databricks/sqlalchemy/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11
from databricks.sqlalchemy.base import DatabricksDialect
2-
from databricks.sqlalchemy._types import TINYINT, TIMESTAMP, TIMESTAMP_NTZ
2+
from databricks.sqlalchemy._types import (
3+
TINYINT,
4+
TIMESTAMP,
5+
TIMESTAMP_NTZ,
6+
DatabricksArray,
7+
DatabricksMap,
8+
)
39

4-
__all__ = ["TINYINT", "TIMESTAMP", "TIMESTAMP_NTZ"]
10+
__all__ = ["TINYINT", "TIMESTAMP", "TIMESTAMP_NTZ", "DatabricksArray", "DatabricksMap"]

src/databricks/sqlalchemy/_types.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import sqlalchemy
66
from sqlalchemy.engine.interfaces import Dialect
77
from sqlalchemy.ext.compiler import compiles
8+
from sqlalchemy.types import TypeDecorator, UserDefinedType
89

910
from databricks.sql.utils import ParamEscaper
1011

@@ -26,6 +27,11 @@ def process_literal_param_hack(value: Any):
2627
return value
2728

2829

30+
def identity_processor(value):
31+
"""This method returns the value itself, when no other processor is provided"""
32+
return value
33+
34+
2935
@compiles(sqlalchemy.types.Enum, "databricks")
3036
@compiles(sqlalchemy.types.String, "databricks")
3137
@compiles(sqlalchemy.types.Text, "databricks")
@@ -321,3 +327,73 @@ class TINYINT(sqlalchemy.types.TypeDecorator):
321327
@compiles(TINYINT, "databricks")
322328
def compile_tinyint(type_, compiler, **kw):
323329
return "TINYINT"
330+
331+
332+
class DatabricksArray(UserDefinedType):
333+
"""
334+
A custom array type that can wrap any other SQLAlchemy type.
335+
336+
Examples:
337+
DatabricksArray(String) -> ARRAY<STRING>
338+
DatabricksArray(Integer) -> ARRAY<INT>
339+
DatabricksArray(CustomType) -> ARRAY<CUSTOM_TYPE>
340+
"""
341+
342+
def __init__(self, item_type):
343+
self.item_type = item_type() if isinstance(item_type, type) else item_type
344+
345+
def bind_processor(self, dialect):
346+
item_processor = self.item_type.bind_processor(dialect)
347+
if item_processor is None:
348+
item_processor = identity_processor
349+
350+
def process(value):
351+
return [item_processor(val) for val in value]
352+
353+
return process
354+
355+
356+
@compiles(DatabricksArray, "databricks")
357+
def compile_databricks_array(type_, compiler, **kw):
358+
inner = compiler.process(type_.item_type, **kw)
359+
360+
return f"ARRAY<{inner}>"
361+
362+
363+
class DatabricksMap(UserDefinedType):
364+
"""
365+
A custom map type that can wrap any other SQLAlchemy types for both key and value.
366+
367+
Examples:
368+
DatabricksMap(String, String) -> MAP<STRING,STRING>
369+
DatabricksMap(Integer, String) -> MAP<INT,STRING>
370+
DatabricksMap(String, DatabricksArray(Integer)) -> MAP<STRING,ARRAY<INT>>
371+
"""
372+
373+
def __init__(self, key_type, value_type):
374+
self.key_type = key_type() if isinstance(key_type, type) else key_type
375+
self.value_type = value_type() if isinstance(value_type, type) else value_type
376+
377+
def bind_processor(self, dialect):
378+
key_processor = self.key_type.bind_processor(dialect)
379+
value_processor = self.value_type.bind_processor(dialect)
380+
381+
if key_processor is None:
382+
key_processor = identity_processor
383+
if value_processor is None:
384+
value_processor = identity_processor
385+
386+
def process(value):
387+
return {
388+
key_processor(key): value_processor(value)
389+
for key, value in value.items()
390+
}
391+
392+
return process
393+
394+
395+
@compiles(DatabricksMap, "databricks")
396+
def compile_databricks_map(type_, compiler, **kw):
397+
key_type = compiler.process(type_.key_type, **kw)
398+
value_type = compiler.process(type_.value_type, **kw)
399+
return f"MAP<{key_type},{value_type}>"

tests/test_local/e2e/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)