Skip to content

Commit 0465072

Browse files
introduce models for requests and responses
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent fd62aa8 commit 0465072

File tree

5 files changed

+104
-14
lines changed

5 files changed

+104
-14
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
"""
2+
Models for the SEA (Statement Execution API) backend.
3+
4+
This package contains data models for SEA API requests and responses.
5+
"""
6+
7+
from databricks.sql.backend.models.requests import (
8+
CreateSessionRequest,
9+
DeleteSessionRequest,
10+
)
11+
12+
from databricks.sql.backend.models.responses import (
13+
CreateSessionResponse,
14+
)
15+
16+
__all__ = [
17+
# Request models
18+
"CreateSessionRequest",
19+
"DeleteSessionRequest",
20+
# Response models
21+
"CreateSessionResponse",
22+
]
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from typing import Dict, Any, Optional
2+
from dataclasses import dataclass
3+
4+
5+
@dataclass
6+
class CreateSessionRequest:
7+
"""Request to create a new session."""
8+
9+
warehouse_id: str
10+
session_confs: Optional[Dict[str, str]] = None
11+
catalog: Optional[str] = None
12+
schema: Optional[str] = None
13+
14+
def to_dict(self) -> Dict[str, Any]:
15+
"""Convert the request to a dictionary for JSON serialization."""
16+
result: Dict[str, Any] = {"warehouse_id": self.warehouse_id}
17+
18+
if self.session_confs:
19+
result["session_confs"] = self.session_confs
20+
21+
if self.catalog:
22+
result["catalog"] = self.catalog
23+
24+
if self.schema:
25+
result["schema"] = self.schema
26+
27+
return result
28+
29+
30+
@dataclass
31+
class DeleteSessionRequest:
32+
"""Request to delete a session."""
33+
34+
warehouse_id: str
35+
session_id: str
36+
37+
def to_dict(self) -> Dict[str, str]:
38+
"""Convert the request to a dictionary for JSON serialization."""
39+
return {"warehouse_id": self.warehouse_id, "session_id": self.session_id}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from typing import Dict, Any
2+
from dataclasses import dataclass
3+
4+
5+
@dataclass
6+
class CreateSessionResponse:
7+
"""Response from creating a new session."""
8+
9+
session_id: str
10+
11+
@classmethod
12+
def from_dict(cls, data: Dict[str, Any]) -> "CreateSessionResponse":
13+
"""Create a CreateSessionResponse from a dictionary."""
14+
return cls(session_id=data.get("session_id", ""))

src/databricks/sql/backend/sea_backend.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,17 @@
77

88
from databricks.sql.backend.databricks_client import DatabricksClient
99
from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType
10-
from databricks.sql.exc import Error, NotSupportedError
10+
from databricks.sql.exc import Error, NotSupportedError, ServerOperationError
1111
from databricks.sql.backend.utils.http_client import CustomHttpClient
1212
from databricks.sql.thrift_api.TCLIService import ttypes
1313
from databricks.sql.types import SSLOptions
1414

15+
from databricks.sql.backend.models import (
16+
CreateSessionRequest,
17+
DeleteSessionRequest,
18+
CreateSessionResponse,
19+
)
20+
1521
logger = logging.getLogger(__name__)
1622

1723

@@ -163,21 +169,27 @@ def open_session(
163169
schema,
164170
)
165171

166-
request_data: Dict[str, Any] = {"warehouse_id": self.warehouse_id}
167-
if session_configuration:
168-
request_data["session_confs"] = session_configuration
169-
if catalog:
170-
request_data["catalog"] = catalog
171-
if schema:
172-
request_data["schema"] = schema
172+
request_data = CreateSessionRequest(
173+
warehouse_id=self.warehouse_id,
174+
session_confs=session_configuration,
175+
catalog=catalog,
176+
schema=schema,
177+
)
173178

174179
response = self.http_client._make_request(
175-
method="POST", path=self.SESSION_PATH, data=request_data
180+
method="POST", path=self.SESSION_PATH, data=request_data.to_dict()
176181
)
177182

178-
session_id = response.get("session_id")
183+
session_response = CreateSessionResponse.from_dict(response)
184+
session_id = session_response.session_id
179185
if not session_id:
180-
raise Error("Failed to create session: No session ID returned")
186+
raise ServerOperationError(
187+
"Failed to create session: No session ID returned",
188+
{
189+
"operation-id": None,
190+
"diagnostic-info": None,
191+
},
192+
)
181193

182194
return SessionId.from_sea_session_id(session_id)
183195

@@ -199,12 +211,15 @@ def close_session(self, session_id: SessionId) -> None:
199211
raise ValueError("Not a valid SEA session ID")
200212
sea_session_id = session_id.to_sea_session_id()
201213

202-
request_data = {"warehouse_id": self.warehouse_id}
214+
request_data = DeleteSessionRequest(
215+
warehouse_id=self.warehouse_id,
216+
session_id=sea_session_id,
217+
)
203218

204219
self.http_client._make_request(
205220
method="DELETE",
206221
path=self.SESSION_PATH_WITH_ID.format(sea_session_id),
207-
data=request_data,
222+
data=request_data.to_dict(),
208223
)
209224

210225
# == Not Implemented Operations ==

tests/unit/test_sea_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def test_close_session_valid_id(self, sea_client, mock_http_client):
150150
mock_http_client._make_request.assert_called_once_with(
151151
method="DELETE",
152152
path=sea_client.SESSION_PATH_WITH_ID.format("test-session-789"),
153-
data={"warehouse_id": "abc123"},
153+
data={"session_id": "test-session-789", "warehouse_id": "abc123"},
154154
)
155155

156156
def test_close_session_invalid_id_type(self, sea_client):

0 commit comments

Comments
 (0)