1
1
import logging
2
- from typing import Dict , Tuple , List , Optional , Any
2
+ from typing import Dict , Tuple , List , Optional , Any , Type
3
3
4
4
from databricks .sql .thrift_api .TCLIService import ttypes
5
5
from databricks .sql .types import SSLOptions
@@ -62,6 +62,7 @@ def __init__(
62
62
useragent_header = "{}/{}" .format (USER_AGENT_NAME , __version__ )
63
63
64
64
base_headers = [("User-Agent" , useragent_header )]
65
+ all_headers = (http_headers or []) + base_headers
65
66
66
67
self ._ssl_options = SSLOptions (
67
68
# Double negation is generally a bad thing, but we have to keep backward compatibility
@@ -75,33 +76,49 @@ def __init__(
75
76
tls_client_cert_key_password = kwargs .get ("_tls_client_cert_key_password" ),
76
77
)
77
78
78
- # Determine which backend to use
79
+ self .backend = self ._create_backend (
80
+ server_hostname ,
81
+ http_path ,
82
+ all_headers ,
83
+ auth_provider ,
84
+ _use_arrow_native_complex_types ,
85
+ kwargs ,
86
+ )
87
+
88
+ self .protocol_version = None
89
+
90
+ def _create_backend (
91
+ self ,
92
+ server_hostname : str ,
93
+ http_path : str ,
94
+ all_headers : List [Tuple [str , str ]],
95
+ auth_provider ,
96
+ _use_arrow_native_complex_types : bool ,
97
+ kwargs : dict ,
98
+ ) -> DatabricksClient :
99
+ """Create and return the appropriate backend client."""
79
100
use_sea = kwargs .get ("use_sea" , False )
80
101
81
102
if use_sea :
82
- self .backend : DatabricksClient = SeaDatabricksClient (
83
- self .host ,
84
- self .port ,
85
- http_path ,
86
- (http_headers or []) + base_headers ,
87
- auth_provider ,
88
- ssl_options = self ._ssl_options ,
89
- _use_arrow_native_complex_types = _use_arrow_native_complex_types ,
90
- ** kwargs ,
91
- )
103
+ logger .debug ("Creating SEA backend client" )
104
+ databricks_client_class = SeaDatabricksClient
92
105
else :
93
- self .backend = ThriftDatabricksClient (
94
- self .host ,
95
- self .port ,
96
- http_path ,
97
- (http_headers or []) + base_headers ,
98
- auth_provider ,
99
- ssl_options = self ._ssl_options ,
100
- _use_arrow_native_complex_types = _use_arrow_native_complex_types ,
101
- ** kwargs ,
102
- )
103
-
104
- self .protocol_version = None
106
+ logger .debug ("Creating Thrift backend client" )
107
+ databricks_client_class = ThriftDatabricksClient
108
+
109
+ # Prepare common arguments
110
+ common_args = {
111
+ "server_hostname" : server_hostname ,
112
+ "port" : self .port ,
113
+ "http_path" : http_path ,
114
+ "http_headers" : all_headers ,
115
+ "auth_provider" : auth_provider ,
116
+ "ssl_options" : self ._ssl_options ,
117
+ "_use_arrow_native_complex_types" : _use_arrow_native_complex_types ,
118
+ ** kwargs ,
119
+ }
120
+
121
+ return databricks_client_class (** common_args )
105
122
106
123
def open (self ):
107
124
self ._session_id = self .backend .open_session (
0 commit comments