Skip to content

Commit 38c2b88

Browse files
correct fetch*_arrow
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent b3273c7 commit 38c2b88

File tree

3 files changed

+17
-11
lines changed

3 files changed

+17
-11
lines changed

examples/experimental/tests/test_sea_async_query.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from databricks.sql.client import Connection
99
from databricks.sql.backend.types import CommandState
1010

11-
logging.basicConfig(level=logging.INFO)
11+
logging.basicConfig(level=logging.DEBUG)
1212
logger = logging.getLogger(__name__)
1313

1414

examples/experimental/tests/test_sea_sync_query.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import logging
77
from databricks.sql.client import Connection
88

9-
logging.basicConfig(level=logging.INFO)
9+
logging.basicConfig(level=logging.DEBUG)
1010
logger = logging.getLogger(__name__)
1111

1212

src/databricks/sql/result_set.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@
2424
from databricks.sql.thrift_api.TCLIService import ttypes
2525
from databricks.sql.types import Row
2626
from databricks.sql.exc import Error, RequestError, CursorAlreadyClosedError
27-
from databricks.sql.utils import ColumnTable, ColumnQueue, JsonQueue
27+
from databricks.sql.utils import (
28+
ColumnTable,
29+
ColumnQueue,
30+
JsonQueue,
31+
SeaResultSetQueueFactory,
32+
)
2833
from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse
2934

3035
logger = logging.getLogger(__name__)
@@ -475,6 +480,7 @@ def __init__(
475480
result_data,
476481
manifest,
477482
str(execute_response.command_id.to_sea_statement_id()),
483+
ssl_options=connection.session.ssl_options,
478484
description=execute_response.description,
479485
max_download_threads=sea_client.max_download_threads,
480486
sea_client=sea_client,
@@ -618,11 +624,11 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
618624
if size < 0:
619625
raise ValueError(f"size argument for fetchmany is {size} but must be >= 0")
620626

621-
if not isinstance(self.results, JsonQueue):
622-
raise NotImplementedError("fetchmany_arrow only supported for JSON data")
627+
results = self.results.next_n_rows(size)
628+
if isinstance(self.results, JsonQueue):
629+
results = self._convert_json_types(results)
630+
results = self._convert_json_to_arrow(results)
623631

624-
rows = self._convert_json_types(self.results.next_n_rows(size))
625-
results = self._convert_json_to_arrow(rows)
626632
self._next_row_index += results.num_rows
627633

628634
return results
@@ -632,11 +638,11 @@ def fetchall_arrow(self) -> "pyarrow.Table":
632638
Fetch all remaining rows as an Arrow table.
633639
"""
634640

635-
if not isinstance(self.results, JsonQueue):
636-
raise NotImplementedError("fetchall_arrow only supported for JSON data")
641+
results = self.results.remaining_rows()
642+
if isinstance(self.results, JsonQueue):
643+
results = self._convert_json_types(results)
644+
results = self._convert_json_to_arrow(results)
637645

638-
rows = self._convert_json_types(self.results.remaining_rows())
639-
results = self._convert_json_to_arrow(rows)
640646
self._next_row_index += results.num_rows
641647

642648
return results

0 commit comments

Comments
 (0)