Skip to content

Commit ef4ca13

Browse files
committed
added test for export_latency_log, made mock of thrift backend with retry policy
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent 14433c4 commit ef4ca13

File tree

5 files changed

+86
-18
lines changed

5 files changed

+86
-18
lines changed

src/databricks/sql/client.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,7 +1185,6 @@ def columns(
11851185
)
11861186
return self
11871187

1188-
@log_latency()
11891188
def fetchall(self) -> List[Row]:
11901189
"""
11911190
Fetch all (remaining) rows of a query result, returning them as a sequence of sequences.
@@ -1219,7 +1218,6 @@ def fetchone(self) -> Optional[Row]:
12191218
session_id_hex=self.connection.get_session_id_hex(),
12201219
)
12211220

1222-
@log_latency()
12231221
def fetchmany(self, size: int) -> List[Row]:
12241222
"""
12251223
Fetch the next set of rows of a query result, returning a sequence of sequences (e.g. a
@@ -1245,7 +1243,6 @@ def fetchmany(self, size: int) -> List[Row]:
12451243
session_id_hex=self.connection.get_session_id_hex(),
12461244
)
12471245

1248-
@log_latency()
12491246
def fetchall_arrow(self) -> "pyarrow.Table":
12501247
self._check_not_closed()
12511248
if self.active_result_set:
@@ -1256,7 +1253,6 @@ def fetchall_arrow(self) -> "pyarrow.Table":
12561253
session_id_hex=self.connection.get_session_id_hex(),
12571254
)
12581255

1259-
@log_latency()
12601256
def fetchmany_arrow(self, size) -> "pyarrow.Table":
12611257
self._check_not_closed()
12621258
if self.active_result_set:
@@ -1380,7 +1376,11 @@ def get_execution_result(self) -> ExecutionResultFormat:
13801376
return ExecutionResultFormat.FORMAT_UNSPECIFIED
13811377

13821378
def get_retry_count(self) -> int:
1383-
# return len(self.thrift_backend.retry_policy.history)
1379+
if (
1380+
hasattr(self.thrift_backend, "retry_policy")
1381+
and self.thrift_backend.retry_policy
1382+
):
1383+
return len(self.thrift_backend.retry_policy.history)
13841384
return 0
13851385

13861386
def get_statement_type(self, func_name: str) -> StatementType:
@@ -1712,5 +1712,9 @@ def get_statement_type(self, func_name: str) -> StatementType:
17121712
return StatementType.SQL
17131713

17141714
def get_retry_count(self) -> int:
1715-
# return len(self.thrift_backend.retry_policy.history)
1715+
if (
1716+
hasattr(self.thrift_backend, "retry_policy")
1717+
and self.thrift_backend.retry_policy
1718+
):
1719+
return len(self.thrift_backend.retry_policy.history)
17161720
return 0

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,7 @@ def export_failure_log(self, error_name, error_message):
113113
raise NotImplementedError("Subclasses must implement export_failure_log")
114114

115115
@abstractmethod
116-
def export_latency_log(
117-
self, latency_ms, sql_execution_event, sql_statement_id=None
118-
):
116+
def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id):
119117
raise NotImplementedError("Subclasses must implement export_latency_log")
120118

121119
@abstractmethod
@@ -310,9 +308,7 @@ def export_failure_log(self, error_name, error_message):
310308
except Exception as e:
311309
logger.debug("Failed to export failure log: %s", e)
312310

313-
def export_latency_log(
314-
self, latency_ms, sql_execution_event, sql_statement_id=None
315-
):
311+
def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id):
316312
logger.debug("Exporting latency log for connection %s", self._session_id_hex)
317313
try:
318314
telemetry_frontend_log = TelemetryFrontendLog(

tests/unit/test_client.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ def new(cls):
3939
cls.apply_property_to_mock(ThriftBackendMock, staging_allowed_local_path=None)
4040
MockTExecuteStatementResp = MagicMock(spec=TExecuteStatementResp())
4141

42+
# Mock retry_policy with history attribute
43+
mock_retry_policy = Mock()
44+
mock_retry_policy.history = []
45+
cls.apply_property_to_mock(ThriftBackendMock, retry_policy=mock_retry_policy)
46+
4247
cls.apply_property_to_mock(
4348
MockTExecuteStatementResp,
4449
description=None,
@@ -70,6 +75,15 @@ def apply_property_to_mock(self, mock_obj, **kwargs):
7075
prop = PropertyMock(**kwargs)
7176
setattr(type(mock_obj), key, prop)
7277

78+
@classmethod
79+
def mock_thrift_backend_with_retry_policy(cls): # Required for log_latency() decorator
80+
"""Create a simple thrift_backend mock with retry_policy for basic tests."""
81+
mock_thrift_backend = Mock()
82+
mock_retry_policy = Mock()
83+
mock_retry_policy.history = []
84+
mock_thrift_backend.retry_policy = mock_retry_policy
85+
return mock_thrift_backend
86+
7387

7488
class ClientTestSuite(unittest.TestCase):
7589
"""
@@ -319,7 +333,7 @@ def test_executing_multiple_commands_uses_the_most_recent_command(
319333
mock_result_sets[1].fetchall.assert_called_once_with()
320334

321335
def test_closed_cursor_doesnt_allow_operations(self):
322-
cursor = client.Cursor(Mock(), Mock())
336+
cursor = client.Cursor(Mock(), ThriftBackendMockFactory.mock_thrift_backend_with_retry_policy())
323337
cursor.close()
324338

325339
with self.assertRaises(Error) as e:
@@ -399,7 +413,7 @@ def test_get_schemas_parameters_passed_to_thrift_backend(self, mock_thrift_backe
399413
for req_args in req_args_combinations:
400414
req_args = {k: v for k, v in req_args.items() if v != "NOT_SET"}
401415
with self.subTest(req_args=req_args):
402-
mock_thrift_backend = Mock()
416+
mock_thrift_backend = ThriftBackendMockFactory.mock_thrift_backend_with_retry_policy()
403417

404418
cursor = client.Cursor(Mock(), mock_thrift_backend)
405419
cursor.schemas(**req_args)
@@ -422,7 +436,7 @@ def test_get_tables_parameters_passed_to_thrift_backend(self, mock_thrift_backen
422436
for req_args in req_args_combinations:
423437
req_args = {k: v for k, v in req_args.items() if v != "NOT_SET"}
424438
with self.subTest(req_args=req_args):
425-
mock_thrift_backend = Mock()
439+
mock_thrift_backend = ThriftBackendMockFactory.mock_thrift_backend_with_retry_policy()
426440

427441
cursor = client.Cursor(Mock(), mock_thrift_backend)
428442
cursor.tables(**req_args)
@@ -445,7 +459,7 @@ def test_get_columns_parameters_passed_to_thrift_backend(self, mock_thrift_backe
445459
for req_args in req_args_combinations:
446460
req_args = {k: v for k, v in req_args.items() if v != "NOT_SET"}
447461
with self.subTest(req_args=req_args):
448-
mock_thrift_backend = Mock()
462+
mock_thrift_backend = ThriftBackendMockFactory.mock_thrift_backend_with_retry_policy()
449463

450464
cursor = client.Cursor(Mock(), mock_thrift_backend)
451465
cursor.columns(**req_args)

tests/unit/test_fetches.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,15 @@ def make_arrow_queue(batch):
3131
_, table = FetchTests.make_arrow_table(batch)
3232
queue = ArrowQueue(table, len(batch))
3333
return queue
34+
35+
@classmethod
36+
def mock_thrift_backend_with_retry_policy(cls): # Required for log_latency() decorator
37+
"""Create a simple thrift_backend mock with retry_policy for basic tests."""
38+
mock_thrift_backend = Mock()
39+
mock_retry_policy = Mock()
40+
mock_retry_policy.history = []
41+
mock_thrift_backend.retry_policy = mock_retry_policy
42+
return mock_thrift_backend
3443

3544
@staticmethod
3645
def make_dummy_result_set_from_initial_results(initial_results):
@@ -39,7 +48,7 @@ def make_dummy_result_set_from_initial_results(initial_results):
3948
arrow_queue = ArrowQueue(arrow_table, len(initial_results), 0)
4049
rs = client.ResultSet(
4150
connection=Mock(),
42-
thrift_backend=None,
51+
thrift_backend=FetchTests.mock_thrift_backend_with_retry_policy(),
4352
execute_response=ExecuteResponse(
4453
status=None,
4554
has_been_closed_server_side=True,
@@ -79,7 +88,7 @@ def fetch_results(
7988

8089
return results, batch_index < len(batch_list)
8190

82-
mock_thrift_backend = Mock()
91+
mock_thrift_backend = FetchTests.mock_thrift_backend_with_retry_policy()
8392
mock_thrift_backend.fetch_results = fetch_results
8493
num_cols = len(batch_list[0][0]) if batch_list and batch_list[0] else 0
8594

tests/unit/test_telemetry.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,12 @@ def test_export_failure_log(self, noop_telemetry_client):
9696
error_name="TestError", error_message="Test error message"
9797
)
9898

99+
def test_export_latency_log(self, noop_telemetry_client):
100+
"""Test that export_latency_log does nothing."""
101+
noop_telemetry_client.export_latency_log(
102+
latency_ms=100, sql_execution_event="EXECUTE_STATEMENT", sql_statement_id="test-id"
103+
)
104+
99105
def test_close(self, noop_telemetry_client):
100106
"""Test that close does nothing."""
101107
noop_telemetry_client.close()
@@ -181,6 +187,40 @@ def test_export_failure_log(
181187

182188
client._export_event.assert_called_once_with(mock_frontend_log.return_value)
183189

190+
@patch("databricks.sql.telemetry.telemetry_client.TelemetryFrontendLog")
191+
@patch("databricks.sql.telemetry.telemetry_client.TelemetryHelper.get_driver_system_configuration")
192+
@patch("databricks.sql.telemetry.telemetry_client.uuid.uuid4")
193+
@patch("databricks.sql.telemetry.telemetry_client.time.time")
194+
def test_export_latency_log(
195+
self,
196+
mock_time,
197+
mock_uuid4,
198+
mock_get_driver_config,
199+
mock_frontend_log,
200+
telemetry_client_setup
201+
):
202+
"""Test exporting latency telemetry log."""
203+
mock_time.return_value = 3000
204+
mock_uuid4.return_value = "test-latency-uuid"
205+
mock_get_driver_config.return_value = "test-driver-config"
206+
mock_frontend_log.return_value = MagicMock()
207+
208+
client = telemetry_client_setup["client"]
209+
client._export_event = MagicMock()
210+
211+
client._driver_connection_params = "test-connection-params"
212+
client._user_agent = "test-user-agent"
213+
214+
latency_ms = 150
215+
sql_execution_event = "test-execution-event"
216+
sql_statement_id = "test-statement-id"
217+
218+
client.export_latency_log(latency_ms, sql_execution_event, sql_statement_id)
219+
220+
mock_frontend_log.assert_called_once()
221+
222+
client._export_event.assert_called_once_with(mock_frontend_log.return_value)
223+
184224
def test_export_event(self, telemetry_client_setup):
185225
"""Test exporting an event."""
186226
client = telemetry_client_setup["client"]
@@ -311,6 +351,11 @@ def test_telemetry_client_exception_handling(self, telemetry_client_setup):
311351
# Should not raise exception
312352
client.export_failure_log("TestError", "Test error message")
313353

354+
# Test export_latency_log with exception
355+
with patch.object(client, '_export_event', side_effect=Exception("Test error")):
356+
# Should not raise exception
357+
client.export_latency_log(100, "EXECUTE_STATEMENT", "test-statement-id")
358+
314359
# Test _send_telemetry with exception
315360
with patch.object(client._executor, 'submit', side_effect=Exception("Test error")):
316361
# Should not raise exception

0 commit comments

Comments
 (0)