Skip to content

Commit 9280fc2

Browse files
remove SeaHttpClient and integrate with THttpClient
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent bc467d1 commit 9280fc2

File tree

6 files changed

+330
-243
lines changed

6 files changed

+330
-243
lines changed

src/databricks/sql/auth/thrift_http_client.py

Lines changed: 152 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
import base64
2+
import json
23
import logging
34
import urllib.parse
4-
from typing import Dict, Union, Optional
5+
from typing import Dict, Union, Optional, Any
56

67
import six
7-
import thrift
8+
import thrift.transport.THttpClient
89

910
import ssl
1011
import warnings
1112
from http.client import HTTPResponse
1213
from io import BytesIO
1314

15+
import urllib3
1416
from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, ProxyManager
1517
from urllib3.util import make_headers
1618
from databricks.sql.auth.retry import CommandType, DatabricksRetryPolicy
@@ -222,3 +224,151 @@ def set_retry_command_type(self, value: CommandType):
222224
logger.warning(
223225
"DatabricksRetryPolicy is currently bypassed. The CommandType cannot be set."
224226
)
227+
228+
def make_rest_request(
229+
self,
230+
method: str,
231+
endpoint_path: str,
232+
data: Optional[Dict[str, Any]] = None,
233+
params: Optional[Dict[str, Any]] = None,
234+
headers: Optional[Dict[str, str]] = None,
235+
) -> Dict[str, Any]:
236+
"""
237+
Make a REST API request using the existing connection pool.
238+
239+
Args:
240+
method (str): HTTP method (GET, POST, DELETE, etc.)
241+
endpoint_path (str): API endpoint path (e.g., "sessions" or "statements/123")
242+
data (dict, optional): Request payload data
243+
params (dict, optional): Query parameters
244+
headers (dict, optional): Additional headers
245+
246+
Returns:
247+
dict: Response data parsed from JSON
248+
249+
Raises:
250+
RequestError: If the request fails
251+
"""
252+
# Ensure the transport is open
253+
if not self.isOpen():
254+
self.open()
255+
256+
# Prepare headers
257+
request_headers = {
258+
"Content-Type": "application/json",
259+
}
260+
261+
# Add authentication headers
262+
auth_headers: Dict[str, str] = {}
263+
self.__auth_provider.add_headers(auth_headers)
264+
request_headers.update(auth_headers)
265+
266+
# Add custom headers if provided
267+
if headers:
268+
request_headers.update(headers)
269+
270+
# Prepare request body
271+
body = json.dumps(data).encode("utf-8") if data else None
272+
273+
# Build query string for params
274+
query_string = ""
275+
if params:
276+
query_string = "?" + urllib.parse.urlencode(params)
277+
278+
# Determine full path
279+
full_path = (
280+
self.path.rstrip("/") + "/" + endpoint_path.lstrip("/") + query_string
281+
)
282+
283+
# Log request details (debug level)
284+
logger.debug(f"Making {method} request to {full_path}")
285+
286+
try:
287+
# Make request using the connection pool
288+
self.__resp = self.__pool.request(
289+
method,
290+
url=full_path,
291+
body=body,
292+
headers=request_headers,
293+
preload_content=False,
294+
timeout=self.__timeout,
295+
retries=self.retry_policy,
296+
)
297+
298+
# Store response status and headers
299+
if self.__resp is not None:
300+
self.code = self.__resp.status
301+
self.message = self.__resp.reason
302+
self.headers = self.__resp.headers
303+
304+
# Log response status
305+
logger.debug(f"Response status: {self.code}, message: {self.message}")
306+
307+
# Read and parse response data
308+
# Note: urllib3's HTTPResponse has a data attribute, but it's not in the type stubs
309+
response_data = getattr(self.__resp, "data", None)
310+
311+
# Check for HTTP errors
312+
self._check_rest_response_for_error(self.code, response_data)
313+
314+
# Parse JSON response if there is content
315+
if response_data:
316+
result = json.loads(response_data.decode("utf-8"))
317+
318+
# Log response content (truncated for large responses)
319+
content_str = json.dumps(result)
320+
if len(content_str) > 1000:
321+
logger.debug(
322+
f"Response content (truncated): {content_str[:1000]}..."
323+
)
324+
else:
325+
logger.debug(f"Response content: {content_str}")
326+
327+
return result
328+
329+
return {}
330+
else:
331+
raise ValueError("No response received from server")
332+
333+
except urllib3.exceptions.HTTPError as e:
334+
error_message = f"REST HTTP request failed: {str(e)}"
335+
logger.error(error_message)
336+
from databricks.sql.exc import RequestError
337+
338+
raise RequestError(error_message, e)
339+
340+
def _check_rest_response_for_error(
341+
self, status_code: int, response_data: Optional[bytes]
342+
) -> None:
343+
"""
344+
Check if the REST response indicates an error and raise an appropriate exception.
345+
346+
Args:
347+
status_code: HTTP status code
348+
response_data: Raw response data
349+
350+
Raises:
351+
RequestError: If the response indicates an error
352+
"""
353+
if status_code >= 400:
354+
error_message = f"REST HTTP request failed with status {status_code}"
355+
356+
# Try to extract error details from JSON response
357+
if response_data:
358+
try:
359+
error_details = json.loads(response_data.decode("utf-8"))
360+
if isinstance(error_details, dict) and "message" in error_details:
361+
error_message = f"{error_message}: {error_details['message']}"
362+
logger.error(
363+
f"Request failed (status {status_code}): {error_details}"
364+
)
365+
except (ValueError, KeyError):
366+
# If we can't parse JSON, log raw content
367+
content = response_data.decode("utf-8", errors="replace")
368+
logger.error(f"Request failed (status {status_code}): {content}")
369+
else:
370+
logger.error(f"Request failed (status {status_code}): No response data")
371+
372+
from databricks.sql.exc import RequestError
373+
374+
raise RequestError(error_message)

src/databricks/sql/backend/sea/backend.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
ExecuteResponse,
2727
)
2828
from databricks.sql.exc import ServerOperationError
29-
from databricks.sql.backend.sea.utils.http_client import SeaHttpClient
29+
from databricks.sql.auth.thrift_http_client import THttpClient
30+
from databricks.sql.backend.sea.utils.http_client_adapter import SeaHttpClientAdapter
3031
from databricks.sql.thrift_api.TCLIService import ttypes
3132
from databricks.sql.types import SSLOptions
3233

@@ -129,17 +130,23 @@ def __init__(
129130
# Extract warehouse ID from http_path
130131
self.warehouse_id = self._extract_warehouse_id(http_path)
131132

132-
# Initialize HTTP client
133-
self.http_client = SeaHttpClient(
134-
server_hostname=server_hostname,
135-
port=port,
136-
http_path=http_path,
137-
http_headers=http_headers,
133+
# Initialize ThriftHttpClient
134+
thrift_client = THttpClient(
138135
auth_provider=auth_provider,
136+
uri_or_host=f"https://{server_hostname}:{port}",
137+
path=http_path,
139138
ssl_options=ssl_options,
140-
**kwargs,
139+
max_connections=kwargs.get("max_connections", 1),
140+
retry_policy=kwargs.get("_retry_stop_after_attempts_count", 30),
141141
)
142142

143+
# Set custom headers
144+
custom_headers = dict(http_headers)
145+
thrift_client.setCustomHeaders(custom_headers)
146+
147+
# Initialize HTTP client adapter
148+
self.http_client = SeaHttpClientAdapter(thrift_client=thrift_client)
149+
143150
def _extract_warehouse_id(self, http_path: str) -> str:
144151
"""
145152
Extract the warehouse ID from the HTTP path.
@@ -222,8 +229,8 @@ def open_session(
222229
schema=schema,
223230
)
224231

225-
response = self.http_client._make_request(
226-
method="POST", path=self.SESSION_PATH, data=request_data.to_dict()
232+
response = self.http_client.post(
233+
path=self.SESSION_PATH, data=request_data.to_dict()
227234
)
228235

229236
session_response = CreateSessionResponse.from_dict(response)
@@ -262,8 +269,7 @@ def close_session(self, session_id: SessionId) -> None:
262269
session_id=sea_session_id,
263270
)
264271

265-
self.http_client._make_request(
266-
method="DELETE",
272+
self.http_client.delete(
267273
path=self.SESSION_PATH_WITH_ID.format(sea_session_id),
268274
data=request_data.to_dict(),
269275
)
@@ -340,8 +346,7 @@ def get_chunk_link(self, statement_id: str, chunk_index: int) -> ExternalLink:
340346
ExternalLink: External link for the chunk
341347
"""
342348

343-
response_data = self.http_client._make_request(
344-
method="GET",
349+
response_data = self.http_client.get(
345350
path=self.CHUNK_PATH_WITH_ID_AND_INDEX.format(statement_id, chunk_index),
346351
)
347352
response = GetChunksResponse.from_dict(response_data)
@@ -470,8 +475,8 @@ def execute_command(
470475
result_compression=result_compression,
471476
)
472477

473-
response_data = self.http_client._make_request(
474-
method="POST", path=self.STATEMENT_PATH, data=request.to_dict()
478+
response_data = self.http_client.post(
479+
path=self.STATEMENT_PATH, data=request.to_dict()
475480
)
476481
response = ExecuteStatementResponse.from_dict(response_data)
477482
statement_id = response.statement_id
@@ -530,8 +535,7 @@ def cancel_command(self, command_id: CommandId) -> None:
530535
sea_statement_id = command_id.to_sea_statement_id()
531536

532537
request = CancelStatementRequest(statement_id=sea_statement_id)
533-
self.http_client._make_request(
534-
method="POST",
538+
self.http_client.post(
535539
path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id),
536540
data=request.to_dict(),
537541
)
@@ -553,8 +557,7 @@ def close_command(self, command_id: CommandId) -> None:
553557
sea_statement_id = command_id.to_sea_statement_id()
554558

555559
request = CloseStatementRequest(statement_id=sea_statement_id)
556-
self.http_client._make_request(
557-
method="DELETE",
560+
self.http_client.delete(
558561
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
559562
data=request.to_dict(),
560563
)
@@ -579,10 +582,8 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
579582
sea_statement_id = command_id.to_sea_statement_id()
580583

581584
request = GetStatementRequest(statement_id=sea_statement_id)
582-
response_data = self.http_client._make_request(
583-
method="GET",
585+
response_data = self.http_client.get(
584586
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
585-
data=request.to_dict(),
586587
)
587588

588589
# Parse the response
@@ -617,10 +618,8 @@ def get_execution_result(
617618
request = GetStatementRequest(statement_id=sea_statement_id)
618619

619620
# Get the statement result
620-
response_data = self.http_client._make_request(
621-
method="GET",
621+
response_data = self.http_client.get(
622622
path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id),
623-
data=request.to_dict(),
624623
)
625624

626625
# Create and return a SeaResultSet
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""
2+
Utility modules for the Statement Execution API (SEA) backend.
3+
"""
4+
5+
from databricks.sql.backend.sea.utils.http_client_adapter import SeaHttpClientAdapter
6+
from databricks.sql.backend.sea.utils.constants import (
7+
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP,
8+
ResultFormat,
9+
ResultDisposition,
10+
ResultCompression,
11+
WaitTimeout,
12+
)
13+
14+
__all__ = [
15+
"SeaHttpClientAdapter",
16+
"ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP",
17+
"ResultFormat",
18+
"ResultDisposition",
19+
"ResultCompression",
20+
"WaitTimeout",
21+
]

0 commit comments

Comments
 (0)