Skip to content

Commit 06bd616

Browse files
Jessesaishreeeee
authored andcommitted
Native Parameters: reintroduce INLINE approach with tests (#267)
Signed-off-by: Jesse Whitehouse <jesse.whitehouse@databricks.com> Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent 06267ba commit 06bd616

File tree

10 files changed

+671
-267
lines changed

10 files changed

+671
-267
lines changed

docs/parameters.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
`<placeholder>`

src/databricks/sql/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
# PEP 249 module globals
66
apilevel = "2.0"
77
threadsafety = 1 # Threads may share the module, but not connections.
8-
paramstyle = "named" # Python extended format codes, e.g. ...WHERE name=%(name)s
8+
9+
# Python extended format codes, e.g. ...WHERE name=%(name)s
10+
# Note that when we switch to ParameterApproach.NATIVE, paramstyle will be `named`
11+
paramstyle = "pyformat"
912

1013

1114
class DBAPITypeObject(object):

src/databricks/sql/client.py

Lines changed: 151 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,25 @@
1919
ExecuteResponse,
2020
ParamEscaper,
2121
named_parameters_to_tsparkparams,
22+
inject_parameters,
23+
ParameterApproach,
2224
)
2325
from databricks.sql.types import Row
2426
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
2527
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
2628

29+
from databricks.sql.thrift_api.TCLIService.ttypes import (
30+
TSparkParameter,
31+
)
32+
33+
2734
logger = logging.getLogger(__name__)
2835

2936
DEFAULT_RESULT_BUFFER_SIZE_BYTES = 104857600
3037
DEFAULT_ARRAY_SIZE = 100000
3138

39+
NO_NATIVE_PARAMS: List = []
40+
3241

3342
class Connection:
3443
def __init__(
@@ -65,6 +74,12 @@ def __init__(
6574
:param schema: An optional initial schema to use. Requires DBR version 9.0+
6675
6776
Other Parameters:
77+
use_inline_params: `boolean`, optional (default is True)
78+
When True, parameterized calls to cursor.execute() will try to render parameter values inline with the
79+
query text instead of using native bound parameters supported in DBR 14.1 and above. This connector will attempt to
80+
sanitise parameterized inputs to prevent SQL injection. Before you can switch this to False, you must
81+
update your queries to use the PEP-249 `named` paramstyle instead of the `pyformat` paramstyle used
82+
in INLINE mode.
6883
auth_type: `str`, optional
6984
`databricks-oauth` : to use oauth with fine-grained permission scopes, set to `databricks-oauth`.
7085
This is currently in private preview for Databricks accounts on AWS.
@@ -207,6 +222,9 @@ def read(self) -> Optional[OAuthToken]:
207222
logger.info("Successfully opened session " + str(self.get_session_id_hex()))
208223
self._cursors = [] # type: List[Cursor]
209224

225+
self._suppress_inline_warning = "use_inline_params" in kwargs
226+
self.use_inline_params = kwargs.get("use_inline_params", True)
227+
210228
def __enter__(self):
211229
return self
212230

@@ -358,6 +376,100 @@ def __iter__(self):
358376
else:
359377
raise Error("There is no active result set")
360378

379+
def _determine_parameter_approach(
380+
self, params: Optional[Union[List, Dict[str, Any]]] = None
381+
) -> ParameterApproach:
382+
"""Encapsulates the logic for choosing whether to send parameters in native vs inline mode
383+
384+
If params is None then ParameterApproach.NONE is returned.
385+
If self.use_inline_params is True then inline mode is used.
386+
If self.use_inline_params is False, then check if the server supports them and proceed.
387+
Else raise an exception.
388+
389+
Returns a ParameterApproach enumeration or raises an exception
390+
391+
If inline approach is used when the server supports native approach, a warning is logged
392+
"""
393+
394+
if params is None:
395+
return ParameterApproach.NONE
396+
397+
server_supports_native_approach = (
398+
self.connection.server_parameterized_queries_enabled(
399+
self.connection.protocol_version
400+
)
401+
)
402+
403+
if self.connection.use_inline_params:
404+
if (
405+
server_supports_native_approach
406+
and not self.connection._suppress_inline_warning
407+
):
408+
logger.warning(
409+
"This query will be executed with inline parameters."
410+
"Consider using native parameters."
411+
"Learn more: https://github.com/databricks/databricks-sql-python/tree/main/docs/parameters.md"
412+
"To suppress this warning, pass use_inline_params=True when creating the connection."
413+
)
414+
return ParameterApproach.INLINE
415+
416+
elif server_supports_native_approach:
417+
return ParameterApproach.NATIVE
418+
else:
419+
raise NotSupportedError(
420+
"Parameterized operations are not supported by this server. DBR 14.1 is required."
421+
)
422+
423+
def _prepare_inline_parameters(
424+
self, stmt: str, params: Optional[Union[List, Dict[str, Any]]]
425+
) -> Tuple[str, List]:
426+
"""Return a statement and list of native parameters to be passed to thrift_backend for execution
427+
428+
:stmt:
429+
A string SQL query containing parameter markers of PEP-249 paramstyle `pyformat`.
430+
For example `%(param)s`.
431+
432+
:params:
433+
An iterable of parameter values to be rendered inline. If passed as a Dict, the keys
434+
must match the names of the markers included in :stmt:. If passed as a List, its length
435+
must equal the count of parameter markers in :stmt:.
436+
437+
Returns a tuple of:
438+
stmt: the passed statement with the param markers replaced by literal rendered values
439+
params: an empty list representing the native parameters to be passed with this query.
440+
The list is always empty because native parameters are never used under the inline approach
441+
"""
442+
443+
escaped_values = self.escaper.escape_args(params)
444+
rendered_statement = inject_parameters(stmt, escaped_values)
445+
446+
return rendered_statement, NO_NATIVE_PARAMS
447+
448+
def _prepare_native_parameters(
449+
self, stmt: str, params: Optional[Union[List[Any], Dict[str, Any]]]
450+
) -> Tuple[str, List[TSparkParameter]]:
451+
"""Return a statement and a list of native parameters to be passed to thrift_backend for execution
452+
453+
:stmt:
454+
A string SQL query containing parameter markers of PEP-249 paramstyle `named`.
455+
For example `:param`.
456+
457+
:params:
458+
An iterable of parameter values to be sent natively. If passed as a Dict, the keys
459+
must match the names of the markers included in :stmt:. If passed as a List, its length
460+
must equal the count of parameter markers in :stmt:. In list form, any member of the list
461+
can be wrapped in a DbsqlParameter class.
462+
463+
Returns a tuple of:
464+
stmt: the passed statement` with the param markers replaced by literal rendered values
465+
params: a list of TSparkParameters that will be passed in native mode
466+
"""
467+
468+
stmt = stmt
469+
params = named_parameters_to_tsparkparams(params) # type: ignore
470+
471+
return stmt, params
472+
361473
def _close_and_clear_active_result_set(self):
362474
try:
363475
if self.active_result_set:
@@ -515,40 +627,62 @@ def _handle_staging_remove(self, presigned_url: str, headers: dict = None):
515627
def execute(
516628
self,
517629
operation: str,
518-
parameters: Optional[Union[List[Any], Dict[str, str]]] = None,
630+
parameters: Optional[Union[List[Any], Dict[str, Any]]] = None,
519631
) -> "Cursor":
520632
"""
521633
Execute a query and wait for execution to complete.
522-
Parameters should be given in extended param format style: %(...)<s|d|f>.
523-
For example:
524-
operation = "SELECT * FROM table WHERE field = %(some_value)s"
525-
parameters = {"some_value": "foo"}
526-
Will result in the query "SELECT * FROM table WHERE field = 'foo' being sent to the server
634+
635+
The parameterisation behaviour of this method depends on which parameter approach is used:
636+
- With INLINE mode (default), parameters are rendered inline with the query text
637+
- With NATIVE mode, parameters are sent to the server separately for binding
638+
639+
This behaviour is controlled by the `use_inline_params` argument passed when building a connection.
640+
641+
The syntax for these approaches is different:
642+
643+
If the connection was instantiated with use_inline_params=False, then parameters
644+
should be given in PEP-249 `named` paramstyle like :param_name
645+
646+
If the connection was instantiated with use_inline_params=True (default), then parameters
647+
should be given in PEP-249 `pyformat` paramstyle like %(param_name)s
648+
649+
```python
650+
inline_operation = "SELECT * FROM table WHERE field = %(some_value)s"
651+
native_operation = "SELECT * FROM table WHERE field = :some_value"
652+
parameters = {"some_value": "foo"}
653+
```
654+
655+
Both will result in the query equivalent to "SELECT * FROM table WHERE field = 'foo'
656+
being sent to the server
657+
527658
:returns self
528659
"""
529-
if parameters is None:
530-
parameters = []
531660

532-
elif not Connection.server_parameterized_queries_enabled(
533-
self.connection.protocol_version
534-
):
535-
raise NotSupportedError(
536-
"Parameterized operations are not supported by this server. DBR 14.1 is required."
661+
param_approach = self._determine_parameter_approach(parameters)
662+
if param_approach == ParameterApproach.NONE:
663+
prepared_params = NO_NATIVE_PARAMS
664+
prepared_operation = operation
665+
666+
elif param_approach == ParameterApproach.INLINE:
667+
prepared_operation, prepared_params = self._prepare_inline_parameters(
668+
operation, parameters
669+
)
670+
elif param_approach == ParameterApproach.NATIVE:
671+
prepared_operation, prepared_params = self._prepare_native_parameters(
672+
operation, parameters
537673
)
538-
else:
539-
parameters = named_parameters_to_tsparkparams(parameters)
540674

541675
self._check_not_closed()
542676
self._close_and_clear_active_result_set()
543677
execute_response = self.thrift_backend.execute_command(
544-
operation=operation,
678+
operation=prepared_operation,
545679
session_handle=self.connection._session_handle,
546680
max_rows=self.arraysize,
547681
max_bytes=self.buffer_size_bytes,
548682
lz4_compression=self.connection.lz4_compression,
549683
cursor=self,
550684
use_cloud_fetch=self.connection.use_cloud_fetch,
551-
parameters=parameters,
685+
parameters=prepared_params,
552686
)
553687
self.active_result_set = ResultSet(
554688
self.connection,

src/databricks/sql/utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@
2525
BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128]
2626

2727

28+
class ParameterApproach(Enum):
29+
INLINE = 1
30+
NATIVE = 2
31+
NONE = 3
32+
33+
2834
class ResultSetQueue(ABC):
2935
@abstractmethod
3036
def next_n_rows(self, num_rows: int) -> pyarrow.Table:
@@ -627,7 +633,9 @@ def calculate_decimal_cast_string(input: Decimal) -> str:
627633
return f"DECIMAL({overall},{after})"
628634

629635

630-
def named_parameters_to_tsparkparams(parameters: Union[List[Any], Dict[str, str]]):
636+
def named_parameters_to_tsparkparams(
637+
parameters: Union[List[Any], Dict[str, str]]
638+
) -> List[TSparkParameter]:
631639
tspark_params = []
632640
if isinstance(parameters, dict):
633641
dbsql_params = named_parameters_to_dbsqlparams_v1(parameters)

src/databricks/sqlalchemy/base.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class DatabricksDialect(default.DefaultDialect):
5959
non_native_boolean_check_constraint: bool = False
6060
supports_identity_columns: bool = True
6161
supports_schemas: bool = True
62-
paramstyle: str = "named"
62+
default_paramstyle: str = "named"
6363
div_is_floordiv: bool = False
6464
supports_default_values: bool = False
6565
supports_server_side_cursors: bool = False
@@ -85,6 +85,21 @@ class DatabricksDialect(default.DefaultDialect):
8585
def dbapi(cls):
8686
return sql
8787

88+
def _force_paramstyle_to_native_mode(self):
89+
"""This method can be removed after databricks-sql-connector wholly switches to NATIVE ParamApproach.
90+
91+
This is a hack to trick SQLAlchemy into using a different paramstyle
92+
than the one declared by this module in src/databricks/sql/__init__.py
93+
94+
This method is called _after_ the dialect has been initialised, which is important because otherwise
95+
our users would need to include a `paramstyle` argument in their SQLAlchemy connection string.
96+
97+
This dialect is written to support NATIVE queries. Although the INLINE approach can technically work,
98+
the same behaviour can be achieved within SQLAlchemy itself using its literal_processor methods.
99+
"""
100+
101+
self.paramstyle = self.default_paramstyle
102+
88103
def create_connect_args(self, url):
89104
# TODO: can schema be provided after HOST?
90105
# Expected URI format is: databricks+thrift://token:dapi***@***.cloud.databricks.com?http_path=/sql/***
@@ -95,11 +110,14 @@ def create_connect_args(self, url):
95110
"http_path": url.query.get("http_path"),
96111
"catalog": url.query.get("catalog"),
97112
"schema": url.query.get("schema"),
113+
"use_inline_params": False,
98114
}
99115

100116
self.schema = kwargs["schema"]
101117
self.catalog = kwargs["catalog"]
102118

119+
self._force_paramstyle_to_native_mode()
120+
103121
return [], kwargs
104122

105123
def get_columns(

src/databricks/sqlalchemy/test/test_suite.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,20 @@
33
then are overridden by our local skip markers in _regression, _unsupported, and _future.
44
"""
55

6+
7+
def start_protocol_patch():
8+
"""See tests/test_parameterized_queries.py for more information about this patch."""
9+
from unittest.mock import patch
10+
11+
native_support_patcher = patch(
12+
"databricks.sql.client.Connection.server_parameterized_queries_enabled",
13+
return_value=True,
14+
)
15+
native_support_patcher.start()
16+
17+
18+
start_protocol_patch()
19+
620
# type: ignore
721
# fmt: off
822
from sqlalchemy.testing.suite import *

src/databricks/sqlalchemy/test_local/e2e/test_basic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
except ImportError:
2323
from sqlalchemy.ext.declarative import declarative_base
2424

25+
from databricks.sqlalchemy.test.test_suite import start_protocol_patch
26+
27+
start_protocol_patch()
28+
2529

2630
USER_AGENT_TOKEN = "PySQL e2e Tests"
2731

0 commit comments

Comments
 (0)