Skip to content

Commit 67fd101

Browse files
remove more irrelevant changes
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 24c6152 commit 67fd101

File tree

3 files changed

+44
-52
lines changed

3 files changed

+44
-52
lines changed

src/databricks/sql/backend/databricks_client.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616

1717
from databricks.sql.thrift_api.TCLIService import ttypes
1818
from databricks.sql.backend.types import SessionId, CommandId, CommandState
19-
from databricks.sql.utils import ExecuteResponse
20-
from databricks.sql.types import SSLOptions
2119

2220
# Forward reference for type hints
2321
from typing import TYPE_CHECKING

src/databricks/sql/backend/thrift_backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,7 @@ def make_request(self, method, request, retryable=True):
352352
Will stop retry attempts if total elapsed time + next retry delay would exceed
353353
_retry_stop_after_attempts_duration.
354354
"""
355+
355356
# basic strategy: build range iterator rep'ing number of available
356357
# retries. bounds can be computed from there. iterate over it with
357358
# retries until success or final failure achieved.
@@ -1241,7 +1242,7 @@ def cancel_command(self, command_id: CommandId) -> None:
12411242
if not thrift_handle:
12421243
raise ValueError("Not a valid Thrift command ID")
12431244

1244-
logger.debug("Cancelling command {}".format(command_id.guid))
1245+
logger.debug("Cancelling command {}".format(guid_to_hex_id(command_id.guid)))
12451246
req = ttypes.TCancelOperationReq(thrift_handle)
12461247
self.make_request(self._client.CancelOperation, req)
12471248

src/databricks/sql/result_set.py

Lines changed: 42 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(
6464
"""
6565

6666
self.connection = connection
67-
self.backend = backend # Store the backend client directly
67+
self.backend = backend
6868
self.arraysize = arraysize
6969
self.buffer_size_bytes = buffer_size_bytes
7070
self._next_row_index = 0
@@ -115,12 +115,12 @@ def fetchall(self) -> List[Row]:
115115
pass
116116

117117
@abstractmethod
118-
def fetchmany_arrow(self, size: int) -> Any:
118+
def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
119119
"""Fetch the next set of rows as an Arrow table."""
120120
pass
121121

122122
@abstractmethod
123-
def fetchall_arrow(self) -> Any:
123+
def fetchall_arrow(self) -> "pyarrow.Table":
124124
"""Fetch all remaining rows as an Arrow table."""
125125
pass
126126

@@ -207,7 +207,7 @@ def _fill_results_buffer(self):
207207
use_cloud_fetch=self._use_cloud_fetch,
208208
)
209209
self.results = results
210-
self._has_more_rows = has_more_rows
210+
self.has_more_rows = has_more_rows
211211

212212
def _convert_columnar_table(self, table):
213213
column_names = [c[0] for c in self.description]
@@ -291,7 +291,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
291291
while (
292292
n_remaining_rows > 0
293293
and not self.has_been_closed_server_side
294-
and self._has_more_rows
294+
and self.has_more_rows
295295
):
296296
self._fill_results_buffer()
297297
partial_results = self.results.next_n_rows(n_remaining_rows)
@@ -316,7 +316,7 @@ def fetchmany_columnar(self, size: int):
316316
while (
317317
n_remaining_rows > 0
318318
and not self.has_been_closed_server_side
319-
and self._has_more_rows
319+
and self.has_more_rows
320320
):
321321
self._fill_results_buffer()
322322
partial_results = self.results.next_n_rows(n_remaining_rows)
@@ -331,7 +331,7 @@ def fetchall_arrow(self) -> "pyarrow.Table":
331331
results = self.results.remaining_rows()
332332
self._next_row_index += results.num_rows
333333

334-
while not self.has_been_closed_server_side and self._has_more_rows:
334+
while not self.has_been_closed_server_side and self.has_more_rows:
335335
self._fill_results_buffer()
336336
partial_results = self.results.remaining_rows()
337337
if isinstance(results, ColumnTable) and isinstance(
@@ -357,7 +357,7 @@ def fetchall_columnar(self):
357357
results = self.results.remaining_rows()
358358
self._next_row_index += results.num_rows
359359

360-
while not self.has_been_closed_server_side and self._has_more_rows:
360+
while not self.has_been_closed_server_side and self.has_more_rows:
361361
self._fill_results_buffer()
362362
partial_results = self.results.remaining_rows()
363363
results = self.merge_columnar(results, partial_results)
@@ -402,6 +402,33 @@ def fetchmany(self, size: int) -> List[Row]:
402402

403403
@staticmethod
404404
def _get_schema_description(table_schema_message):
405+
"""
406+
Takes a TableSchema message and returns a description 7-tuple as specified by PEP-249
407+
"""
408+
409+
def map_col_type(type_):
410+
if type_.startswith("decimal"):
411+
return "decimal"
412+
else:
413+
return type_
414+
415+
return [
416+
(column.name, map_col_type(column.datatype), None, None, None, None, None)
417+
for column in table_schema_message.columns
418+
]
419+
420+
421+
class SeaResultSet(ResultSet):
422+
"""ResultSet implementation for the SEA backend."""
423+
424+
def __init__(
425+
self,
426+
connection: "Connection",
427+
execute_response: "ExecuteResponse",
428+
sea_client: "SeaDatabricksClient",
429+
buffer_size_bytes: int = 104857600,
430+
arraysize: int = 10000,
431+
):
405432
"""
406433
Initialize a SeaResultSet with the response from a SEA query execution.
407434
@@ -413,53 +440,19 @@ def _get_schema_description(table_schema_message):
413440
execute_response: Response from the execute command (new style)
414441
sea_response: Direct SEA response (legacy style)
415442
"""
416-
# Handle both initialization styles
417-
if execute_response is not None:
418-
# New style with ExecuteResponse
419-
command_id = execute_response.command_id
420-
status = execute_response.status
421-
has_been_closed_server_side = execute_response.has_been_closed_server_side
422-
has_more_rows = execute_response.has_more_rows
423-
results_queue = execute_response.results_queue
424-
description = execute_response.description
425-
is_staging_operation = execute_response.is_staging_operation
426-
self._response = getattr(execute_response, "sea_response", {})
427-
self.statement_id = command_id.to_sea_statement_id() if command_id else None
428-
elif sea_response is not None:
429-
# Legacy style with direct sea_response
430-
self._response = sea_response
431-
# Extract values from sea_response
432-
command_id = CommandId.from_sea_statement_id(
433-
sea_response.get("statement_id", "")
434-
)
435-
self.statement_id = sea_response.get("statement_id", "")
436-
437-
# Extract status
438-
status_data = sea_response.get("status", {})
439-
status = CommandState.from_sea_state(status_data.get("state", "PENDING"))
440-
441-
# Set defaults for other fields
442-
has_been_closed_server_side = False
443-
has_more_rows = False
444-
results_queue = None
445-
description = None
446-
is_staging_operation = False
447-
else:
448-
raise ValueError("Either execute_response or sea_response must be provided")
449443

450-
# Call parent constructor with common attributes
451444
super().__init__(
452445
connection=connection,
453446
backend=sea_client,
454447
arraysize=arraysize,
455448
buffer_size_bytes=buffer_size_bytes,
456-
command_id=command_id,
457-
status=status,
458-
has_been_closed_server_side=has_been_closed_server_side,
459-
has_more_rows=has_more_rows,
460-
results_queue=results_queue,
461-
description=description,
462-
is_staging_operation=is_staging_operation,
449+
command_id=execute_response.command_id,
450+
status=execute_response.status,
451+
has_been_closed_server_side=execute_response.has_been_closed_server_side,
452+
has_more_rows=execute_response.has_more_rows,
453+
results_queue=execute_response.results_queue,
454+
description=execute_response.description,
455+
is_staging_operation=execute_response.is_staging_operation,
463456
)
464457

465458
def _fill_results_buffer(self):

0 commit comments

Comments
 (0)