Skip to content

Commit 57db3c8

Browse files
committed
chore: separate async and postgres code in worker
1 parent 243990f commit 57db3c8

File tree

1 file changed

+121
-99
lines changed

1 file changed

+121
-99
lines changed

src/worker.rs

Lines changed: 121 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ const TOKIO_THREAD_NUMBER: usize = 1;
2828
const WORKER_WAIT_TIMEOUT: Duration = Duration::from_millis(100);
2929
const SLOT_WAIT_TIMEOUT: Duration = Duration::from_millis(1);
3030

31+
// POSTGRES WORLD
32+
// Do not use any async functions in this part of the code.
33+
3134
#[pg_guard]
3235
pub(crate) fn init_datafusion_worker() {
3336
BackgroundWorkerBuilder::new("datafusion")
@@ -100,6 +103,14 @@ fn init_slots() -> Result<()> {
100103
Ok(())
101104
}
102105

106+
fn response_error(id: SlotNumber, ctx: &mut WorkerContext, stream: SlotStream, message: &str) {
107+
ctx.flush(id);
108+
send_error(id, stream, message).expect("Failed to send error response");
109+
}
110+
111+
// POSTGRES - ASYNC BRIDGE
112+
// The place where async world meets the postgres world.
113+
103114
#[pg_guard]
104115
#[no_mangle]
105116
pub extern "C" fn worker_main(_arg: pg_sys::Datum) {
@@ -117,103 +128,12 @@ pub extern "C" fn worker_main(_arg: pg_sys::Datum) {
117128

118129
log!("DataFusion worker is running");
119130
while do_retry || BackgroundWorker::wait_latch(Some(WORKER_WAIT_TIMEOUT)) {
120-
// Do not use any pgrx API in this loop: tokio has a multithreaded runtime,
121-
// while PostgreSQL functions can work only in single thread processes.
122131
rt.block_on(async {
123132
do_retry = false;
124-
// Process packets from the slots.
125-
for (id, locked_slot) in Bus::new().into_iter().enumerate() {
126-
let Some(slot) = locked_slot else {
127-
continue;
128-
};
129-
let mut stream = SlotStream::from(slot);
130-
let header = match consume_header(&mut stream) {
131-
Ok(header) => header,
132-
Err(err) => {
133-
errors[id] = Some(format_smolstr!("Failed to consume header: {:?}", err));
134-
continue;
135-
}
136-
};
137-
if header.direction == Direction::ToBackend {
138-
continue;
139-
}
140-
let machine = &mut ctx.states[id];
141-
let slot_id = u32::try_from(id).expect("Failed to convert slot id to u32");
142-
let output = match machine.consume(&header.packet) {
143-
Ok(output) => output,
144-
Err(err) => {
145-
let msg = format_smolstr!("Failed to change machine state: {:?}", err);
146-
errors[id] = Some(msg);
147-
continue;
148-
}
149-
};
150-
let handle = match output {
151-
Some(ExecutorOutput::Parse) => tokio::spawn(parse(header, stream)),
152-
Some(ExecutorOutput::Flush) => {
153-
ctx.flush(slot_id);
154-
continue;
155-
}
156-
Some(ExecutorOutput::Compile) => {
157-
let Some(stmt) = std::mem::take(&mut ctx.statements[id]) else {
158-
errors[id] = Some(format_smolstr!("No statement found for slot: {id}"));
159-
continue;
160-
};
161-
tokio::spawn(compile(header, stream, stmt))
162-
}
163-
Some(ExecutorOutput::Bind) => {
164-
let Some(plan) = std::mem::take(&mut ctx.logical_plans[id]) else {
165-
errors[id] =
166-
Some(format_smolstr!("No logical plan found for slot: {id}"));
167-
continue;
168-
};
169-
tokio::spawn(bind(header, stream, plan))
170-
}
171-
None => unreachable!("Empty output in the worker state machine"),
172-
};
173-
ctx.tasks.push((slot_id, handle));
174-
}
175-
// Wait for the tasks to finish and process their results.
176-
for (id, task) in &mut ctx.tasks {
177-
let result = task.await.expect("Failed to await task");
178-
match result {
179-
Ok(TaskResult::Parsing((stmt, tables))) => {
180-
let mut stream = wait_stream(*id).await;
181-
if tables.is_empty() {
182-
// We don't need any table metadata for this query.
183-
// So, write a fake metadata packet to the slot and proceed it
184-
// in the next iteration.
185-
do_retry = true;
186-
if let Err(err) = prepare_empty_metadata(&mut stream) {
187-
errors[*id as usize] =
188-
Some(format_smolstr!("Failed to prepare metadata: {:?}", err));
189-
continue;
190-
}
191-
} else {
192-
send_table_refs(*id, stream, tables.as_slice())
193-
.expect("Failed to reqest table references");
194-
}
195-
ctx.statements[*id as usize] = Some(stmt);
196-
}
197-
Ok(TaskResult::Compilation(plan)) => {
198-
let stream = wait_stream(*id).await;
199-
if let Err(err) = request_params(*id, stream) {
200-
errors[*id as usize] =
201-
Some(format_smolstr!("Failed to request params: {:?}", err));
202-
continue;
203-
}
204-
ctx.logical_plans[*id as usize] = Some(plan);
205-
}
206-
Ok(TaskResult::Bind(plan)) => {
207-
ctx.logical_plans[*id as usize] = Some(plan);
208-
}
209-
Err(err) => {
210-
errors[*id as usize] =
211-
Some(format_smolstr!("Failed to execute task: {:?}", err))
212-
}
213-
}
214-
}
133+
create_tasks(&mut ctx, &mut errors, &mut do_retry).await;
134+
wait_results(&mut ctx, &mut errors, &mut do_retry).await;
215135
});
216-
// Process errors in the main PostgreSQL thread.
136+
// Process errors returned by the tasks.
217137
for (slot_id, msg) in errors.iter_mut().enumerate() {
218138
if let Some(msg) = msg {
219139
let stream;
@@ -233,6 +153,113 @@ pub extern "C" fn worker_main(_arg: pg_sys::Datum) {
233153
set_worker_id(INVALID_PROC_NUMBER);
234154
}
235155

156+
// ASYNC WORLD
157+
// Do not use any pgrx symbols in async functions. Tokio has a multithreaded
158+
// runtime, while postgres functions can work only in single thread.
159+
160+
/// Process packets from the slots and create tasks for them.
161+
async fn create_tasks(
162+
ctx: &mut WorkerContext,
163+
errors: &mut [Option<SmolStr>],
164+
do_retry: &mut bool,
165+
) {
166+
for (id, locked_slot) in Bus::new().into_iter().enumerate() {
167+
let Some(slot) = locked_slot else {
168+
continue;
169+
};
170+
let mut stream = SlotStream::from(slot);
171+
let header = match consume_header(&mut stream) {
172+
Ok(header) => header,
173+
Err(err) => {
174+
errors[id] = Some(format_smolstr!("Failed to consume header: {:?}", err));
175+
continue;
176+
}
177+
};
178+
if header.direction == Direction::ToBackend {
179+
continue;
180+
}
181+
let machine = &mut ctx.states[id];
182+
let slot_id = u32::try_from(id).expect("Failed to convert slot id to u32");
183+
let output = match machine.consume(&header.packet) {
184+
Ok(output) => output,
185+
Err(err) => {
186+
let msg = format_smolstr!("Failed to change machine state: {:?}", err);
187+
errors[id] = Some(msg);
188+
continue;
189+
}
190+
};
191+
let handle = match output {
192+
Some(ExecutorOutput::Parse) => tokio::spawn(parse(header, stream)),
193+
Some(ExecutorOutput::Flush) => {
194+
ctx.flush(slot_id);
195+
continue;
196+
}
197+
Some(ExecutorOutput::Compile) => {
198+
let Some(stmt) = std::mem::take(&mut ctx.statements[id]) else {
199+
errors[id] = Some(format_smolstr!("No statement found for slot: {id}"));
200+
continue;
201+
};
202+
tokio::spawn(compile(header, stream, stmt))
203+
}
204+
Some(ExecutorOutput::Bind) => {
205+
let Some(plan) = std::mem::take(&mut ctx.logical_plans[id]) else {
206+
errors[id] = Some(format_smolstr!("No logical plan found for slot: {id}"));
207+
continue;
208+
};
209+
tokio::spawn(bind(header, stream, plan))
210+
}
211+
None => unreachable!("Empty output in the worker state machine"),
212+
};
213+
ctx.tasks.push((slot_id, handle));
214+
}
215+
}
216+
217+
/// Wait for the tasks to finish and process their results.
218+
async fn wait_results(
219+
ctx: &mut WorkerContext,
220+
errors: &mut [Option<SmolStr>],
221+
do_retry: &mut bool,
222+
) {
223+
for (id, task) in &mut ctx.tasks {
224+
let result = task.await.expect("Failed to await task");
225+
match result {
226+
Ok(TaskResult::Parsing((stmt, tables))) => {
227+
let mut stream = wait_stream(*id).await;
228+
if tables.is_empty() {
229+
// We don't need any table metadata for this query.
230+
// So, write a fake metadata packet to the slot and proceed it
231+
// in the next iteration.
232+
*do_retry = true;
233+
if let Err(err) = prepare_empty_metadata(&mut stream) {
234+
errors[*id as usize] =
235+
Some(format_smolstr!("Failed to prepare metadata: {:?}", err));
236+
continue;
237+
}
238+
} else {
239+
send_table_refs(*id, stream, tables.as_slice())
240+
.expect("Failed to reqest table references");
241+
}
242+
ctx.statements[*id as usize] = Some(stmt);
243+
}
244+
Ok(TaskResult::Compilation(plan)) => {
245+
let stream = wait_stream(*id).await;
246+
if let Err(err) = request_params(*id, stream) {
247+
errors[*id as usize] =
248+
Some(format_smolstr!("Failed to request params: {:?}", err));
249+
continue;
250+
}
251+
ctx.logical_plans[*id as usize] = Some(plan);
252+
}
253+
Ok(TaskResult::Bind(plan)) => {
254+
ctx.logical_plans[*id as usize] = Some(plan);
255+
}
256+
Err(err) => {
257+
errors[*id as usize] = Some(format_smolstr!("Failed to execute task: {:?}", err))
258+
}
259+
}
260+
}
261+
}
262+
236263
#[inline(always)]
237264
async fn wait_stream(slot_id: u32) -> SlotStream {
238265
loop {
@@ -283,11 +310,6 @@ async fn bind(
283310
Ok(TaskResult::Bind(plan))
284311
}
285312

286-
fn response_error(id: SlotNumber, ctx: &mut WorkerContext, stream: SlotStream, message: &str) {
287-
ctx.flush(id);
288-
send_error(id, stream, message).expect("Failed to send error response");
289-
}
290-
291313
#[cfg(test)]
292314
mod tests {
293315
use datafusion::scalar::ScalarValue;

0 commit comments

Comments
 (0)