1
1
import time
2
2
import functools
3
+ from typing import Optional
3
4
from databricks .sql .telemetry .telemetry_client import TelemetryClientFactory
4
5
from databricks .sql .telemetry .models .event import (
5
6
SqlExecutionEvent ,
6
7
)
8
+ from databricks .sql .telemetry .models .enums import ExecutionResultFormat , StatementType
9
+ from databricks .sql .utils import ColumnQueue , CloudFetchQueue , ArrowQueue
10
+ from uuid import UUID
11
+
12
+
13
+ class TelemetryExtractor :
14
+ def __init__ (self , obj ):
15
+ self ._obj = obj
16
+
17
+ def __getattr__ (self , name ):
18
+ return getattr (self ._obj , name )
19
+
20
+ def get_session_id_hex (self ): pass
21
+ def get_statement_id (self ): pass
22
+ def get_statement_type (self ): pass
23
+ def get_is_compressed (self ): pass
24
+ def get_execution_result (self ): pass
25
+ def get_retry_count (self ): pass
26
+
27
+
28
+ class CursorExtractor (TelemetryExtractor ):
29
+ def get_statement_id (self ) -> Optional [str ]:
30
+ return self .query_id
31
+
32
+ def get_session_id_hex (self ) -> Optional [str ]:
33
+ return self .connection .get_session_id_hex ()
34
+
35
+ def get_is_compressed (self ) -> bool :
36
+ return self .connection .lz4_compression
37
+
38
+ def get_execution_result (self ) -> ExecutionResultFormat :
39
+ if self .active_result_set is None :
40
+ return ExecutionResultFormat .FORMAT_UNSPECIFIED
41
+
42
+ if isinstance (self .active_result_set .results , ColumnQueue ):
43
+ return ExecutionResultFormat .COLUMNAR_INLINE
44
+ elif isinstance (self .active_result_set .results , CloudFetchQueue ):
45
+ return ExecutionResultFormat .EXTERNAL_LINKS
46
+ elif isinstance (self .active_result_set .results , ArrowQueue ):
47
+ return ExecutionResultFormat .INLINE_ARROW
48
+ return ExecutionResultFormat .FORMAT_UNSPECIFIED
49
+
50
+ def get_retry_count (self ) -> int :
51
+ if (
52
+ hasattr (self .thrift_backend , "retry_policy" )
53
+ and self .thrift_backend .retry_policy
54
+ ):
55
+ return len (self .thrift_backend .retry_policy .history )
56
+ return 0
57
+
58
+ def get_statement_type (self : str ) -> StatementType :
59
+ # TODO: Implement this
60
+ return StatementType .SQL
61
+
62
+
63
+ class ResultSetExtractor (TelemetryExtractor ):
64
+ def get_statement_id (self ) -> Optional [str ]:
65
+ if self .command_id :
66
+ return str (UUID (bytes = self .command_id .operationId .guid ))
67
+ return None
68
+
69
+ def get_session_id_hex (self ) -> Optional [str ]:
70
+ return self .connection .get_session_id_hex ()
71
+
72
+ def get_is_compressed (self ) -> bool :
73
+ return self .lz4_compressed
74
+
75
+ def get_execution_result (self ) -> ExecutionResultFormat :
76
+ if isinstance (self .results , ColumnQueue ):
77
+ return ExecutionResultFormat .COLUMNAR_INLINE
78
+ elif isinstance (self .results , CloudFetchQueue ):
79
+ return ExecutionResultFormat .EXTERNAL_LINKS
80
+ elif isinstance (self .results , ArrowQueue ):
81
+ return ExecutionResultFormat .INLINE_ARROW
82
+ return ExecutionResultFormat .FORMAT_UNSPECIFIED
83
+
84
+ def get_statement_type (self : str ) -> StatementType :
85
+ # TODO: Implement this
86
+ return StatementType .SQL
87
+
88
+ def get_retry_count (self ) -> int :
89
+ if (
90
+ hasattr (self .thrift_backend , "retry_policy" )
91
+ and self .thrift_backend .retry_policy
92
+ ):
93
+ return len (self .thrift_backend .retry_policy .history )
94
+ return 0
95
+
96
+
97
+ def get_extractor (obj ):
98
+ if obj .__class__ .__name__ == 'Cursor' :
99
+ return CursorExtractor (obj )
100
+ elif obj .__class__ .__name__ == 'ResultSet' :
101
+ return ResultSetExtractor (obj )
102
+ else :
103
+ return TelemetryExtractor (obj )
7
104
8
105
9
106
def log_latency ():
@@ -19,14 +116,15 @@ def wrapper(self, *args, **kwargs):
19
116
end_time = time .perf_counter ()
20
117
duration_ms = int ((end_time - start_time ) * 1000 )
21
118
22
- session_id_hex = self .get_session_id_hex ()
23
- statement_id = self .get_statement_id ()
119
+ extractor = get_extractor (self )
120
+ session_id_hex = extractor .get_session_id_hex ()
121
+ statement_id = extractor .get_statement_id ()
24
122
25
123
sql_exec_event = SqlExecutionEvent (
26
- statement_type = self .get_statement_type (func . __name__ ),
27
- is_compressed = self .get_is_compressed (),
28
- execution_result = self .get_execution_result (),
29
- retry_count = self .get_retry_count (),
124
+ statement_type = extractor .get_statement_type (),
125
+ is_compressed = extractor .get_is_compressed (),
126
+ execution_result = extractor .get_execution_result (),
127
+ retry_count = extractor .get_retry_count (),
30
128
)
31
129
32
130
telemetry_client = TelemetryClientFactory .get_telemetry_client (
0 commit comments