Skip to content

Commit 206eb0a

Browse files
use factory for backend instantiation
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent ed14ef7 commit 206eb0a

File tree

2 files changed

+50
-31
lines changed

2 files changed

+50
-31
lines changed

src/databricks/sql/session.py

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Dict, Tuple, List, Optional, Any
2+
from typing import Dict, Tuple, List, Optional, Any, Type
33

44
from databricks.sql.thrift_api.TCLIService import ttypes
55
from databricks.sql.types import SSLOptions
@@ -62,6 +62,7 @@ def __init__(
6262
useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__)
6363

6464
base_headers = [("User-Agent", useragent_header)]
65+
all_headers = (http_headers or []) + base_headers
6566

6667
self._ssl_options = SSLOptions(
6768
# Double negation is generally a bad thing, but we have to keep backward compatibility
@@ -75,33 +76,49 @@ def __init__(
7576
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
7677
)
7778

78-
# Determine which backend to use
79+
self.backend = self._create_backend(
80+
server_hostname,
81+
http_path,
82+
all_headers,
83+
auth_provider,
84+
_use_arrow_native_complex_types,
85+
kwargs,
86+
)
87+
88+
self.protocol_version = None
89+
90+
def _create_backend(
91+
self,
92+
server_hostname: str,
93+
http_path: str,
94+
all_headers: List[Tuple[str, str]],
95+
auth_provider,
96+
_use_arrow_native_complex_types: bool,
97+
kwargs: dict,
98+
) -> DatabricksClient:
99+
"""Create and return the appropriate backend client."""
79100
use_sea = kwargs.get("use_sea", False)
80101

81102
if use_sea:
82-
self.backend: DatabricksClient = SeaDatabricksClient(
83-
self.host,
84-
self.port,
85-
http_path,
86-
(http_headers or []) + base_headers,
87-
auth_provider,
88-
ssl_options=self._ssl_options,
89-
_use_arrow_native_complex_types=_use_arrow_native_complex_types,
90-
**kwargs,
91-
)
103+
logger.debug("Creating SEA backend client")
104+
databricks_client_class = SeaDatabricksClient
92105
else:
93-
self.backend = ThriftDatabricksClient(
94-
self.host,
95-
self.port,
96-
http_path,
97-
(http_headers or []) + base_headers,
98-
auth_provider,
99-
ssl_options=self._ssl_options,
100-
_use_arrow_native_complex_types=_use_arrow_native_complex_types,
101-
**kwargs,
102-
)
103-
104-
self.protocol_version = None
106+
logger.debug("Creating Thrift backend client")
107+
databricks_client_class = ThriftDatabricksClient
108+
109+
# Prepare common arguments
110+
common_args = {
111+
"server_hostname": server_hostname,
112+
"port": self.port,
113+
"http_path": http_path,
114+
"http_headers": all_headers,
115+
"auth_provider": auth_provider,
116+
"ssl_options": self._ssl_options,
117+
"_use_arrow_native_complex_types": _use_arrow_native_complex_types,
118+
**kwargs,
119+
}
120+
121+
return databricks_client_class(**common_args)
105122

106123
def open(self):
107124
self._session_id = self.backend.open_session(

tests/unit/test_session.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,18 +62,18 @@ def test_auth_args(self, mock_client_class):
6262

6363
for args in connection_args:
6464
connection = databricks.sql.connect(**args)
65-
host, port, http_path, *_ = mock_client_class.call_args[0]
66-
self.assertEqual(args["server_hostname"], host)
67-
self.assertEqual(args["http_path"], http_path)
65+
call_kwargs = mock_client_class.call_args[1]
66+
self.assertEqual(args["server_hostname"], call_kwargs["server_hostname"])
67+
self.assertEqual(args["http_path"], call_kwargs["http_path"])
6868
connection.close()
6969

7070
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
7171
def test_http_header_passthrough(self, mock_client_class):
7272
http_headers = [("foo", "bar")]
7373
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers)
7474

75-
call_args = mock_client_class.call_args[0][3]
76-
self.assertIn(("foo", "bar"), call_args)
75+
call_kwargs = mock_client_class.call_args[1]
76+
self.assertIn(("foo", "bar"), call_kwargs["http_headers"])
7777

7878
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
7979
def test_tls_arg_passthrough(self, mock_client_class):
@@ -95,7 +95,8 @@ def test_tls_arg_passthrough(self, mock_client_class):
9595
def test_useragent_header(self, mock_client_class):
9696
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
9797

98-
http_headers = mock_client_class.call_args[0][3]
98+
call_kwargs = mock_client_class.call_args[1]
99+
http_headers = call_kwargs["http_headers"]
99100
user_agent_header = (
100101
"User-Agent",
101102
"{}/{}".format(databricks.sql.USER_AGENT_NAME, databricks.sql.__version__),
@@ -109,7 +110,8 @@ def test_useragent_header(self, mock_client_class):
109110
databricks.sql.USER_AGENT_NAME, databricks.sql.__version__, "foobar"
110111
),
111112
)
112-
http_headers = mock_client_class.call_args[0][3]
113+
call_kwargs = mock_client_class.call_args[1]
114+
http_headers = call_kwargs["http_headers"]
113115
self.assertIn(user_agent_header_with_entry, http_headers)
114116

115117
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)

0 commit comments

Comments
 (0)