Skip to content

Commit 9625229

Browse files
Introduce Sea HTTP Client and test script (#583)
* introduce http client (temp) and sea test file Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * reduce verbosity Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * redundant comment Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * reduce redundancy, params and data separate Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * rename client Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * fix type issues Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * reduce repetition in request calls Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * remove un-necessary elifs Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * add newline at EOF Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> --------- Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 3c78ed7 commit 9625229

File tree

2 files changed

+252
-0
lines changed

2 files changed

+252
-0
lines changed
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import os
2+
import sys
3+
import logging
4+
from databricks.sql.client import Connection
5+
6+
logging.basicConfig(level=logging.DEBUG)
7+
logger = logging.getLogger(__name__)
8+
9+
def test_sea_session():
10+
"""
11+
Test opening and closing a SEA session using the connector.
12+
13+
This function connects to a Databricks SQL endpoint using the SEA backend,
14+
opens a session, and then closes it.
15+
16+
Required environment variables:
17+
- DATABRICKS_SERVER_HOSTNAME: Databricks server hostname
18+
- DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint
19+
- DATABRICKS_TOKEN: Personal access token for authentication
20+
"""
21+
22+
server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME")
23+
http_path = os.environ.get("DATABRICKS_HTTP_PATH")
24+
access_token = os.environ.get("DATABRICKS_TOKEN")
25+
catalog = os.environ.get("DATABRICKS_CATALOG")
26+
27+
if not all([server_hostname, http_path, access_token]):
28+
logger.error("Missing required environment variables.")
29+
logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.")
30+
sys.exit(1)
31+
32+
logger.info(f"Connecting to {server_hostname}")
33+
logger.info(f"HTTP Path: {http_path}")
34+
if catalog:
35+
logger.info(f"Using catalog: {catalog}")
36+
37+
try:
38+
logger.info("Creating connection with SEA backend...")
39+
connection = Connection(
40+
server_hostname=server_hostname,
41+
http_path=http_path,
42+
access_token=access_token,
43+
catalog=catalog,
44+
schema="default",
45+
use_sea=True,
46+
user_agent_entry="SEA-Test-Client" # add custom user agent
47+
)
48+
49+
logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}")
50+
logger.info(f"backend type: {type(connection.session.backend)}")
51+
52+
# Close the connection
53+
logger.info("Closing the SEA session...")
54+
connection.close()
55+
logger.info("Successfully closed SEA session")
56+
57+
except Exception as e:
58+
logger.error(f"Error testing SEA session: {str(e)}")
59+
import traceback
60+
logger.error(traceback.format_exc())
61+
sys.exit(1)
62+
63+
logger.info("SEA session test completed successfully")
64+
65+
if __name__ == "__main__":
66+
test_sea_session()
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import json
2+
import logging
3+
import requests
4+
from typing import Callable, Dict, Any, Optional, Union, List, Tuple
5+
from urllib.parse import urljoin
6+
7+
from databricks.sql.auth.authenticators import AuthProvider
8+
from databricks.sql.types import SSLOptions
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class SeaHttpClient:
14+
"""
15+
HTTP client for Statement Execution API (SEA).
16+
17+
This client handles the HTTP communication with the SEA endpoints,
18+
including authentication, request formatting, and response parsing.
19+
"""
20+
21+
def __init__(
22+
self,
23+
server_hostname: str,
24+
port: int,
25+
http_path: str,
26+
http_headers: List[Tuple[str, str]],
27+
auth_provider: AuthProvider,
28+
ssl_options: SSLOptions,
29+
**kwargs,
30+
):
31+
"""
32+
Initialize the SEA HTTP client.
33+
34+
Args:
35+
server_hostname: Hostname of the Databricks server
36+
port: Port number for the connection
37+
http_path: HTTP path for the connection
38+
http_headers: List of HTTP headers to include in requests
39+
auth_provider: Authentication provider
40+
ssl_options: SSL configuration options
41+
**kwargs: Additional keyword arguments
42+
"""
43+
44+
self.server_hostname = server_hostname
45+
self.port = port
46+
self.http_path = http_path
47+
self.auth_provider = auth_provider
48+
self.ssl_options = ssl_options
49+
50+
self.base_url = f"https://{server_hostname}:{port}"
51+
52+
self.headers: Dict[str, str] = dict(http_headers)
53+
self.headers.update({"Content-Type": "application/json"})
54+
55+
self.max_retries = kwargs.get("_retry_stop_after_attempts_count", 30)
56+
57+
# Create a session for connection pooling
58+
self.session = requests.Session()
59+
60+
# Configure SSL verification
61+
if ssl_options.tls_verify:
62+
self.session.verify = ssl_options.tls_trusted_ca_file or True
63+
else:
64+
self.session.verify = False
65+
66+
# Configure client certificates if provided
67+
if ssl_options.tls_client_cert_file:
68+
client_cert = ssl_options.tls_client_cert_file
69+
client_key = ssl_options.tls_client_cert_key_file
70+
client_key_password = ssl_options.tls_client_cert_key_password
71+
72+
if client_key:
73+
self.session.cert = (client_cert, client_key)
74+
else:
75+
self.session.cert = client_cert
76+
77+
if client_key_password:
78+
# Note: requests doesn't directly support key passwords
79+
# This would require more complex handling with libraries like pyOpenSSL
80+
logger.warning(
81+
"Client key password provided but not supported by requests library"
82+
)
83+
84+
def _get_auth_headers(self) -> Dict[str, str]:
85+
"""Get authentication headers from the auth provider."""
86+
headers: Dict[str, str] = {}
87+
self.auth_provider.add_headers(headers)
88+
return headers
89+
90+
def _get_call(self, method: str) -> Callable:
91+
"""Get the appropriate HTTP method function."""
92+
method = method.upper()
93+
if method == "GET":
94+
return self.session.get
95+
if method == "POST":
96+
return self.session.post
97+
if method == "DELETE":
98+
return self.session.delete
99+
raise ValueError(f"Unsupported HTTP method: {method}")
100+
101+
def _make_request(
102+
self,
103+
method: str,
104+
path: str,
105+
data: Optional[Dict[str, Any]] = None,
106+
params: Optional[Dict[str, Any]] = None,
107+
) -> Dict[str, Any]:
108+
"""
109+
Make an HTTP request to the SEA endpoint.
110+
111+
Args:
112+
method: HTTP method (GET, POST, DELETE)
113+
path: API endpoint path
114+
data: Request payload data
115+
params: Query parameters
116+
117+
Returns:
118+
Dict[str, Any]: Response data parsed from JSON
119+
120+
Raises:
121+
RequestError: If the request fails
122+
"""
123+
124+
url = urljoin(self.base_url, path)
125+
headers: Dict[str, str] = {**self.headers, **self._get_auth_headers()}
126+
127+
logger.debug(f"making {method} request to {url}")
128+
129+
try:
130+
call = self._get_call(method)
131+
response = call(
132+
url=url,
133+
headers=headers,
134+
json=data,
135+
params=params,
136+
)
137+
138+
# Check for HTTP errors
139+
response.raise_for_status()
140+
141+
# Log response details
142+
logger.debug(f"Response status: {response.status_code}")
143+
144+
# Parse JSON response
145+
if response.content:
146+
result = response.json()
147+
# Log response content (but limit it for large responses)
148+
content_str = json.dumps(result)
149+
if len(content_str) > 1000:
150+
logger.debug(
151+
f"Response content (truncated): {content_str[:1000]}..."
152+
)
153+
else:
154+
logger.debug(f"Response content: {content_str}")
155+
return result
156+
return {}
157+
158+
except requests.exceptions.RequestException as e:
159+
# Handle request errors and extract details from response if available
160+
error_message = f"SEA HTTP request failed: {str(e)}"
161+
162+
if hasattr(e, "response") and e.response is not None:
163+
status_code = e.response.status_code
164+
try:
165+
error_details = e.response.json()
166+
error_message = (
167+
f"{error_message}: {error_details.get('message', '')}"
168+
)
169+
logger.error(
170+
f"Request failed (status {status_code}): {error_details}"
171+
)
172+
except (ValueError, KeyError):
173+
# If we can't parse JSON, log raw content
174+
content = (
175+
e.response.content.decode("utf-8", errors="replace")
176+
if isinstance(e.response.content, bytes)
177+
else str(e.response.content)
178+
)
179+
logger.error(f"Request failed (status {status_code}): {content}")
180+
else:
181+
logger.error(error_message)
182+
183+
# Re-raise as a RequestError
184+
from databricks.sql.exc import RequestError
185+
186+
raise RequestError(error_message, e)

0 commit comments

Comments
 (0)