Skip to content

Commit 716304b

Browse files
rmeove redundant queue init
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent c038d5a commit 716304b

File tree

5 files changed

+457
-111
lines changed

5 files changed

+457
-111
lines changed

src/databricks/sql/backend/thrift_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,9 +1224,9 @@ def fetch_results(
12241224
)
12251225
)
12261226

1227-
from databricks.sql.utils import ResultSetQueueFactory
1227+
from databricks.sql.utils import ThriftResultSetQueueFactory
12281228

1229-
queue = ResultSetQueueFactory.build_queue(
1229+
queue = ThriftResultSetQueueFactory.build_queue(
12301230
row_set_type=resp.resultSetMetadata.resultFormat,
12311231
t_row_set=resp.results,
12321232
arrow_schema_bytes=arrow_schema_bytes,

src/databricks/sql/result_set.py

Lines changed: 123 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(
5151
description=None,
5252
is_staging_operation: bool = False,
5353
lz4_compressed: bool = False,
54-
arrow_schema_bytes: bytes = b"",
54+
arrow_schema_bytes: Optional[bytes] = b"",
5555
):
5656
"""
5757
A ResultSet manages the results of a single command.
@@ -205,22 +205,6 @@ def __init__(
205205
ssl_options=ssl_options,
206206
)
207207

208-
# Build the results queue if t_row_set is provided
209-
results_queue = None
210-
if t_row_set and execute_response.result_format is not None:
211-
from databricks.sql.utils import ResultSetQueueFactory
212-
213-
# Create the results queue using the provided format
214-
results_queue = ResultSetQueueFactory.build_queue(
215-
row_set_type=execute_response.result_format,
216-
t_row_set=t_row_set,
217-
arrow_schema_bytes=execute_response.arrow_schema_bytes or b"",
218-
max_download_threads=max_download_threads,
219-
lz4_compressed=execute_response.lz4_compressed,
220-
description=execute_response.description,
221-
ssl_options=ssl_options,
222-
)
223-
224208
# Call parent constructor with common attributes
225209
super().__init__(
226210
connection=connection,
@@ -543,16 +527,13 @@ def fetchone(self) -> Optional[Row]:
543527
Fetch the next row of a query result set, returning a single sequence,
544528
or None when no more data is available.
545529
"""
546-
if isinstance(self.results, JsonQueue):
547-
rows = self.results.next_n_rows(1)
548-
if not rows:
549-
return None
530+
rows = self.results.next_n_rows(1)
531+
if not rows:
532+
return None
550533

551-
# Convert to Row object
552-
converted_rows = self._convert_to_row_objects(rows)
553-
return converted_rows[0] if converted_rows else None
554-
else:
555-
raise NotImplementedError("Unsupported queue type")
534+
# Convert to Row object
535+
converted_rows = self._convert_to_row_objects(rows)
536+
return converted_rows[0] if converted_rows else None
556537

557538
def fetchmany(self, size: Optional[int] = None) -> List[Row]:
558539
"""
@@ -566,58 +547,141 @@ def fetchmany(self, size: Optional[int] = None) -> List[Row]:
566547
if size < 0:
567548
raise ValueError(f"size argument for fetchmany is {size} but must be >= 0")
568549

569-
# Note: We check for the specific queue type to maintain consistency with ThriftResultSet
570-
if isinstance(self.results, JsonQueue):
571-
rows = self.results.next_n_rows(size)
572-
self._next_row_index += len(rows)
550+
rows = self.results.next_n_rows(size)
551+
self._next_row_index += len(rows)
573552

574-
# Convert to Row objects
575-
return self._convert_to_row_objects(rows)
576-
else:
577-
raise NotImplementedError("Unsupported queue type")
553+
# Convert to Row objects
554+
return self._convert_to_row_objects(rows)
578555

579556
def fetchall(self) -> List[Row]:
580557
"""
581558
Fetch all (remaining) rows of a query result, returning them as a list of rows.
582559
"""
583-
# Note: We check for the specific queue type to maintain consistency with ThriftResultSet
584-
if isinstance(self.results, JsonQueue):
585-
rows = self.results.remaining_rows()
586-
self._next_row_index += len(rows)
587560

588-
# Convert to Row objects
589-
return self._convert_to_row_objects(rows)
561+
rows = self.results.remaining_rows()
562+
self._next_row_index += len(rows)
563+
564+
# Convert to Row objects
565+
return self._convert_to_row_objects(rows)
566+
567+
def _create_empty_arrow_table(self) -> Any:
568+
"""
569+
Create an empty PyArrow table with the schema from the result set.
570+
571+
Returns:
572+
An empty PyArrow table with the correct schema.
573+
"""
574+
import pyarrow
575+
576+
# Try to use schema bytes if available
577+
if self._arrow_schema_bytes:
578+
schema = pyarrow.ipc.read_schema(
579+
pyarrow.BufferReader(self._arrow_schema_bytes)
580+
)
581+
return pyarrow.Table.from_pydict(
582+
{name: [] for name in schema.names}, schema=schema
583+
)
584+
585+
# Fall back to creating schema from description
586+
if self.description:
587+
# Map SQL types to PyArrow types
588+
type_map = {
589+
"boolean": pyarrow.bool_(),
590+
"tinyint": pyarrow.int8(),
591+
"smallint": pyarrow.int16(),
592+
"int": pyarrow.int32(),
593+
"bigint": pyarrow.int64(),
594+
"float": pyarrow.float32(),
595+
"double": pyarrow.float64(),
596+
"string": pyarrow.string(),
597+
"binary": pyarrow.binary(),
598+
"timestamp": pyarrow.timestamp("us"),
599+
"date": pyarrow.date32(),
600+
"decimal": pyarrow.decimal128(38, 18), # Default precision and scale
601+
}
602+
603+
fields = []
604+
for col_desc in self.description:
605+
col_name = col_desc[0]
606+
col_type = col_desc[1].lower() if col_desc[1] else "string"
607+
608+
# Handle decimal with precision and scale
609+
if (
610+
col_type == "decimal"
611+
and col_desc[4] is not None
612+
and col_desc[5] is not None
613+
):
614+
arrow_type = pyarrow.decimal128(col_desc[4], col_desc[5])
615+
else:
616+
arrow_type = type_map.get(col_type, pyarrow.string())
617+
618+
fields.append(pyarrow.field(col_name, arrow_type))
619+
620+
schema = pyarrow.schema(fields)
621+
return pyarrow.Table.from_pydict(
622+
{name: [] for name in schema.names}, schema=schema
623+
)
624+
625+
# If no schema information is available, return an empty table
626+
return pyarrow.Table.from_pydict({})
627+
628+
def _convert_rows_to_arrow_table(self, rows: List[Row]) -> Any:
629+
"""
630+
Convert a list of Row objects to a PyArrow table.
631+
632+
Args:
633+
rows: List of Row objects to convert.
634+
635+
Returns:
636+
PyArrow table containing the data from the rows.
637+
"""
638+
import pyarrow
639+
640+
if not rows:
641+
return self._create_empty_arrow_table()
642+
643+
# Extract column names from description
644+
if self.description:
645+
column_names = [col[0] for col in self.description]
590646
else:
591-
raise NotImplementedError("Unsupported queue type")
647+
# If no description, use the attribute names from the first row
648+
column_names = rows[0]._fields
649+
650+
# Convert rows to columns
651+
columns: dict[str, list] = {name: [] for name in column_names}
652+
653+
for row in rows:
654+
for i, name in enumerate(column_names):
655+
if hasattr(row, "_asdict"): # If it's a Row object
656+
columns[name].append(row[i])
657+
else: # If it's a raw list
658+
columns[name].append(row[i])
659+
660+
# Create PyArrow table
661+
return pyarrow.Table.from_pydict(columns)
592662

593663
def fetchmany_arrow(self, size: int) -> Any:
594664
"""Fetch the next set of rows as an Arrow table."""
595665
if not pyarrow:
596666
raise ImportError("PyArrow is required for Arrow support")
597667

598-
if isinstance(self.results, JsonQueue):
599-
rows = self.fetchmany(size)
600-
if not rows:
601-
# Return empty Arrow table with schema
602-
return self._create_empty_arrow_table()
668+
rows = self.fetchmany(size)
669+
if not rows:
670+
# Return empty Arrow table with schema
671+
return self._create_empty_arrow_table()
603672

604-
# Convert rows to Arrow table
605-
return self._convert_rows_to_arrow_table(rows)
606-
else:
607-
raise NotImplementedError("Unsupported queue type")
673+
# Convert rows to Arrow table
674+
return self._convert_rows_to_arrow_table(rows)
608675

609676
def fetchall_arrow(self) -> Any:
610677
"""Fetch all remaining rows as an Arrow table."""
611678
if not pyarrow:
612679
raise ImportError("PyArrow is required for Arrow support")
613680

614-
if isinstance(self.results, JsonQueue):
615-
rows = self.fetchall()
616-
if not rows:
617-
# Return empty Arrow table with schema
618-
return self._create_empty_arrow_table()
681+
rows = self.fetchall()
682+
if not rows:
683+
# Return empty Arrow table with schema
684+
return self._create_empty_arrow_table()
619685

620-
# Convert rows to Arrow table
621-
return self._convert_rows_to_arrow_table(rows)
622-
else:
623-
raise NotImplementedError("Unsupported queue type")
686+
# Convert rows to Arrow table
687+
return self._convert_rows_to_arrow_table(rows)

tests/unit/test_sea_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ def test_get_execution_result(
536536
print(result)
537537

538538
# Verify basic properties of the result
539-
assert result.statement_id == "test-statement-123"
539+
assert result.command_id.to_sea_statement_id() == "test-statement-123"
540540
assert result.status == CommandState.SUCCEEDED
541541

542542
# Verify the HTTP request

0 commit comments

Comments
 (0)