Skip to content

Commit 6b3436f

Browse files
generalise open session, fix session tests to consider positional args
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 160ba9f commit 6b3436f

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

src/databricks/sql/session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,9 @@ def get_protocol_version(sessionId: SessionId):
9898
Since the sessionHandle will sometimes have a serverProtocolVersion, it takes
9999
precedence over the serverProtocolVersion defined in the OpenSessionResponse.
100100
"""
101-
if sessionId.backend_type != BackendType.THRIFT:
101+
if session_id.backend_type != BackendType.THRIFT:
102102
return None
103-
session_handle = sessionId.to_thrift_handle()
103+
session_handle = session_id.to_thrift_handle()
104104
if (
105105
session_handle
106106
and hasattr(session_handle, "serverProtocolVersion")

tests/unit/test_session.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -146,18 +146,18 @@ def test_socket_timeout_passthrough(self, mock_client_class):
146146
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
147147
def test_configuration_passthrough(self, mock_client_class):
148148
mock_session_config = Mock()
149-
149+
150150
# Create a mock SessionId that will be returned by open_session
151151
mock_session_id = SessionId(BackendType.THRIFT, b"\x22", b"\x33")
152152
mock_client_class.return_value.open_session.return_value = mock_session_id
153-
153+
154154
databricks.sql.connect(
155155
session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS
156156
)
157157

158-
# Check that open_session was called with the correct session_configuration
159-
call_kwargs = mock_client_class.return_value.open_session.call_args[1]
160-
self.assertEqual(call_kwargs["session_configuration"], mock_session_config)
158+
# Check that open_session was called with the correct session_configuration as first positional argument
159+
call_args = mock_client_class.return_value.open_session.call_args[0]
160+
self.assertEqual(call_args[0], mock_session_config)
161161

162162
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
163163
def test_initial_namespace_passthrough(self, mock_client_class):
@@ -171,11 +171,11 @@ def test_initial_namespace_passthrough(self, mock_client_class):
171171
databricks.sql.connect(
172172
**self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem
173173
)
174-
175-
# Check that open_session was called with the correct catalog and schema
176-
call_kwargs = mock_client_class.return_value.open_session.call_args[1]
177-
self.assertEqual(call_kwargs["catalog"], mock_cat)
178-
self.assertEqual(call_kwargs["schema"], mock_schem)
174+
175+
# Check that open_session was called with the correct catalog and schema as positional arguments
176+
call_args = mock_client_class.return_value.open_session.call_args[0]
177+
self.assertEqual(call_args[1], mock_cat) # catalog is second positional argument
178+
self.assertEqual(call_args[2], mock_schem) # schema is third positional argument
179179

180180
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
181181
def test_finalizer_closes_abandoned_connection(self, mock_client_class):

0 commit comments

Comments
 (0)