Skip to content

SEA: Cleanup #626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: sea-migration
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,6 @@ def get_execution_result(
return SeaResultSet(
connection=cursor.connection,
execute_response=execute_response,
sea_client=self,
result_data=response.result,
manifest=response.manifest,
buffer_size_bytes=cursor.buffer_size_bytes,
Expand Down
2 changes: 1 addition & 1 deletion src/databricks/sql/backend/sea/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC
from typing import List, Optional, Tuple

from databricks.sql.backend.sea.backend import SeaDatabricksClient
from databricks.sql.backend.sea.client import SeaDatabricksClient
from databricks.sql.backend.sea.models.base import ResultData, ResultManifest
from databricks.sql.backend.sea.utils.constants import ResultFormat
from databricks.sql.exc import ProgrammingError
Expand Down
40 changes: 25 additions & 15 deletions src/databricks/sql/backend/sea/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import logging

from databricks.sql.backend.sea.backend import SeaDatabricksClient
from databricks.sql.backend.sea.client import SeaDatabricksClient
from databricks.sql.backend.sea.models.base import ResultData, ResultManifest
from databricks.sql.backend.sea.utils.conversion import SqlTypeConverter

Expand All @@ -31,7 +31,6 @@ def __init__(
self,
connection: Connection,
execute_response: ExecuteResponse,
sea_client: SeaDatabricksClient,
result_data: ResultData,
manifest: ResultManifest,
buffer_size_bytes: int = 104857600,
Expand All @@ -43,7 +42,6 @@ def __init__(
Args:
connection: The parent connection
execute_response: Response from the execute command
sea_client: The SeaDatabricksClient instance for direct access
buffer_size_bytes: Buffer size for fetching results
arraysize: Default number of rows to fetch
result_data: Result data from SEA response
Expand All @@ -56,32 +54,38 @@ def __init__(
if statement_id is None:
raise ValueError("Command ID is not a SEA statement ID")

results_queue = SeaResultSetQueueFactory.build_queue(
result_data,
self.manifest,
statement_id,
description=execute_response.description,
max_download_threads=sea_client.max_download_threads,
sea_client=sea_client,
lz4_compressed=execute_response.lz4_compressed,
)

# Call parent constructor with common attributes
super().__init__(
connection=connection,
backend=sea_client,
arraysize=arraysize,
buffer_size_bytes=buffer_size_bytes,
command_id=execute_response.command_id,
status=execute_response.status,
has_been_closed_server_side=execute_response.has_been_closed_server_side,
results_queue=results_queue,
description=execute_response.description,
is_staging_operation=execute_response.is_staging_operation,
lz4_compressed=execute_response.lz4_compressed,
arrow_schema_bytes=execute_response.arrow_schema_bytes,
)

# Assert that the backend is of the correct type
assert isinstance(
self.backend, SeaDatabricksClient
), "Backend must be a SeaDatabricksClient"

results_queue = SeaResultSetQueueFactory.build_queue(
result_data,
self.manifest,
statement_id,
description=execute_response.description,
max_download_threads=self.backend.max_download_threads,
sea_client=self.backend,
lz4_compressed=execute_response.lz4_compressed,
)

# Set the results queue
self.results = results_queue

def _convert_json_types(self, row: List[str]) -> List[Any]:
"""
Convert string values in the row to appropriate Python types based on column metadata.
Expand Down Expand Up @@ -160,6 +164,9 @@ def fetchmany_json(self, size: int) -> List[List[str]]:
if size < 0:
raise ValueError(f"size argument for fetchmany is {size} but must be >= 0")

if self.results is None:
raise RuntimeError("Results queue is not initialized")

results = self.results.next_n_rows(size)
self._next_row_index += len(results)

Expand All @@ -173,6 +180,9 @@ def fetchall_json(self) -> List[List[str]]:
Columnar table containing all remaining rows
"""

if self.results is None:
raise RuntimeError("Results queue is not initialized")

results = self.results.remaining_rows()
self._next_row_index += len(results)

Expand Down
9 changes: 5 additions & 4 deletions src/databricks/sql/backend/sea/utils/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
Optional,
Any,
Callable,
cast,
TYPE_CHECKING,
)

if TYPE_CHECKING:
from databricks.sql.backend.sea.result_set import SeaResultSet

from databricks.sql.backend.types import ExecuteResponse
from databricks.sql.backend.types import ExecuteResponse, CommandId, CommandState

logger = logging.getLogger(__name__)

Expand All @@ -45,6 +44,9 @@ def _filter_sea_result_set(
"""

# Get all remaining rows
if result_set.results is None:
raise RuntimeError("Results queue is not initialized")

all_rows = result_set.results.remaining_rows()

# Filter rows
Expand All @@ -69,7 +71,7 @@ def _filter_sea_result_set(

result_data = ResultData(data=filtered_rows, external_links=None)

from databricks.sql.backend.sea.backend import SeaDatabricksClient
from databricks.sql.backend.sea.client import SeaDatabricksClient
from databricks.sql.backend.sea.result_set import SeaResultSet

# Create a new SeaResultSet with the filtered data
Expand All @@ -79,7 +81,6 @@ def _filter_sea_result_set(
filtered_result_set = SeaResultSet(
connection=result_set.connection,
execute_response=execute_response,
sea_client=cast(SeaDatabricksClient, result_set.backend),
result_data=result_data,
manifest=manifest,
buffer_size_bytes=result_set.buffer_size_bytes,
Expand Down
6 changes: 0 additions & 6 deletions src/databricks/sql/backend/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,6 @@ def get_execution_result(
return ThriftResultSet(
connection=cursor.connection,
execute_response=execute_response,
thrift_client=self,
buffer_size_bytes=cursor.buffer_size_bytes,
arraysize=cursor.arraysize,
use_cloud_fetch=cursor.connection.use_cloud_fetch,
Expand Down Expand Up @@ -987,7 +986,6 @@ def execute_command(
return ThriftResultSet(
connection=cursor.connection,
execute_response=execute_response,
thrift_client=self,
buffer_size_bytes=max_bytes,
arraysize=max_rows,
use_cloud_fetch=use_cloud_fetch,
Expand Down Expand Up @@ -1027,7 +1025,6 @@ def get_catalogs(
return ThriftResultSet(
connection=cursor.connection,
execute_response=execute_response,
thrift_client=self,
buffer_size_bytes=max_bytes,
arraysize=max_rows,
use_cloud_fetch=cursor.connection.use_cloud_fetch,
Expand Down Expand Up @@ -1071,7 +1068,6 @@ def get_schemas(
return ThriftResultSet(
connection=cursor.connection,
execute_response=execute_response,
thrift_client=self,
buffer_size_bytes=max_bytes,
arraysize=max_rows,
use_cloud_fetch=cursor.connection.use_cloud_fetch,
Expand Down Expand Up @@ -1119,7 +1115,6 @@ def get_tables(
return ThriftResultSet(
connection=cursor.connection,
execute_response=execute_response,
thrift_client=self,
buffer_size_bytes=max_bytes,
arraysize=max_rows,
use_cloud_fetch=cursor.connection.use_cloud_fetch,
Expand Down Expand Up @@ -1167,7 +1162,6 @@ def get_columns(
return ThriftResultSet(
connection=cursor.connection,
execute_response=execute_response,
thrift_client=self,
buffer_size_bytes=max_bytes,
arraysize=max_rows,
use_cloud_fetch=cursor.connection.use_cloud_fetch,
Expand Down
63 changes: 41 additions & 22 deletions src/databricks/sql/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from databricks.sql.utils import (
ColumnTable,
ColumnQueue,
ResultSetQueue,
)
from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse

Expand All @@ -36,14 +37,12 @@ class ResultSet(ABC):
def __init__(
self,
connection: "Connection",
backend: "DatabricksClient",
arraysize: int,
buffer_size_bytes: int,
command_id: CommandId,
status: CommandState,
has_been_closed_server_side: bool = False,
is_direct_results: bool = False,
results_queue=None,
description: List[Tuple] = [],
is_staging_operation: bool = False,
lz4_compressed: bool = False,
Expand All @@ -54,32 +53,30 @@ def __init__(

Parameters:
:param connection: The parent connection
:param backend: The backend client
:param arraysize: The max number of rows to fetch at a time (PEP-249)
:param buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch
:param command_id: The command ID
:param status: The command status
:param has_been_closed_server_side: Whether the command has been closed on the server
:param is_direct_results: Whether the command has more rows
:param results_queue: The results queue
:param description: column description of the results
:param is_staging_operation: Whether the command is a staging operation
"""

self.connection = connection
self.backend = backend
self.arraysize = arraysize
self.buffer_size_bytes = buffer_size_bytes
self._next_row_index = 0
self.description = description
self.command_id = command_id
self.status = status
self.has_been_closed_server_side = has_been_closed_server_side
self.is_direct_results = is_direct_results
self.results = results_queue
self._is_staging_operation = is_staging_operation
self.lz4_compressed = lz4_compressed
self._arrow_schema_bytes = arrow_schema_bytes
self.connection: "Connection" = connection
self.backend: DatabricksClient = connection.session.backend
self.arraysize: int = arraysize
self.buffer_size_bytes: int = buffer_size_bytes
self._next_row_index: int = 0
self.description: List[Tuple] = description
self.command_id: CommandId = command_id
self.status: CommandState = status
self.has_been_closed_server_side: bool = has_been_closed_server_side
self.is_direct_results: bool = is_direct_results
self.results: Optional[ResultSetQueue] = None # Children will set this
self._is_staging_operation: bool = is_staging_operation
self.lz4_compressed: bool = lz4_compressed
self._arrow_schema_bytes: Optional[bytes] = arrow_schema_bytes

def __iter__(self):
while True:
Expand Down Expand Up @@ -190,7 +187,6 @@ def __init__(
self,
connection: "Connection",
execute_response: "ExecuteResponse",
thrift_client: "ThriftDatabricksClient",
buffer_size_bytes: int = 104857600,
arraysize: int = 10000,
use_cloud_fetch: bool = True,
Expand All @@ -205,7 +201,6 @@ def __init__(
Parameters:
:param connection: The parent connection
:param execute_response: Response from the execute command
:param thrift_client: The ThriftDatabricksClient instance for direct access
:param buffer_size_bytes: Buffer size for fetching results
:param arraysize: Default number of rows to fetch
:param use_cloud_fetch: Whether to use cloud fetch for retrieving results
Expand Down Expand Up @@ -238,20 +233,28 @@ def __init__(
# Call parent constructor with common attributes
super().__init__(
connection=connection,
backend=thrift_client,
arraysize=arraysize,
buffer_size_bytes=buffer_size_bytes,
command_id=execute_response.command_id,
status=execute_response.status,
has_been_closed_server_side=execute_response.has_been_closed_server_side,
is_direct_results=is_direct_results,
results_queue=results_queue,
description=execute_response.description,
is_staging_operation=execute_response.is_staging_operation,
lz4_compressed=execute_response.lz4_compressed,
arrow_schema_bytes=execute_response.arrow_schema_bytes,
)

# Assert that the backend is of the correct type
from databricks.sql.backend.thrift_backend import ThriftDatabricksClient

assert isinstance(
self.backend, ThriftDatabricksClient
), "Backend must be a ThriftDatabricksClient"

# Set the results queue
self.results = results_queue

# Initialize results queue if not provided
if not self.results:
self._fill_results_buffer()
Expand Down Expand Up @@ -307,6 +310,10 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
"""
if size < 0:
raise ValueError("size argument for fetchmany is %s but must be >= 0", size)

if self.results is None:
raise RuntimeError("Results queue is not initialized")

results = self.results.next_n_rows(size)
n_remaining_rows = size - results.num_rows
self._next_row_index += results.num_rows
Expand All @@ -332,6 +339,9 @@ def fetchmany_columnar(self, size: int):
if size < 0:
raise ValueError("size argument for fetchmany is %s but must be >= 0", size)

if self.results is None:
raise RuntimeError("Results queue is not initialized")

results = self.results.next_n_rows(size)
n_remaining_rows = size - results.num_rows
self._next_row_index += results.num_rows
Expand All @@ -351,6 +361,9 @@ def fetchmany_columnar(self, size: int):

def fetchall_arrow(self) -> "pyarrow.Table":
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
if self.results is None:
raise RuntimeError("Results queue is not initialized")

results = self.results.remaining_rows()
self._next_row_index += results.num_rows

Expand All @@ -377,6 +390,9 @@ def fetchall_arrow(self) -> "pyarrow.Table":

def fetchall_columnar(self):
"""Fetch all (remaining) rows of a query result, returning them as a Columnar table."""
if self.results is None:
raise RuntimeError("Results queue is not initialized")

results = self.results.remaining_rows()
self._next_row_index += results.num_rows

Expand All @@ -393,6 +409,9 @@ def fetchone(self) -> Optional[Row]:
Fetch the next row of a query result set, returning a single sequence,
or None when no more data is available.
"""
if self.results is None:
raise RuntimeError("Results queue is not initialized")

if isinstance(self.results, ColumnQueue):
res = self._convert_columnar_table(self.fetchmany_columnar(1))
else:
Expand Down
2 changes: 1 addition & 1 deletion src/databricks/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from databricks.sql import __version__
from databricks.sql import USER_AGENT_NAME
from databricks.sql.backend.thrift_backend import ThriftDatabricksClient
from databricks.sql.backend.sea.backend import SeaDatabricksClient
from databricks.sql.backend.sea.client import SeaDatabricksClient
from databricks.sql.backend.databricks_client import DatabricksClient
from databricks.sql.backend.types import SessionId, BackendType

Expand Down
2 changes: 1 addition & 1 deletion src/databricks/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import lz4.frame

from databricks.sql.backend.sea.backend import SeaDatabricksClient
from databricks.sql.backend.sea.client import SeaDatabricksClient
from databricks.sql.backend.sea.models.base import ResultData, ResultManifest

try:
Expand Down
Loading
Loading