Skip to content

Commit 0ea69d9

Browse files
committed
formatting
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent facd588 commit 0ea69d9

File tree

2 files changed

+63
-66
lines changed

2 files changed

+63
-66
lines changed

src/databricks/sql/client.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
TSparkParameter,
5050
TOperationState,
5151
)
52-
from databricks.sql.telemetry.telemetry_client import telemetry_manager
52+
from databricks.sql.telemetry.telemetry_client import telemetry_client_factory
5353

5454

5555
logger = logging.getLogger(__name__)
@@ -297,29 +297,26 @@ def read(self) -> Optional[OAuthToken]:
297297
self.use_inline_params = self._set_use_inline_params_with_warning(
298298
kwargs.get("use_inline_params", False)
299299
)
300-
300+
301301
telemetry_kwargs = {
302302
"auth_provider": auth_provider,
303-
"is_authenticated": True, # TODO: Add authentication logic later
303+
"is_authenticated": True, # TODO: Add authentication logic later
304304
"user_agent": useragent_header,
305-
"host_url": server_hostname
305+
"host_url": server_hostname,
306306
}
307-
telemetry_manager.initialize_telemetry_client(
307+
self.telemetry_client = telemetry_client_factory.get_telemetry_client(
308308
telemetry_enabled=self.telemetry_enabled,
309309
batch_size=telemetry_batch_size,
310310
connection_uuid=self.get_session_id_hex(),
311-
**telemetry_kwargs
311+
**telemetry_kwargs,
312312
)
313313

314314
intial_telmetry_kwargs = {
315315
"http_path": http_path,
316316
"port": self.port,
317317
"socket_timeout": kwargs.get("_socket_timeout", None),
318318
}
319-
telemetry_manager.export_initial_telemetry_log(
320-
connection_uuid=self.get_session_id_hex(),
321-
**intial_telmetry_kwargs
322-
)
319+
self.telemetry_client.export_initial_telemetry_log(**intial_telmetry_kwargs)
323320

324321
def _set_use_inline_params_with_warning(self, value: Union[bool, str]):
325322
"""Valid values are True, False, and "silent"
@@ -457,7 +454,7 @@ def _close(self, close_cursors=True) -> None:
457454

458455
self.open = False
459456

460-
telemetry_manager.close_telemetry_client(self.get_session_id_hex())
457+
self.telemetry_client.close()
461458

462459
def commit(self):
463460
"""No-op because Databricks does not support transactions"""

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 55 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,36 @@
2020
import platform
2121
import uuid
2222
import locale
23+
from abc import ABC, abstractmethod
2324

2425

25-
class TelemetryClient:
26-
def __init__(
27-
self,
28-
telemetry_enabled,
29-
batch_size,
30-
connection_uuid,
31-
**kwargs
32-
):
26+
class BaseTelemetryClient(ABC):
27+
@abstractmethod
28+
def export_initial_telemetry_log(self, **kwargs):
29+
pass
30+
31+
@abstractmethod
32+
def close(self):
33+
pass
34+
35+
36+
class NoopTelemetryClient(BaseTelemetryClient):
37+
_instance = None
38+
39+
def __new__(cls):
40+
if cls._instance is None:
41+
cls._instance = super(NoopTelemetryClient, cls).__new__(cls)
42+
return cls._instance
43+
44+
def export_initial_telemetry_log(self, **kwargs):
45+
pass
46+
47+
def close(self):
48+
pass
49+
50+
51+
class TelemetryClient(BaseTelemetryClient):
52+
def __init__(self, telemetry_enabled, batch_size, connection_uuid, **kwargs):
3353
self.telemetry_enabled = telemetry_enabled
3454
self.batch_size = batch_size
3555
self.connection_uuid = connection_uuid
@@ -55,11 +75,12 @@ def flush(self):
5575
self.events_batch = []
5676

5777
if events_to_flush:
58-
telemetry_manager._send_telemetry(events_to_flush, self.host_url, self.is_authenticated, self.auth_provider)
59-
60-
def close(self):
61-
"""Flush remaining events before closing"""
62-
self.flush()
78+
telemetry_client_factory._send_telemetry(
79+
events_to_flush,
80+
self.host_url,
81+
self.is_authenticated,
82+
self.auth_provider,
83+
)
6384

6485
def export_initial_telemetry_log(self, **kwargs):
6586
http_path = kwargs.get("http_path", None)
@@ -94,35 +115,28 @@ def export_initial_telemetry_log(self, **kwargs):
94115
entry=FrontendLogEntry(
95116
sql_driver_log=TelemetryEvent(
96117
session_id=self.connection_uuid,
97-
system_configuration=TelemetryManager.getDriverSystemConfiguration(),
118+
system_configuration=telemetry_client_factory.getDriverSystemConfiguration(),
98119
driver_connection_params=self.DriverConnectionParameters,
99120
)
100121
),
101122
)
102123

103124
self.export_event(telemetry_frontend_log)
104125

126+
def close(self):
127+
"""Flush remaining events before closing"""
128+
self.flush()
129+
telemetry_client_factory.close(self.connection_uuid)
105130

106-
class TelemetryManager:
107-
"""A singleton manager class that handles telemetry operations for SQL connections.
108-
109-
This class maintains a map of connection_uuid to TelemetryClient instances. The initialize()
110-
method is only called from the connection class when telemetry is enabled for that connection.
111-
All telemetry operations (initial logs, failure logs, latency logs) first check if the
112-
connection_uuid exists in the map. If it doesn't exist (meaning telemetry was not enabled
113-
for that connection), the operation is skipped. If it exists, the operation is delegated
114-
to the corresponding TelemetryClient instance.
115131

116-
This design ensures that telemetry operations are only performed for connections where
117-
telemetry was explicitly enabled during initialization.
118-
"""
132+
class TelemetryClientFactory:
119133

120134
_instance = None
121135
_DRIVER_SYSTEM_CONFIGURATION = None
122136

123137
def __new__(cls):
124138
if cls._instance is None:
125-
cls._instance = super(TelemetryManager, cls).__new__(cls)
139+
cls._instance = super(TelemetryClientFactory, cls).__new__(cls)
126140
cls._instance._initialized = False
127141
return cls._instance
128142

@@ -131,15 +145,13 @@ def __init__(self):
131145
return
132146

133147
self._clients = {} # Map of connection_uuid -> TelemetryClient
134-
self.executor = ThreadPoolExecutor(max_workers=10) # Thread pool for async operations TODO: Decide on max workers
148+
self.executor = ThreadPoolExecutor(
149+
max_workers=10
150+
) # Thread pool for async operations TODO: Decide on max workers
135151
self._initialized = True
136152

137-
def initialize_telemetry_client(
138-
self,
139-
telemetry_enabled,
140-
batch_size,
141-
connection_uuid,
142-
**kwargs
153+
def get_telemetry_client(
154+
self, telemetry_enabled, batch_size, connection_uuid, **kwargs
143155
):
144156
"""Initialize a telemetry client for a specific connection if telemetry is enabled"""
145157
if telemetry_enabled:
@@ -148,8 +160,11 @@ def initialize_telemetry_client(
148160
telemetry_enabled=telemetry_enabled,
149161
batch_size=batch_size,
150162
connection_uuid=connection_uuid,
151-
**kwargs
163+
**kwargs,
152164
)
165+
return self._clients[connection_uuid]
166+
else:
167+
return NoopTelemetryClient()
153168

154169
def _send_telemetry(self, events, host_url, is_authenticated, auth_provider):
155170
"""Send telemetry events to the server"""
@@ -168,20 +183,9 @@ def _send_telemetry(self, events, host_url, is_authenticated, auth_provider):
168183
auth_provider.add_headers(headers)
169184

170185
self.executor.submit(
171-
requests.post,
172-
url,
173-
data=json.dumps(request),
174-
headers=headers,
175-
timeout=10
186+
requests.post, url, data=json.dumps(request), headers=headers, timeout=10
176187
)
177188

178-
def export_initial_telemetry_log(
179-
self, connection_uuid, **kwargs
180-
):
181-
"""Export initial telemetry for a specific connection"""
182-
if connection_uuid in self._clients:
183-
self._clients[connection_uuid].export_initial_telemetry_log(**kwargs)
184-
185189
@classmethod
186190
def getDriverSystemConfiguration(cls) -> DriverSystemConfiguration:
187191
if cls._DRIVER_SYSTEM_CONFIGURATION is None:
@@ -202,17 +206,13 @@ def getDriverSystemConfiguration(cls) -> DriverSystemConfiguration:
202206
)
203207
return cls._DRIVER_SYSTEM_CONFIGURATION
204208

205-
def close_telemetry_client(self, connection_uuid):
206-
"""Close telemetry client"""
207-
if connection_uuid:
208-
if connection_uuid in self._clients:
209-
self._clients[connection_uuid].close()
210-
del self._clients[connection_uuid]
211-
209+
def close(self, connection_uuid):
210+
del self._clients[connection_uuid]
211+
212212
# Shutdown executor if no more clients
213213
if not self._clients:
214214
self.executor.shutdown(wait=True)
215215

216216

217217
# Create a global instance
218-
telemetry_manager = TelemetryManager()
218+
telemetry_client_factory = TelemetryClientFactory()

0 commit comments

Comments
 (0)