Skip to content

Commit eeaee96

Browse files
committed
Implemented the columnar flow for non arrow users
1 parent d31063c commit eeaee96

File tree

5 files changed

+240
-32
lines changed

5 files changed

+240
-32
lines changed

src/databricks/sql/client.py

Lines changed: 84 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from typing import Dict, Tuple, List, Optional, Any, Union, Sequence
22

33
import pandas
4-
import pyarrow
4+
try:
5+
import pyarrow
6+
except ImportError:
7+
pyarrow = None
58
import requests
69
import json
710
import os
@@ -22,6 +25,8 @@
2225
ParamEscaper,
2326
inject_parameters,
2427
transform_paramstyle,
28+
ArrowQueue,
29+
ColumnQueue
2530
)
2631
from databricks.sql.parameters.native import (
2732
DbsqlParameterBase,
@@ -991,14 +996,14 @@ def fetchmany(self, size: int) -> List[Row]:
991996
else:
992997
raise Error("There is no active result set")
993998

994-
def fetchall_arrow(self) -> pyarrow.Table:
999+
def fetchall_arrow(self) -> "pyarrow.Table":
9951000
self._check_not_closed()
9961001
if self.active_result_set:
9971002
return self.active_result_set.fetchall_arrow()
9981003
else:
9991004
raise Error("There is no active result set")
10001005

1001-
def fetchmany_arrow(self, size) -> pyarrow.Table:
1006+
def fetchmany_arrow(self, size) -> "pyarrow.Table":
10021007
self._check_not_closed()
10031008
if self.active_result_set:
10041009
return self.active_result_set.fetchmany_arrow(size)
@@ -1143,6 +1148,18 @@ def _fill_results_buffer(self):
11431148
self.results = results
11441149
self.has_more_rows = has_more_rows
11451150

1151+
def _convert_columnar_table(self, table):
1152+
column_names = [c[0] for c in self.description]
1153+
ResultRow = Row(*column_names)
1154+
result = []
1155+
for row_index in range(len(table[0])):
1156+
curr_row = []
1157+
for col_index in range(len(table)):
1158+
curr_row.append(table[col_index][row_index])
1159+
result.append(ResultRow(*curr_row))
1160+
1161+
return result
1162+
11461163
def _convert_arrow_table(self, table):
11471164
column_names = [c[0] for c in self.description]
11481165
ResultRow = Row(*column_names)
@@ -1185,7 +1202,7 @@ def _convert_arrow_table(self, table):
11851202
def rownumber(self):
11861203
return self._next_row_index
11871204

1188-
def fetchmany_arrow(self, size: int) -> pyarrow.Table:
1205+
def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
11891206
"""
11901207
Fetch the next set of rows of a query result, returning a PyArrow table.
11911208
@@ -1210,7 +1227,42 @@ def fetchmany_arrow(self, size: int) -> pyarrow.Table:
12101227

12111228
return results
12121229

1213-
def fetchall_arrow(self) -> pyarrow.Table:
1230+
def merge_columnar(self, result1, result2):
1231+
"""
1232+
Function to merge / combining the columnar results into a single result
1233+
:param result1:
1234+
:param result2:
1235+
:return:
1236+
"""
1237+
merged_result = [result1[i] + result2[i] for i in range(len(result1))]
1238+
return merged_result
1239+
1240+
def fetchmany_columnar(self, size: int):
1241+
"""
1242+
Fetch the next set of rows of a query result, returning a Columnar Table.
1243+
An empty sequence is returned when no more rows are available.
1244+
"""
1245+
if size < 0:
1246+
raise ValueError("size argument for fetchmany is %s but must be >= 0", size)
1247+
1248+
results = self.results.next_n_rows(size)
1249+
n_remaining_rows = size - len(results[0])
1250+
self._next_row_index += len(results[0])
1251+
1252+
while (
1253+
n_remaining_rows > 0
1254+
and not self.has_been_closed_server_side
1255+
and self.has_more_rows
1256+
):
1257+
self._fill_results_buffer()
1258+
partial_results = self.results.next_n_rows(n_remaining_rows)
1259+
results = self.merge_columnar(results, partial_results)
1260+
n_remaining_rows -= len(partial_results[0])
1261+
self._next_row_index += len(partial_results[0])
1262+
1263+
return results
1264+
1265+
def fetchall_arrow(self) -> "pyarrow.Table":
12141266
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
12151267
results = self.results.remaining_rows()
12161268
self._next_row_index += results.num_rows
@@ -1223,12 +1275,30 @@ def fetchall_arrow(self) -> pyarrow.Table:
12231275

12241276
return results
12251277

1278+
def fetchall_columnar(self):
1279+
"""Fetch all (remaining) rows of a query result, returning them as a Columnar table."""
1280+
results = self.results.remaining_rows()
1281+
self._next_row_index += len(results[0])
1282+
1283+
while not self.has_been_closed_server_side and self.has_more_rows:
1284+
self._fill_results_buffer()
1285+
partial_results = self.results.remaining_rows()
1286+
results = self.merge_columnar(results, partial_results)
1287+
self._next_row_index += len(partial_results[0])
1288+
1289+
return results
1290+
12261291
def fetchone(self) -> Optional[Row]:
12271292
"""
12281293
Fetch the next row of a query result set, returning a single sequence,
12291294
or None when no more data is available.
12301295
"""
1231-
res = self._convert_arrow_table(self.fetchmany_arrow(1))
1296+
1297+
if isinstance(self.results, ColumnQueue):
1298+
res = self._convert_columnar_table(self.fetchmany_columnar(1))
1299+
else:
1300+
res = self._convert_arrow_table(self.fetchmany_arrow(1))
1301+
12321302
if len(res) > 0:
12331303
return res[0]
12341304
else:
@@ -1238,15 +1308,21 @@ def fetchall(self) -> List[Row]:
12381308
"""
12391309
Fetch all (remaining) rows of a query result, returning them as a list of rows.
12401310
"""
1241-
return self._convert_arrow_table(self.fetchall_arrow())
1311+
if isinstance(self.results, ColumnQueue):
1312+
return self._convert_columnar_table(self.fetchall_columnar())
1313+
else:
1314+
return self._convert_arrow_table(self.fetchall_arrow())
12421315

12431316
def fetchmany(self, size: int) -> List[Row]:
12441317
"""
12451318
Fetch the next set of rows of a query result, returning a list of rows.
12461319
12471320
An empty sequence is returned when no more rows are available.
12481321
"""
1249-
return self._convert_arrow_table(self.fetchmany_arrow(size))
1322+
if isinstance(self.results, ColumnQueue):
1323+
return self._convert_columnar_table(self.fetchmany_columnar(size))
1324+
else:
1325+
return self._convert_arrow_table(self.fetchmany_arrow(size))
12501326

12511327
def close(self) -> None:
12521328
"""

src/databricks/sql/thrift_backend.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
import threading
88
from typing import List, Union
99

10-
import pyarrow
10+
try:
11+
import pyarrow
12+
except ImportError:
13+
pyarrow = None
1114
import thrift.transport.THttpClient
1215
import thrift.protocol.TBinaryProtocol
1316
import thrift.transport.TSocket
@@ -621,6 +624,12 @@ def _get_metadata_resp(self, op_handle):
621624

622625
@staticmethod
623626
def _hive_schema_to_arrow_schema(t_table_schema):
627+
628+
if pyarrow is None:
629+
raise ImportError(
630+
"pyarrow is required to convert Hive schema to Arrow schema"
631+
)
632+
624633
def map_type(t_type_entry):
625634
if t_type_entry.primitiveEntry:
626635
return {
@@ -726,12 +735,17 @@ def _results_message_to_execute_response(self, resp, operation_state):
726735
description = self._hive_schema_to_description(
727736
t_result_set_metadata_resp.schema
728737
)
729-
schema_bytes = (
730-
t_result_set_metadata_resp.arrowSchema
731-
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
732-
.serialize()
733-
.to_pybytes()
734-
)
738+
739+
if pyarrow:
740+
schema_bytes = (
741+
t_result_set_metadata_resp.arrowSchema
742+
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
743+
.serialize()
744+
.to_pybytes()
745+
)
746+
else:
747+
schema_bytes = None
748+
735749
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
736750
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
737751
if direct_results and direct_results.resultSet:
@@ -827,7 +841,7 @@ def execute_command(
827841
getDirectResults=ttypes.TSparkGetDirectResults(
828842
maxRows=max_rows, maxBytes=max_bytes
829843
),
830-
canReadArrowResult=True,
844+
canReadArrowResult=True if pyarrow else False,
831845
canDecompressLZ4Result=lz4_compression,
832846
canDownloadResult=use_cloud_fetch,
833847
confOverlay={

0 commit comments

Comments
 (0)