Skip to content

Commit c576110

Browse files
committed
Setup working dynamic change from ColumnQueue to ArrowQueue
1 parent 1cfaae2 commit c576110

File tree

4 files changed

+70
-53
lines changed

4 files changed

+70
-53
lines changed

databricks_sql_connector_core/src/databricks_sql_connector_core/sql/client.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -784,8 +784,8 @@ def execute(
784784
parameters=prepared_params,
785785
)
786786

787-
print("Line 781")
788-
print(execute_response)
787+
# print("Line 781")
788+
# print(execute_response)
789789
self.active_result_set = ResultSet(
790790
self.connection,
791791
execute_response,
@@ -1141,7 +1141,7 @@ def _fill_results_buffer(self):
11411141
def _convert_columnar_table(self, table):
11421142
column_names = [c[0] for c in self.description]
11431143
ResultRow = Row(*column_names)
1144-
print("Table\n",table)
1144+
# print("Table\n",table)
11451145
result = []
11461146
for row_index in range(len(table[0])):
11471147
curr_row = []
@@ -1164,23 +1164,20 @@ def _convert_arrow_table(self, table):
11641164
# Need to use nullable types, as otherwise type can change when there are missing values.
11651165
# See https://arrow.apache.org/docs/python/pandas.html#nullable-types
11661166
# NOTE: This api is epxerimental https://pandas.pydata.org/pandas-docs/stable/user_guide/integer_na.html
1167-
try:
1168-
dtype_mapping = {
1169-
pyarrow.int8(): pandas.Int8Dtype(),
1170-
pyarrow.int16(): pandas.Int16Dtype(),
1171-
pyarrow.int32(): pandas.Int32Dtype(),
1172-
pyarrow.int64(): pandas.Int64Dtype(),
1173-
pyarrow.uint8(): pandas.UInt8Dtype(),
1174-
pyarrow.uint16(): pandas.UInt16Dtype(),
1175-
pyarrow.uint32(): pandas.UInt32Dtype(),
1176-
pyarrow.uint64(): pandas.UInt64Dtype(),
1177-
pyarrow.bool_(): pandas.BooleanDtype(),
1178-
pyarrow.float32(): pandas.Float32Dtype(),
1179-
pyarrow.float64(): pandas.Float64Dtype(),
1180-
pyarrow.string(): pandas.StringDtype(),
1181-
}
1182-
except AttributeError:
1183-
print("pyarrow is not present")
1167+
dtype_mapping = {
1168+
pyarrow.int8(): pandas.Int8Dtype(),
1169+
pyarrow.int16(): pandas.Int16Dtype(),
1170+
pyarrow.int32(): pandas.Int32Dtype(),
1171+
pyarrow.int64(): pandas.Int64Dtype(),
1172+
pyarrow.uint8(): pandas.UInt8Dtype(),
1173+
pyarrow.uint16(): pandas.UInt16Dtype(),
1174+
pyarrow.uint32(): pandas.UInt32Dtype(),
1175+
pyarrow.uint64(): pandas.UInt64Dtype(),
1176+
pyarrow.bool_(): pandas.BooleanDtype(),
1177+
pyarrow.float32(): pandas.Float32Dtype(),
1178+
pyarrow.float64(): pandas.Float64Dtype(),
1179+
pyarrow.string(): pandas.StringDtype(),
1180+
}
11841181

11851182
# Need to rename columns, as the to_pandas function cannot handle duplicate column names
11861183
table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)])
@@ -1222,6 +1219,20 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
12221219

12231220
return results
12241221

1222+
def fetchmany_columnar(self, size: int):
1223+
"""
1224+
Fetch the next set of rows of a query result, returning a Columnar Table.
1225+
1226+
An empty sequence is returned when no more rows are available.
1227+
"""
1228+
if size < 0:
1229+
raise ValueError("size argument for fetchmany is %s but must be >= 0", size)
1230+
1231+
results = self.results.next_n_rows(size)
1232+
self._next_row_index += results.num_rows
1233+
1234+
return results
1235+
12251236
def fetchall_arrow(self) -> "pyarrow.Table":
12261237
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
12271238
results = self.results.remaining_rows()
@@ -1245,7 +1256,11 @@ def fetchone(self) -> Optional[Row]:
12451256
Fetch the next row of a query result set, returning a single sequence,
12461257
or None when no more data is available.
12471258
"""
1248-
res = self._convert_arrow_table(self.fetchmany_arrow(1))
1259+
if isinstance(self.results, ColumnQueue):
1260+
res = self._convert_columnar_table(self.fetchmany_columnar(1))
1261+
else:
1262+
res = self._convert_arrow_table(self.fetchmany_arrow(1))
1263+
12491264
if len(res) > 0:
12501265
return res[0]
12511266
else:
@@ -1260,14 +1275,16 @@ def fetchall(self) -> List[Row]:
12601275
else:
12611276
return self._convert_arrow_table(self.fetchall_arrow())
12621277

1263-
12641278
def fetchmany(self, size: int) -> List[Row]:
12651279
"""
12661280
Fetch the next set of rows of a query result, returning a list of rows.
12671281
12681282
An empty sequence is returned when no more rows are available.
12691283
"""
1270-
return self._convert_arrow_table(self.fetchmany_arrow(size))
1284+
if isinstance(self.results, ColumnQueue):
1285+
return self._convert_columnar_table(self.fetchmany_columnar(size))
1286+
else:
1287+
return self._convert_arrow_table(self.fetchmany_arrow(size))
12711288

12721289
def close(self) -> None:
12731290
"""

databricks_sql_connector_core/src/databricks_sql_connector_core/sql/thrift_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
743743
else:
744744
t_result_set_metadata_resp = self._get_metadata_resp(resp.operationHandle)
745745

746-
print(f"Line 739 - {t_result_set_metadata_resp.resultFormat}")
746+
# print(f"Line 739 - {t_result_set_metadata_resp.resultFormat}")
747747
if t_result_set_metadata_resp.resultFormat not in [
748748
ttypes.TSparkRowSetType.ARROW_BASED_SET,
749749
ttypes.TSparkRowSetType.COLUMN_BASED_SET,
@@ -873,7 +873,7 @@ def execute_command(
873873
getDirectResults=ttypes.TSparkGetDirectResults(
874874
maxRows=max_rows, maxBytes=max_bytes
875875
),
876-
canReadArrowResult=False,
876+
canReadArrowResult=True if pyarrow else False,
877877
canDecompressLZ4Result=lz4_compression,
878878
canDownloadResult=use_cloud_fetch,
879879
confOverlay={

databricks_sql_connector_core/src/databricks_sql_connector_core/sql/utils.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,16 @@ def build_queue(
7575
ResultSetQueue
7676
"""
7777

78-
def trow_to_json(trow):
79-
# Step 1: Serialize TRow using Thrift's TJSONProtocol
80-
transport = TTransport.TMemoryBuffer()
81-
protocol = TJSONProtocol.TJSONProtocol(transport)
82-
trow.write(protocol)
83-
84-
# Step 2: Extract JSON string from the transport
85-
json_str = transport.getvalue().decode('utf-8')
86-
87-
return json_str
78+
# def trow_to_json(trow):
79+
# # Step 1: Serialize TRow using Thrift's TJSONProtocol
80+
# transport = TTransport.TMemoryBuffer()
81+
# protocol = TJSONProtocol.TJSONProtocol(transport)
82+
# trow.write(protocol)
83+
#
84+
# # Step 2: Extract JSON string from the transport
85+
# json_str = transport.getvalue().decode('utf-8')
86+
#
87+
# return json_str
8888

8989
if row_set_type == TSparkRowSetType.ARROW_BASED_SET:
9090
arrow_table, n_valid_rows = convert_arrow_based_set_to_arrow_table(
@@ -95,30 +95,30 @@ def trow_to_json(trow):
9595
)
9696
return ArrowQueue(converted_arrow_table, n_valid_rows)
9797
elif row_set_type == TSparkRowSetType.COLUMN_BASED_SET:
98-
print("Lin 79 ")
99-
print(type(t_row_set))
100-
print(t_row_set)
101-
json_str = json.loads(trow_to_json(t_row_set))
102-
pretty_json = json.dumps(json_str, indent=2)
103-
print(pretty_json)
98+
# print("Lin 79 ")
99+
# print(type(t_row_set))
100+
# print(t_row_set)
101+
# json_str = json.loads(trow_to_json(t_row_set))
102+
# pretty_json = json.dumps(json_str, indent=2)
103+
# print(pretty_json)
104104

105105
converted_column_table, column_names = convert_column_based_set_to_column_table(
106106
t_row_set.columns,
107107
description)
108-
print(converted_column_table, column_names)
108+
# print(converted_column_table, column_names)
109109

110110
return ColumnQueue(converted_column_table, column_names)
111111

112-
print(columnQueue.next_n_rows(2))
113-
print(columnQueue.next_n_rows(2))
114-
print(columnQueue.remaining_rows())
115-
arrow_table, n_valid_rows = convert_column_based_set_to_arrow_table(
116-
t_row_set.columns, description
117-
)
118-
converted_arrow_table = convert_decimals_in_arrow_table(
119-
arrow_table, description
120-
)
121-
return ArrowQueue(converted_arrow_table, n_valid_rows)
112+
# print(columnQueue.next_n_rows(2))
113+
# print(columnQueue.next_n_rows(2))
114+
# print(columnQueue.remaining_rows())
115+
# arrow_table, n_valid_rows = convert_column_based_set_to_arrow_table(
116+
# t_row_set.columns, description
117+
# )
118+
# converted_arrow_table = convert_decimals_in_arrow_table(
119+
# arrow_table, description
120+
# )
121+
# return ArrowQueue(converted_arrow_table, n_valid_rows)
122122
elif row_set_type == TSparkRowSetType.URL_BASED_SET:
123123
return CloudFetchQueue(
124124
schema_bytes=arrow_schema_bytes,

setup_script.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,6 @@ def build_and_install_library(directory_name):
2525

2626

2727
if __name__ == "__main__":
28-
build_and_install_library("databricks_sql_connector_core")
28+
# build_and_install_library("databricks_sql_connector_core")
2929
build_and_install_library("databricks_sql_connector")
3030
# build_and_install_library("databricks_sqlalchemy")

0 commit comments

Comments
 (0)