1
- from pathlib import Path
2
-
3
1
import duckdb
4
- import pyarrow as pa
5
- from substrait .gen .proto import algebra_pb2 , plan_pb2 , type_pb2
6
-
2
+ import pytest
7
3
8
- def create_connection () -> duckdb .DuckDBPyConnection :
9
- """Create a connection to the backend."""
10
- connection = duckdb .connect (config = {'max_memory' : '100GB' ,
11
- "allow_unsigned_extensions" : "true" ,
12
- 'temp_directory' : str (Path ('.' ).resolve ())})
13
- connection .install_extension ('substrait' )
14
- connection .load_extension ('substrait' )
15
-
16
- return connection
4
+ plan_pb2 = pytest .importorskip ("substrait.gen.proto.plan_pb2" )
5
+ algebra_pb2 = pytest .importorskip ("substrait.gen.proto.algebra_pb2" )
6
+ type_pb2 = pytest .importorskip ("substrait.gen.proto.type_pb2" )
7
+ pa = pytest .importorskip ("pyarrow" )
17
8
18
9
19
10
def execute_plan (connection : duckdb .DuckDBPyConnection , plan : plan_pb2 .Plan ) -> pa .lib .Table :
20
- """Execute the given Substrait plan against DuckDB."""
21
11
plan_data = plan .SerializeToString ()
22
-
23
12
try :
24
13
query_result = connection .from_substrait (proto = plan_data )
25
14
except Exception as err :
26
15
raise ValueError (f'DuckDB Execution Error: { err } ' ) from err
27
16
return query_result .arrow ()
28
17
29
-
30
- def register_table (
31
- connection : duckdb .DuckDBPyConnection ,
32
- table_name : str ,
33
- location : Path ,
34
- use_duckdb_python_api : bool = True ) -> None :
35
- """Register the given table with the backend."""
36
- if use_duckdb_python_api :
37
- table_data = connection .read_parquet (location )
38
- connection .register (table_name , table_data )
39
- else :
40
- files_sql = f"CREATE OR REPLACE TABLE { table_name } AS FROM read_parquet(['{ location } '])"
41
- connection .execute (files_sql )
42
-
43
-
44
- def register_table_with_arrow_data (
45
- connection : duckdb .DuckDBPyConnection ,
46
- table_name : str ,
47
- data : bytes ) -> None :
48
- """Register the given arrow data as a table with the backend."""
49
- r = pa .ipc .open_stream (data ).read_all ()
50
- connection .register (table_name , r )
51
-
52
-
53
- def describe_table (connection , table_name : str ):
54
- s = connection .execute (f"SELECT * FROM { name } " )
55
- t = connection .table (name )
56
- v = connection .view (name )
57
- print (f's = %s' % s .fetch_arrow_table ())
58
- print (f't = %s' % t )
59
- print (f'v = %s' % v )
60
-
18
+ def execute_query (connection , table_name : str ):
61
19
plan = plan_pb2 .Plan (relations = [
62
20
plan_pb2 .PlanRel (
63
21
root = algebra_pb2 .RelRoot (
@@ -68,35 +26,20 @@ def describe_table(connection, table_name: str):
68
26
struct = type_pb2 .Type .Struct (
69
27
types = [type_pb2 .Type (i64 = type_pb2 .Type .I64 ()),
70
28
type_pb2 .Type (string = type_pb2 .Type .String ())])),
71
- named_table = algebra_pb2 .ReadRel .NamedTable (names = [name ])
29
+ named_table = algebra_pb2 .ReadRel .NamedTable (names = [table_name ])
72
30
)),
73
31
names = ['a' , 'b' ]))])
74
- print ('About to execute Substrait' )
75
- x = execute_plan (connection , plan )
76
- print (f'x = %s' % x )
77
-
78
-
79
- def serialize_table (table : pa .Table ) -> bytes :
80
- """Serialize a PyArrow table to bytes."""
81
- sink = pa .BufferOutputStream ()
82
- with pa .ipc .new_stream (sink , table .schema ) as writer :
83
- writer .write_table (table )
84
- return sink .getvalue ().to_pybytes ()
85
-
32
+ return execute_plan (connection , plan )
86
33
87
- if __name__ == '__main__' :
88
- connection = create_connection ()
89
- name = 'my_table'
34
+ def test_substrait_pyarrow (require ):
35
+ connection = require ('substrait' )
90
36
91
- use_parquet = False
92
- if use_parquet :
93
- register_table (connection , name ,
94
- '/Users/davids/projects/voltrondata-spark-substrait-gateway/third_party/tpch/parquet/customer/part-0.parquet' )
95
- else :
96
- table = pa .table ({'column1' : [1 , 2 , 3 ], 'column2' : ['a' , 'b' , 'c' ]})
97
- serialized_data = serialize_table (table )
98
- register_table_with_arrow_data (connection , name , serialized_data )
37
+ connection .execute ('CREATE TABLE integers (a integer, b varchar )' )
38
+ connection .execute ('INSERT INTO integers VALUES (0, \' a\' ),(1, \' b\' )' )
39
+ arrow_table = connection .execute ('FROM integers' ).arrow ()
99
40
100
- describe_table (connection , name )
41
+ connection .register ("arrow_integers" , arrow_table )
42
+
43
+ arrow_result = execute_query (connection , "arrow_integers" )
101
44
102
- connection .close ()
45
+ assert connection .execute ( "FROM arrow_result" ). fetchall () == 0
0 commit comments