Skip to content

Commit 27295c2

Browse files
committed
statement type, unit test fix
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent 4d56141 commit 27295c2

File tree

3 files changed

+34
-43
lines changed

3 files changed

+34
-43
lines changed

src/databricks/sql/client.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
HostDetails,
6363
)
6464
from databricks.sql.telemetry.latency_logger import log_latency
65+
from databricks.sql.telemetry.models.enums import StatementType
6566

6667
logger = logging.getLogger(__name__)
6768

@@ -827,7 +828,7 @@ def _handle_staging_remove(
827828
session_id_hex=self.connection.get_session_id_hex(),
828829
)
829830

830-
@log_latency()
831+
@log_latency(StatementType.SQL)
831832
def execute(
832833
self,
833834
operation: str,
@@ -918,7 +919,7 @@ def execute(
918919

919920
return self
920921

921-
@log_latency()
922+
@log_latency(StatementType.SQL)
922923
def execute_async(
923924
self,
924925
operation: str,
@@ -1044,7 +1045,7 @@ def executemany(self, operation, seq_of_parameters):
10441045
self.execute(operation, parameters)
10451046
return self
10461047

1047-
@log_latency()
1048+
@log_latency(StatementType.METADATA)
10481049
def catalogs(self) -> "Cursor":
10491050
"""
10501051
Get all available catalogs.
@@ -1068,7 +1069,7 @@ def catalogs(self) -> "Cursor":
10681069
)
10691070
return self
10701071

1071-
@log_latency()
1072+
@log_latency(StatementType.METADATA)
10721073
def schemas(
10731074
self, catalog_name: Optional[str] = None, schema_name: Optional[str] = None
10741075
) -> "Cursor":
@@ -1097,7 +1098,7 @@ def schemas(
10971098
)
10981099
return self
10991100

1100-
@log_latency()
1101+
@log_latency(StatementType.METADATA)
11011102
def tables(
11021103
self,
11031104
catalog_name: Optional[str] = None,
@@ -1133,7 +1134,7 @@ def tables(
11331134
)
11341135
return self
11351136

1136-
@log_latency()
1137+
@log_latency(StatementType.METADATA)
11371138
def columns(
11381139
self,
11391140
catalog_name: Optional[str] = None,
@@ -1444,6 +1445,7 @@ def _convert_arrow_table(self, table):
14441445
def rownumber(self):
14451446
return self._next_row_index
14461447

1448+
@log_latency()
14471449
def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
14481450
"""
14491451
Fetch the next set of rows of a query result, returning a PyArrow table.
@@ -1486,6 +1488,7 @@ def merge_columnar(self, result1, result2):
14861488
]
14871489
return ColumnTable(merged_result, result1.column_names)
14881490

1491+
@log_latency()
14891492
def fetchmany_columnar(self, size: int):
14901493
"""
14911494
Fetch the next set of rows of a query result, returning a Columnar Table.
@@ -1511,6 +1514,7 @@ def fetchmany_columnar(self, size: int):
15111514

15121515
return results
15131516

1517+
@log_latency()
15141518
def fetchall_arrow(self) -> "pyarrow.Table":
15151519
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
15161520
results = self.results.remaining_rows()
@@ -1537,6 +1541,7 @@ def fetchall_arrow(self) -> "pyarrow.Table":
15371541
return pyarrow.Table.from_pydict(data)
15381542
return results
15391543

1544+
@log_latency()
15401545
def fetchall_columnar(self):
15411546
"""Fetch all (remaining) rows of a query result, returning them as a Columnar table."""
15421547
results = self.results.remaining_rows()

src/databricks/sql/telemetry/latency_logger.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,6 @@ def get_session_id_hex(self):
4545
def get_statement_id(self):
4646
pass
4747

48-
def get_statement_type(self):
49-
pass
50-
5148
def get_is_compressed(self):
5249
pass
5350

@@ -95,10 +92,6 @@ def get_retry_count(self) -> int:
9592
return len(self.thrift_backend.retry_policy.history)
9693
return 0
9794

98-
def get_statement_type(self) -> StatementType:
99-
# TODO: Implement this
100-
return StatementType.SQL
101-
10295

10396
class ResultSetExtractor(TelemetryExtractor):
10497
"""
@@ -128,10 +121,6 @@ def get_execution_result(self) -> ExecutionResultFormat:
128121
return ExecutionResultFormat.INLINE_ARROW
129122
return ExecutionResultFormat.FORMAT_UNSPECIFIED
130123

131-
def get_statement_type(self) -> StatementType:
132-
# TODO: Implement this
133-
return StatementType.SQL
134-
135124
def get_retry_count(self) -> int:
136125
if (
137126
hasattr(self.thrift_backend, "retry_policy")
@@ -166,7 +155,7 @@ def get_extractor(obj):
166155
raise NotImplementedError(f"No extractor found for {obj.__class__.__name__}")
167156

168157

169-
def log_latency():
158+
def log_latency(statement_type: StatementType = StatementType.NONE):
170159
"""
171160
Decorator for logging execution latency and telemetry information.
172161
@@ -180,17 +169,15 @@ def log_latency():
180169
- Creates a SqlExecutionEvent with execution details
181170
- Sends the telemetry data asynchronously via TelemetryClient
182171
172+
Args:
173+
statement_type (StatementType): The type of SQL statement being executed.
174+
183175
Usage:
184-
@log_latency()
176+
@log_latency(StatementType.SQL)
185177
def execute(self, query):
186178
# Method implementation
187179
pass
188180
189-
@log_latency()
190-
def fetchall(self):
191-
# Method implementation
192-
pass
193-
194181
Returns:
195182
function: A decorator that wraps methods to add latency logging.
196183
@@ -216,7 +203,7 @@ def wrapper(self, *args, **kwargs):
216203
statement_id = extractor.get_statement_id()
217204

218205
sql_exec_event = SqlExecutionEvent(
219-
statement_type=extractor.get_statement_type(),
206+
statement_type=statement_type,
220207
is_compressed=extractor.get_is_compressed(),
221208
execution_result=extractor.get_execution_result(),
222209
retry_count=extractor.get_retry_count(),

tests/unit/test_client.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,19 @@
88
from datetime import datetime, date
99
from uuid import UUID
1010

11+
def noop_log_latency_decorator(*args, **kwargs):
12+
"""
13+
This is a no-op decorator. It is used to patch the log_latency decorator
14+
during tests, so that the tests for client logic are not affected by the
15+
telemetry logging logic. It accepts any arguments and returns a decorator
16+
that returns the original function unmodified.
17+
"""
18+
def decorator(func):
19+
return func
20+
return decorator
21+
22+
patch('databricks.sql.telemetry.latency_logger.log_latency', new=noop_log_latency_decorator).start()
23+
1124
from databricks.sql.thrift_api.TCLIService.ttypes import (
1225
TOpenSessionResp,
1326
TExecuteStatementResp,
@@ -38,11 +51,6 @@ def new(cls):
3851
cls.apply_property_to_mock(ThriftBackendMock, staging_allowed_local_path=None)
3952
MockTExecuteStatementResp = MagicMock(spec=TExecuteStatementResp())
4053

41-
# Mock retry_policy with history attribute
42-
mock_retry_policy = Mock()
43-
mock_retry_policy.history = []
44-
cls.apply_property_to_mock(ThriftBackendMock, retry_policy=mock_retry_policy)
45-
4654
cls.apply_property_to_mock(
4755
MockTExecuteStatementResp,
4856
description=None,
@@ -74,15 +82,6 @@ def apply_property_to_mock(self, mock_obj, **kwargs):
7482
prop = PropertyMock(**kwargs)
7583
setattr(type(mock_obj), key, prop)
7684

77-
@classmethod
78-
def mock_thrift_backend_with_retry_policy(cls): # Required for log_latency() decorator
79-
"""Create a simple thrift_backend mock with retry_policy for basic tests."""
80-
mock_thrift_backend = Mock()
81-
mock_retry_policy = Mock()
82-
mock_retry_policy.history = []
83-
mock_thrift_backend.retry_policy = mock_retry_policy
84-
return mock_thrift_backend
85-
8685

8786
class ClientTestSuite(unittest.TestCase):
8887
"""
@@ -332,7 +331,7 @@ def test_executing_multiple_commands_uses_the_most_recent_command(
332331
mock_result_sets[1].fetchall.assert_called_once_with()
333332

334333
def test_closed_cursor_doesnt_allow_operations(self):
335-
cursor = client.Cursor(Mock(), ThriftBackendMockFactory.mock_thrift_backend_with_retry_policy())
334+
cursor = client.Cursor(Mock(), Mock())
336335
cursor.close()
337336

338337
with self.assertRaises(Error) as e:
@@ -394,7 +393,7 @@ def test_get_schemas_parameters_passed_to_thrift_backend(self, mock_thrift_backe
394393
for req_args in req_args_combinations:
395394
req_args = {k: v for k, v in req_args.items() if v != "NOT_SET"}
396395
with self.subTest(req_args=req_args):
397-
mock_thrift_backend = ThriftBackendMockFactory.mock_thrift_backend_with_retry_policy()
396+
mock_thrift_backend = Mock()
398397

399398
cursor = client.Cursor(Mock(), mock_thrift_backend)
400399
cursor.schemas(**req_args)
@@ -417,7 +416,7 @@ def test_get_tables_parameters_passed_to_thrift_backend(self, mock_thrift_backen
417416
for req_args in req_args_combinations:
418417
req_args = {k: v for k, v in req_args.items() if v != "NOT_SET"}
419418
with self.subTest(req_args=req_args):
420-
mock_thrift_backend = ThriftBackendMockFactory.mock_thrift_backend_with_retry_policy()
419+
mock_thrift_backend = Mock()
421420

422421
cursor = client.Cursor(Mock(), mock_thrift_backend)
423422
cursor.tables(**req_args)
@@ -440,7 +439,7 @@ def test_get_columns_parameters_passed_to_thrift_backend(self, mock_thrift_backe
440439
for req_args in req_args_combinations:
441440
req_args = {k: v for k, v in req_args.items() if v != "NOT_SET"}
442441
with self.subTest(req_args=req_args):
443-
mock_thrift_backend = ThriftBackendMockFactory.mock_thrift_backend_with_retry_policy()
442+
mock_thrift_backend = Mock()
444443

445444
cursor = client.Cursor(Mock(), mock_thrift_backend)
446445
cursor.columns(**req_args)

0 commit comments

Comments
 (0)