Skip to content

Commit 9468cce

Browse files
committed
pyarrow
1 parent 7bec374 commit 9468cce

File tree

1 file changed

+17
-74
lines changed

1 file changed

+17
-74
lines changed

test/python/test_pyarrow.py

Lines changed: 17 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,21 @@
1-
from pathlib import Path
2-
31
import duckdb
4-
import pyarrow as pa
5-
from substrait.gen.proto import algebra_pb2, plan_pb2, type_pb2
6-
2+
import pytest
73

8-
def create_connection() -> duckdb.DuckDBPyConnection:
9-
"""Create a connection to the backend."""
10-
connection = duckdb.connect(config={'max_memory': '100GB',
11-
"allow_unsigned_extensions": "true",
12-
'temp_directory': str(Path('.').resolve())})
13-
connection.install_extension('substrait')
14-
connection.load_extension('substrait')
15-
16-
return connection
4+
plan_pb2 = pytest.importorskip("substrait.gen.proto.plan_pb2")
5+
algebra_pb2 = pytest.importorskip("substrait.gen.proto.algebra_pb2")
6+
type_pb2 = pytest.importorskip("substrait.gen.proto.type_pb2")
7+
pa = pytest.importorskip("pyarrow")
178

189

1910
def execute_plan(connection: duckdb.DuckDBPyConnection, plan: plan_pb2.Plan) -> pa.lib.Table:
20-
"""Execute the given Substrait plan against DuckDB."""
2111
plan_data = plan.SerializeToString()
22-
2312
try:
2413
query_result = connection.from_substrait(proto=plan_data)
2514
except Exception as err:
2615
raise ValueError(f'DuckDB Execution Error: {err}') from err
2716
return query_result.arrow()
2817

29-
30-
def register_table(
31-
connection: duckdb.DuckDBPyConnection,
32-
table_name: str,
33-
location: Path,
34-
use_duckdb_python_api: bool = True) -> None:
35-
"""Register the given table with the backend."""
36-
if use_duckdb_python_api:
37-
table_data = connection.read_parquet(location)
38-
connection.register(table_name, table_data)
39-
else:
40-
files_sql = f"CREATE OR REPLACE TABLE {table_name} AS FROM read_parquet(['{location}'])"
41-
connection.execute(files_sql)
42-
43-
44-
def register_table_with_arrow_data(
45-
connection: duckdb.DuckDBPyConnection,
46-
table_name: str,
47-
data: bytes) -> None:
48-
"""Register the given arrow data as a table with the backend."""
49-
r = pa.ipc.open_stream(data).read_all()
50-
connection.register(table_name, r)
51-
52-
53-
def describe_table(connection, table_name: str):
54-
s = connection.execute(f"SELECT * FROM {name}")
55-
t = connection.table(name)
56-
v = connection.view(name)
57-
print(f's = %s' % s.fetch_arrow_table())
58-
print(f't = %s' % t)
59-
print(f'v = %s' % v)
60-
18+
def execute_query(connection, table_name: str):
6119
plan = plan_pb2.Plan(relations=[
6220
plan_pb2.PlanRel(
6321
root=algebra_pb2.RelRoot(
@@ -68,35 +26,20 @@ def describe_table(connection, table_name: str):
6826
struct=type_pb2.Type.Struct(
6927
types=[type_pb2.Type(i64=type_pb2.Type.I64()),
7028
type_pb2.Type(string=type_pb2.Type.String())])),
71-
named_table=algebra_pb2.ReadRel.NamedTable(names=[name])
29+
named_table=algebra_pb2.ReadRel.NamedTable(names=[table_name])
7230
)),
7331
names=['a', 'b']))])
74-
print('About to execute Substrait')
75-
x = execute_plan(connection, plan)
76-
print(f'x = %s' % x)
77-
78-
79-
def serialize_table(table: pa.Table) -> bytes:
80-
"""Serialize a PyArrow table to bytes."""
81-
sink = pa.BufferOutputStream()
82-
with pa.ipc.new_stream(sink, table.schema) as writer:
83-
writer.write_table(table)
84-
return sink.getvalue().to_pybytes()
85-
32+
return execute_plan(connection, plan)
8633

87-
if __name__ == '__main__':
88-
connection = create_connection()
89-
name = 'my_table'
34+
def test_substrait_pyarrow(require):
35+
connection = require('substrait')
9036

91-
use_parquet = False
92-
if use_parquet:
93-
register_table(connection, name,
94-
'/Users/davids/projects/voltrondata-spark-substrait-gateway/third_party/tpch/parquet/customer/part-0.parquet')
95-
else:
96-
table = pa.table({'column1': [1, 2, 3], 'column2': ['a', 'b', 'c']})
97-
serialized_data = serialize_table(table)
98-
register_table_with_arrow_data(connection, name, serialized_data)
37+
connection.execute('CREATE TABLE integers (a integer, b varchar )')
38+
connection.execute('INSERT INTO integers VALUES (0, \'a\'),(1, \'b\')')
39+
arrow_table = connection.execute('FROM integers').arrow()
9940

100-
describe_table(connection, name)
41+
connection.register("arrow_integers", arrow_table)
42+
43+
arrow_result = execute_query(connection, "arrow_integers")
10144

102-
connection.close()
45+
assert connection.execute("FROM arrow_result").fetchall() == 0

0 commit comments

Comments
 (0)