Skip to content

Commit d4a49ca

Browse files
introduce Result Set interface and concrete thrift implementation
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent a57f6c3 commit d4a49ca

File tree

4 files changed

+437
-370
lines changed

4 files changed

+437
-370
lines changed

src/databricks/sql/backend/databricks_client.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
from databricks.sql.utils import ExecuteResponse
66
from databricks.sql.types import SSLOptions
77

8+
# Forward reference for type hints
9+
from typing import TYPE_CHECKING
10+
if TYPE_CHECKING:
11+
from databricks.sql.result_set import ResultSet
12+
813

914
class DatabricksClient(ABC):
1015
# == Connection and Session Management ==
@@ -35,7 +40,7 @@ def execute_command(
3540
parameters: List[ttypes.TSparkParameter],
3641
async_op: bool,
3742
enforce_embedded_schema_correctness: bool,
38-
) -> Any:
43+
) -> "ResultSet": # Changed return type to ResultSet
3944
pass
4045

4146
@abstractmethod

src/databricks/sql/backend/thrift_backend.py

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
)
4343
from databricks.sql.types import SSLOptions
4444
from databricks.sql.backend.databricks_client import DatabricksClient
45+
from databricks.sql.result_set import ThriftResultSet
4546

4647
logger = logging.getLogger(__name__)
4748

@@ -833,7 +834,7 @@ def get_execution_result(self, op_handle, cursor):
833834
ssl_options=self._ssl_options,
834835
)
835836

836-
return ExecuteResponse(
837+
execute_response = ExecuteResponse(
837838
arrow_queue=queue,
838839
status=resp.status,
839840
has_been_closed_server_side=False,
@@ -844,6 +845,15 @@ def get_execution_result(self, op_handle, cursor):
844845
description=description,
845846
arrow_schema_bytes=schema_bytes,
846847
)
848+
849+
return ThriftResultSet(
850+
connection=cursor.connection,
851+
execute_response=execute_response,
852+
thrift_client=self,
853+
buffer_size_bytes=cursor.buffer_size_bytes,
854+
arraysize=cursor.arraysize,
855+
use_cloud_fetch=cursor.connection.use_cloud_fetch
856+
)
847857

848858
def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
849859
if initial_operation_status_resp:
@@ -938,8 +948,18 @@ def execute_command(
938948

939949
if async_op:
940950
self._handle_execute_response_async(resp, cursor)
951+
return None
941952
else:
942-
return self._handle_execute_response(resp, cursor)
953+
execute_response = self._handle_execute_response(resp, cursor)
954+
955+
return ThriftResultSet(
956+
connection=cursor.connection,
957+
execute_response=execute_response,
958+
thrift_client=self,
959+
buffer_size_bytes=max_bytes,
960+
arraysize=max_rows,
961+
use_cloud_fetch=use_cloud_fetch
962+
)
943963

944964
def get_catalogs(self, session_handle, max_rows, max_bytes, cursor):
945965
assert session_handle is not None
@@ -951,7 +971,17 @@ def get_catalogs(self, session_handle, max_rows, max_bytes, cursor):
951971
),
952972
)
953973
resp = self.make_request(self._client.GetCatalogs, req)
954-
return self._handle_execute_response(resp, cursor)
974+
975+
execute_response = self._handle_execute_response(resp, cursor)
976+
977+
return ThriftResultSet(
978+
connection=cursor.connection,
979+
execute_response=execute_response,
980+
thrift_client=self,
981+
buffer_size_bytes=max_bytes,
982+
arraysize=max_rows,
983+
use_cloud_fetch=cursor.connection.use_cloud_fetch
984+
)
955985

956986
def get_schemas(
957987
self,
@@ -973,7 +1003,17 @@ def get_schemas(
9731003
schemaName=schema_name,
9741004
)
9751005
resp = self.make_request(self._client.GetSchemas, req)
976-
return self._handle_execute_response(resp, cursor)
1006+
1007+
execute_response = self._handle_execute_response(resp, cursor)
1008+
1009+
return ThriftResultSet(
1010+
connection=cursor.connection,
1011+
execute_response=execute_response,
1012+
thrift_client=self,
1013+
buffer_size_bytes=max_bytes,
1014+
arraysize=max_rows,
1015+
use_cloud_fetch=cursor.connection.use_cloud_fetch
1016+
)
9771017

9781018
def get_tables(
9791019
self,
@@ -999,7 +1039,17 @@ def get_tables(
9991039
tableTypes=table_types,
10001040
)
10011041
resp = self.make_request(self._client.GetTables, req)
1002-
return self._handle_execute_response(resp, cursor)
1042+
1043+
execute_response = self._handle_execute_response(resp, cursor)
1044+
1045+
return ThriftResultSet(
1046+
connection=cursor.connection,
1047+
execute_response=execute_response,
1048+
thrift_client=self,
1049+
buffer_size_bytes=max_bytes,
1050+
arraysize=max_rows,
1051+
use_cloud_fetch=cursor.connection.use_cloud_fetch
1052+
)
10031053

10041054
def get_columns(
10051055
self,
@@ -1025,7 +1075,17 @@ def get_columns(
10251075
columnName=column_name,
10261076
)
10271077
resp = self.make_request(self._client.GetColumns, req)
1028-
return self._handle_execute_response(resp, cursor)
1078+
1079+
execute_response = self._handle_execute_response(resp, cursor)
1080+
1081+
return ThriftResultSet(
1082+
connection=cursor.connection,
1083+
execute_response=execute_response,
1084+
thrift_client=self,
1085+
buffer_size_bytes=max_bytes,
1086+
arraysize=max_rows,
1087+
use_cloud_fetch=cursor.connection.use_cloud_fetch
1088+
)
10291089

10301090
def _handle_execute_response(self, resp, cursor):
10311091
cursor.active_op_handle = resp.operationHandle

0 commit comments

Comments
 (0)