@@ -28,6 +28,9 @@ const TOKIO_THREAD_NUMBER: usize = 1;
28
28
const WORKER_WAIT_TIMEOUT : Duration = Duration :: from_millis ( 100 ) ;
29
29
const SLOT_WAIT_TIMEOUT : Duration = Duration :: from_millis ( 1 ) ;
30
30
31
+ // POSTGRES WORLD
32
+ // Do not use any async functions in this part of the code.
33
+
31
34
#[ pg_guard]
32
35
pub ( crate ) fn init_datafusion_worker ( ) {
33
36
BackgroundWorkerBuilder :: new ( "datafusion" )
@@ -100,6 +103,14 @@ fn init_slots() -> Result<()> {
100
103
Ok ( ( ) )
101
104
}
102
105
106
+ fn response_error ( id : SlotNumber , ctx : & mut WorkerContext , stream : SlotStream , message : & str ) {
107
+ ctx. flush ( id) ;
108
+ send_error ( id, stream, message) . expect ( "Failed to send error response" ) ;
109
+ }
110
+
111
+ // POSTGRES - ASYNC BRIDGE
112
+ // The place where async world meets the postgres world.
113
+
103
114
#[ pg_guard]
104
115
#[ no_mangle]
105
116
pub extern "C" fn worker_main ( _arg : pg_sys:: Datum ) {
@@ -117,103 +128,12 @@ pub extern "C" fn worker_main(_arg: pg_sys::Datum) {
117
128
118
129
log ! ( "DataFusion worker is running" ) ;
119
130
while do_retry || BackgroundWorker :: wait_latch ( Some ( WORKER_WAIT_TIMEOUT ) ) {
120
- // Do not use any pgrx API in this loop: tokio has a multithreaded runtime,
121
- // while PostgreSQL functions can work only in single thread processes.
122
131
rt. block_on ( async {
123
132
do_retry = false ;
124
- // Process packets from the slots.
125
- for ( id, locked_slot) in Bus :: new ( ) . into_iter ( ) . enumerate ( ) {
126
- let Some ( slot) = locked_slot else {
127
- continue ;
128
- } ;
129
- let mut stream = SlotStream :: from ( slot) ;
130
- let header = match consume_header ( & mut stream) {
131
- Ok ( header) => header,
132
- Err ( err) => {
133
- errors[ id] = Some ( format_smolstr ! ( "Failed to consume header: {:?}" , err) ) ;
134
- continue ;
135
- }
136
- } ;
137
- if header. direction == Direction :: ToBackend {
138
- continue ;
139
- }
140
- let machine = & mut ctx. states [ id] ;
141
- let slot_id = u32:: try_from ( id) . expect ( "Failed to convert slot id to u32" ) ;
142
- let output = match machine. consume ( & header. packet ) {
143
- Ok ( output) => output,
144
- Err ( err) => {
145
- let msg = format_smolstr ! ( "Failed to change machine state: {:?}" , err) ;
146
- errors[ id] = Some ( msg) ;
147
- continue ;
148
- }
149
- } ;
150
- let handle = match output {
151
- Some ( ExecutorOutput :: Parse ) => tokio:: spawn ( parse ( header, stream) ) ,
152
- Some ( ExecutorOutput :: Flush ) => {
153
- ctx. flush ( slot_id) ;
154
- continue ;
155
- }
156
- Some ( ExecutorOutput :: Compile ) => {
157
- let Some ( stmt) = std:: mem:: take ( & mut ctx. statements [ id] ) else {
158
- errors[ id] = Some ( format_smolstr ! ( "No statement found for slot: {id}" ) ) ;
159
- continue ;
160
- } ;
161
- tokio:: spawn ( compile ( header, stream, stmt) )
162
- }
163
- Some ( ExecutorOutput :: Bind ) => {
164
- let Some ( plan) = std:: mem:: take ( & mut ctx. logical_plans [ id] ) else {
165
- errors[ id] =
166
- Some ( format_smolstr ! ( "No logical plan found for slot: {id}" ) ) ;
167
- continue ;
168
- } ;
169
- tokio:: spawn ( bind ( header, stream, plan) )
170
- }
171
- None => unreachable ! ( "Empty output in the worker state machine" ) ,
172
- } ;
173
- ctx. tasks . push ( ( slot_id, handle) ) ;
174
- }
175
- // Wait for the tasks to finish and process their results.
176
- for ( id, task) in & mut ctx. tasks {
177
- let result = task. await . expect ( "Failed to await task" ) ;
178
- match result {
179
- Ok ( TaskResult :: Parsing ( ( stmt, tables) ) ) => {
180
- let mut stream = wait_stream ( * id) . await ;
181
- if tables. is_empty ( ) {
182
- // We don't need any table metadata for this query.
183
- // So, write a fake metadata packet to the slot and proceed it
184
- // in the next iteration.
185
- do_retry = true ;
186
- if let Err ( err) = prepare_empty_metadata ( & mut stream) {
187
- errors[ * id as usize ] =
188
- Some ( format_smolstr ! ( "Failed to prepare metadata: {:?}" , err) ) ;
189
- continue ;
190
- }
191
- } else {
192
- send_table_refs ( * id, stream, tables. as_slice ( ) )
193
- . expect ( "Failed to reqest table references" ) ;
194
- }
195
- ctx. statements [ * id as usize ] = Some ( stmt) ;
196
- }
197
- Ok ( TaskResult :: Compilation ( plan) ) => {
198
- let stream = wait_stream ( * id) . await ;
199
- if let Err ( err) = request_params ( * id, stream) {
200
- errors[ * id as usize ] =
201
- Some ( format_smolstr ! ( "Failed to request params: {:?}" , err) ) ;
202
- continue ;
203
- }
204
- ctx. logical_plans [ * id as usize ] = Some ( plan) ;
205
- }
206
- Ok ( TaskResult :: Bind ( plan) ) => {
207
- ctx. logical_plans [ * id as usize ] = Some ( plan) ;
208
- }
209
- Err ( err) => {
210
- errors[ * id as usize ] =
211
- Some ( format_smolstr ! ( "Failed to execute task: {:?}" , err) )
212
- }
213
- }
214
- }
133
+ create_tasks ( & mut ctx, & mut errors, & mut do_retry) . await ;
134
+ wait_results ( & mut ctx, & mut errors, & mut do_retry) . await ;
215
135
} ) ;
216
- // Process errors in the main PostgreSQL thread .
136
+ // Process errors returned by the tasks .
217
137
for ( slot_id, msg) in errors. iter_mut ( ) . enumerate ( ) {
218
138
if let Some ( msg) = msg {
219
139
let stream;
@@ -233,6 +153,113 @@ pub extern "C" fn worker_main(_arg: pg_sys::Datum) {
233
153
set_worker_id ( INVALID_PROC_NUMBER ) ;
234
154
}
235
155
156
+ // ASYNC WORLD
157
+ // Do not use any pgrx symbols in async functions. Tokio has a multithreaded
158
+ // runtime, while postgres functions can work only in single thread.
159
+
160
+ /// Process packets from the slots and create tasks for them.
161
+ async fn create_tasks (
162
+ ctx : & mut WorkerContext ,
163
+ errors : & mut [ Option < SmolStr > ] ,
164
+ do_retry : & mut bool ,
165
+ ) {
166
+ for ( id, locked_slot) in Bus :: new ( ) . into_iter ( ) . enumerate ( ) {
167
+ let Some ( slot) = locked_slot else {
168
+ continue ;
169
+ } ;
170
+ let mut stream = SlotStream :: from ( slot) ;
171
+ let header = match consume_header ( & mut stream) {
172
+ Ok ( header) => header,
173
+ Err ( err) => {
174
+ errors[ id] = Some ( format_smolstr ! ( "Failed to consume header: {:?}" , err) ) ;
175
+ continue ;
176
+ }
177
+ } ;
178
+ if header. direction == Direction :: ToBackend {
179
+ continue ;
180
+ }
181
+ let machine = & mut ctx. states [ id] ;
182
+ let slot_id = u32:: try_from ( id) . expect ( "Failed to convert slot id to u32" ) ;
183
+ let output = match machine. consume ( & header. packet ) {
184
+ Ok ( output) => output,
185
+ Err ( err) => {
186
+ let msg = format_smolstr ! ( "Failed to change machine state: {:?}" , err) ;
187
+ errors[ id] = Some ( msg) ;
188
+ continue ;
189
+ }
190
+ } ;
191
+ let handle = match output {
192
+ Some ( ExecutorOutput :: Parse ) => tokio:: spawn ( parse ( header, stream) ) ,
193
+ Some ( ExecutorOutput :: Flush ) => {
194
+ ctx. flush ( slot_id) ;
195
+ continue ;
196
+ }
197
+ Some ( ExecutorOutput :: Compile ) => {
198
+ let Some ( stmt) = std:: mem:: take ( & mut ctx. statements [ id] ) else {
199
+ errors[ id] = Some ( format_smolstr ! ( "No statement found for slot: {id}" ) ) ;
200
+ continue ;
201
+ } ;
202
+ tokio:: spawn ( compile ( header, stream, stmt) )
203
+ }
204
+ Some ( ExecutorOutput :: Bind ) => {
205
+ let Some ( plan) = std:: mem:: take ( & mut ctx. logical_plans [ id] ) else {
206
+ errors[ id] = Some ( format_smolstr ! ( "No logical plan found for slot: {id}" ) ) ;
207
+ continue ;
208
+ } ;
209
+ tokio:: spawn ( bind ( header, stream, plan) )
210
+ }
211
+ None => unreachable ! ( "Empty output in the worker state machine" ) ,
212
+ } ;
213
+ ctx. tasks . push ( ( slot_id, handle) ) ;
214
+ }
215
+ }
216
+
217
+ /// Wait for the tasks to finish and process their results.
218
+ async fn wait_results (
219
+ ctx : & mut WorkerContext ,
220
+ errors : & mut [ Option < SmolStr > ] ,
221
+ do_retry : & mut bool ,
222
+ ) {
223
+ for ( id, task) in & mut ctx. tasks {
224
+ let result = task. await . expect ( "Failed to await task" ) ;
225
+ match result {
226
+ Ok ( TaskResult :: Parsing ( ( stmt, tables) ) ) => {
227
+ let mut stream = wait_stream ( * id) . await ;
228
+ if tables. is_empty ( ) {
229
+ // We don't need any table metadata for this query.
230
+ // So, write a fake metadata packet to the slot and proceed it
231
+ // in the next iteration.
232
+ * do_retry = true ;
233
+ if let Err ( err) = prepare_empty_metadata ( & mut stream) {
234
+ errors[ * id as usize ] =
235
+ Some ( format_smolstr ! ( "Failed to prepare metadata: {:?}" , err) ) ;
236
+ continue ;
237
+ }
238
+ } else {
239
+ send_table_refs ( * id, stream, tables. as_slice ( ) )
240
+ . expect ( "Failed to reqest table references" ) ;
241
+ }
242
+ ctx. statements [ * id as usize ] = Some ( stmt) ;
243
+ }
244
+ Ok ( TaskResult :: Compilation ( plan) ) => {
245
+ let stream = wait_stream ( * id) . await ;
246
+ if let Err ( err) = request_params ( * id, stream) {
247
+ errors[ * id as usize ] =
248
+ Some ( format_smolstr ! ( "Failed to request params: {:?}" , err) ) ;
249
+ continue ;
250
+ }
251
+ ctx. logical_plans [ * id as usize ] = Some ( plan) ;
252
+ }
253
+ Ok ( TaskResult :: Bind ( plan) ) => {
254
+ ctx. logical_plans [ * id as usize ] = Some ( plan) ;
255
+ }
256
+ Err ( err) => {
257
+ errors[ * id as usize ] = Some ( format_smolstr ! ( "Failed to execute task: {:?}" , err) )
258
+ }
259
+ }
260
+ }
261
+ }
262
+
236
263
#[ inline( always) ]
237
264
async fn wait_stream ( slot_id : u32 ) -> SlotStream {
238
265
loop {
@@ -283,11 +310,6 @@ async fn bind(
283
310
Ok ( TaskResult :: Bind ( plan) )
284
311
}
285
312
286
- fn response_error ( id : SlotNumber , ctx : & mut WorkerContext , stream : SlotStream , message : & str ) {
287
- ctx. flush ( id) ;
288
- send_error ( id, stream, message) . expect ( "Failed to send error response" ) ;
289
- }
290
-
291
313
#[ cfg( test) ]
292
314
mod tests {
293
315
use datafusion:: scalar:: ScalarValue ;
0 commit comments