Skip to content

Commit 0fd46d4

Browse files
committed
added TelemetryExtractor, removed multithreaded tests
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent 307a8cc commit 0fd46d4

File tree

3 files changed

+105
-466
lines changed

3 files changed

+105
-466
lines changed

src/databricks/sql/client.py

Lines changed: 0 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@
3131
transform_paramstyle,
3232
ColumnTable,
3333
ColumnQueue,
34-
ArrowQueue,
35-
CloudFetchQueue,
3634
)
3735
from databricks.sql.parameters.native import (
3836
DbsqlParameterBase,
@@ -64,7 +62,6 @@
6462
HostDetails,
6563
)
6664
from databricks.sql.telemetry.latency_logger import log_latency
67-
from databricks.sql.telemetry.models.enums import ExecutionResultFormat, StatementType
6865

6966
logger = logging.getLogger(__name__)
7067

@@ -1354,39 +1351,6 @@ def setoutputsize(self, size, column=None):
13541351
"""Does nothing by default"""
13551352
pass
13561353

1357-
def get_statement_id(self) -> Optional[str]:
1358-
return self.query_id
1359-
1360-
def get_session_id_hex(self) -> Optional[str]:
1361-
return self.connection.get_session_id_hex()
1362-
1363-
def get_is_compressed(self) -> bool:
1364-
return self.connection.lz4_compression
1365-
1366-
def get_execution_result(self) -> ExecutionResultFormat:
1367-
if self.active_result_set is None:
1368-
return ExecutionResultFormat.FORMAT_UNSPECIFIED
1369-
1370-
if isinstance(self.active_result_set.results, ColumnQueue):
1371-
return ExecutionResultFormat.COLUMNAR_INLINE
1372-
elif isinstance(self.active_result_set.results, CloudFetchQueue):
1373-
return ExecutionResultFormat.EXTERNAL_LINKS
1374-
elif isinstance(self.active_result_set.results, ArrowQueue):
1375-
return ExecutionResultFormat.INLINE_ARROW
1376-
return ExecutionResultFormat.FORMAT_UNSPECIFIED
1377-
1378-
def get_retry_count(self) -> int:
1379-
if (
1380-
hasattr(self.thrift_backend, "retry_policy")
1381-
and self.thrift_backend.retry_policy
1382-
):
1383-
return len(self.thrift_backend.retry_policy.history)
1384-
return 0
1385-
1386-
def get_statement_type(self, func_name: str) -> StatementType:
1387-
# TODO: Implement this
1388-
return StatementType.SQL
1389-
13901354

13911355
class ResultSet:
13921356
def __init__(
@@ -1687,34 +1651,3 @@ def map_col_type(type_):
16871651
for column in table_schema_message.columns
16881652
]
16891653

1690-
def get_statement_id(self) -> Optional[str]:
1691-
if self.command_id:
1692-
return str(UUID(bytes=self.command_id.operationId.guid))
1693-
return None
1694-
1695-
def get_session_id_hex(self) -> Optional[str]:
1696-
return self.connection.get_session_id_hex()
1697-
1698-
def get_is_compressed(self) -> bool:
1699-
return self.lz4_compressed
1700-
1701-
def get_execution_result(self) -> ExecutionResultFormat:
1702-
if isinstance(self.results, ColumnQueue):
1703-
return ExecutionResultFormat.COLUMNAR_INLINE
1704-
elif isinstance(self.results, CloudFetchQueue):
1705-
return ExecutionResultFormat.EXTERNAL_LINKS
1706-
elif isinstance(self.results, ArrowQueue):
1707-
return ExecutionResultFormat.INLINE_ARROW
1708-
return ExecutionResultFormat.FORMAT_UNSPECIFIED
1709-
1710-
def get_statement_type(self, func_name: str) -> StatementType:
1711-
# TODO: Implement this
1712-
return StatementType.SQL
1713-
1714-
def get_retry_count(self) -> int:
1715-
if (
1716-
hasattr(self.thrift_backend, "retry_policy")
1717-
and self.thrift_backend.retry_policy
1718-
):
1719-
return len(self.thrift_backend.retry_policy.history)
1720-
return 0

src/databricks/sql/telemetry/latency_logger.py

Lines changed: 104 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,106 @@
11
import time
22
import functools
3+
from typing import Optional
34
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
45
from databricks.sql.telemetry.models.event import (
56
SqlExecutionEvent,
67
)
8+
from databricks.sql.telemetry.models.enums import ExecutionResultFormat, StatementType
9+
from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue
10+
from uuid import UUID
11+
12+
13+
class TelemetryExtractor:
14+
def __init__(self, obj):
15+
self._obj = obj
16+
17+
def __getattr__(self, name):
18+
return getattr(self._obj, name)
19+
20+
def get_session_id_hex(self): pass
21+
def get_statement_id(self): pass
22+
def get_statement_type(self): pass
23+
def get_is_compressed(self): pass
24+
def get_execution_result(self): pass
25+
def get_retry_count(self): pass
26+
27+
28+
class CursorExtractor(TelemetryExtractor):
29+
def get_statement_id(self) -> Optional[str]:
30+
return self.query_id
31+
32+
def get_session_id_hex(self) -> Optional[str]:
33+
return self.connection.get_session_id_hex()
34+
35+
def get_is_compressed(self) -> bool:
36+
return self.connection.lz4_compression
37+
38+
def get_execution_result(self) -> ExecutionResultFormat:
39+
if self.active_result_set is None:
40+
return ExecutionResultFormat.FORMAT_UNSPECIFIED
41+
42+
if isinstance(self.active_result_set.results, ColumnQueue):
43+
return ExecutionResultFormat.COLUMNAR_INLINE
44+
elif isinstance(self.active_result_set.results, CloudFetchQueue):
45+
return ExecutionResultFormat.EXTERNAL_LINKS
46+
elif isinstance(self.active_result_set.results, ArrowQueue):
47+
return ExecutionResultFormat.INLINE_ARROW
48+
return ExecutionResultFormat.FORMAT_UNSPECIFIED
49+
50+
def get_retry_count(self) -> int:
51+
if (
52+
hasattr(self.thrift_backend, "retry_policy")
53+
and self.thrift_backend.retry_policy
54+
):
55+
return len(self.thrift_backend.retry_policy.history)
56+
return 0
57+
58+
def get_statement_type(self: str) -> StatementType:
59+
# TODO: Implement this
60+
return StatementType.SQL
61+
62+
63+
class ResultSetExtractor(TelemetryExtractor):
64+
def get_statement_id(self) -> Optional[str]:
65+
if self.command_id:
66+
return str(UUID(bytes=self.command_id.operationId.guid))
67+
return None
68+
69+
def get_session_id_hex(self) -> Optional[str]:
70+
return self.connection.get_session_id_hex()
71+
72+
def get_is_compressed(self) -> bool:
73+
return self.lz4_compressed
74+
75+
def get_execution_result(self) -> ExecutionResultFormat:
76+
if isinstance(self.results, ColumnQueue):
77+
return ExecutionResultFormat.COLUMNAR_INLINE
78+
elif isinstance(self.results, CloudFetchQueue):
79+
return ExecutionResultFormat.EXTERNAL_LINKS
80+
elif isinstance(self.results, ArrowQueue):
81+
return ExecutionResultFormat.INLINE_ARROW
82+
return ExecutionResultFormat.FORMAT_UNSPECIFIED
83+
84+
def get_statement_type(self: str) -> StatementType:
85+
# TODO: Implement this
86+
return StatementType.SQL
87+
88+
def get_retry_count(self) -> int:
89+
if (
90+
hasattr(self.thrift_backend, "retry_policy")
91+
and self.thrift_backend.retry_policy
92+
):
93+
return len(self.thrift_backend.retry_policy.history)
94+
return 0
95+
96+
97+
def get_extractor(obj):
98+
if obj.__class__.__name__ == 'Cursor':
99+
return CursorExtractor(obj)
100+
elif obj.__class__.__name__ == 'ResultSet':
101+
return ResultSetExtractor(obj)
102+
else:
103+
return TelemetryExtractor(obj)
7104

8105

9106
def log_latency():
@@ -19,14 +116,15 @@ def wrapper(self, *args, **kwargs):
19116
end_time = time.perf_counter()
20117
duration_ms = int((end_time - start_time) * 1000)
21118

22-
session_id_hex = self.get_session_id_hex()
23-
statement_id = self.get_statement_id()
119+
extractor = get_extractor(self)
120+
session_id_hex = extractor.get_session_id_hex()
121+
statement_id = extractor.get_statement_id()
24122

25123
sql_exec_event = SqlExecutionEvent(
26-
statement_type=self.get_statement_type(func.__name__),
27-
is_compressed=self.get_is_compressed(),
28-
execution_result=self.get_execution_result(),
29-
retry_count=self.get_retry_count(),
124+
statement_type=extractor.get_statement_type(),
125+
is_compressed=extractor.get_is_compressed(),
126+
execution_result=extractor.get_execution_result(),
127+
retry_count=extractor.get_retry_count(),
30128
)
31129

32130
telemetry_client = TelemetryClientFactory.get_telemetry_client(

0 commit comments

Comments
 (0)