13
13
SessionAlreadyClosedError ,
14
14
CursorAlreadyClosedError ,
15
15
)
16
+ from databricks .sql .thrift_api .TCLIService import ttypes
16
17
from databricks .sql .thrift_backend import ThriftBackend
17
18
from databricks .sql .utils import (
18
19
ExecuteResponse ,
@@ -196,9 +197,11 @@ def read(self) -> Optional[OAuthToken]:
196
197
** kwargs ,
197
198
)
198
199
199
- self ._session_handle = self .thrift_backend .open_session (
200
+ self ._open_session_resp = self .thrift_backend .open_session (
200
201
session_configuration , catalog , schema
201
202
)
203
+ self ._session_handle = self ._open_session_resp .sessionHandle
204
+ self .protocol_version = self .get_protocol_version (self ._open_session_resp )
202
205
self .use_cloud_fetch = kwargs .get ("use_cloud_fetch" , True )
203
206
self .open = True
204
207
logger .info ("Successfully opened session " + str (self .get_session_id_hex ()))
@@ -225,6 +228,30 @@ def __del__(self):
225
228
def get_session_id (self ):
226
229
return self .thrift_backend .handle_to_id (self ._session_handle )
227
230
231
+ @staticmethod
232
+ def get_protocol_version (openSessionResp ):
233
+ """
234
+ Since the sessionHandle will sometimes have a serverProtocolVersion, it takes
235
+ precedence over the serverProtocolVersion defined in the OpenSessionResponse.
236
+ """
237
+ if (
238
+ openSessionResp .sessionHandle
239
+ and hasattr (openSessionResp .sessionHandle , "serverProtocolVersion" )
240
+ and openSessionResp .sessionHandle .serverProtocolVersion
241
+ ):
242
+ return openSessionResp .sessionHandle .serverProtocolVersion
243
+ return openSessionResp .serverProtocolVersion
244
+
245
+ @staticmethod
246
+ def server_parameterized_queries_enabled (protocolVersion ):
247
+ if (
248
+ protocolVersion
249
+ and protocolVersion >= ttypes .TProtocolVersion .SPARK_CLI_SERVICE_PROTOCOL_V8
250
+ ):
251
+ return True
252
+ else :
253
+ return False
254
+
228
255
def get_session_id_hex (self ):
229
256
return self .thrift_backend .handle_to_hex_id (self ._session_handle )
230
257
@@ -501,6 +528,13 @@ def execute(
501
528
"""
502
529
if parameters is None :
503
530
parameters = []
531
+
532
+ elif not Connection .server_parameterized_queries_enabled (
533
+ self .connection .protocol_version
534
+ ):
535
+ raise NotSupportedError (
536
+ "Parameterized operations are not supported by this server. DBR 14.1 is required."
537
+ )
504
538
else :
505
539
parameters = named_parameters_to_tsparkparams (parameters )
506
540
0 commit comments