Skip to content

Commit 18bc125

Browse files
committed
chore: separate async and postgres code in worker
1 parent 243990f commit 18bc125

File tree

1 file changed

+122
-99
lines changed

1 file changed

+122
-99
lines changed

src/worker.rs

Lines changed: 122 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ const TOKIO_THREAD_NUMBER: usize = 1;
2828
const WORKER_WAIT_TIMEOUT: Duration = Duration::from_millis(100);
2929
const SLOT_WAIT_TIMEOUT: Duration = Duration::from_millis(1);
3030

31+
// POSTGRES WORLD
32+
// Do not use any async functions in this part of the code.
33+
3134
#[pg_guard]
3235
pub(crate) fn init_datafusion_worker() {
3336
BackgroundWorkerBuilder::new("datafusion")
@@ -100,6 +103,14 @@ fn init_slots() -> Result<()> {
100103
Ok(())
101104
}
102105

106+
fn response_error(id: SlotNumber, ctx: &mut WorkerContext, stream: SlotStream, message: &str) {
107+
ctx.flush(id);
108+
send_error(id, stream, message).expect("Failed to send error response");
109+
}
110+
111+
// POSTGRES - ASYNC BRIDGE
112+
// The place where async world meets the postgres world.
113+
103114
#[pg_guard]
104115
#[no_mangle]
105116
pub extern "C" fn worker_main(_arg: pg_sys::Datum) {
@@ -117,103 +128,12 @@ pub extern "C" fn worker_main(_arg: pg_sys::Datum) {
117128

118129
log!("DataFusion worker is running");
119130
while do_retry || BackgroundWorker::wait_latch(Some(WORKER_WAIT_TIMEOUT)) {
120-
// Do not use any pgrx API in this loop: tokio has a multithreaded runtime,
121-
// while PostgreSQL functions can work only in single thread processes.
122131
rt.block_on(async {
123132
do_retry = false;
124-
// Process packets from the slots.
125-
for (id, locked_slot) in Bus::new().into_iter().enumerate() {
126-
let Some(slot) = locked_slot else {
127-
continue;
128-
};
129-
let mut stream = SlotStream::from(slot);
130-
let header = match consume_header(&mut stream) {
131-
Ok(header) => header,
132-
Err(err) => {
133-
errors[id] = Some(format_smolstr!("Failed to consume header: {:?}", err));
134-
continue;
135-
}
136-
};
137-
if header.direction == Direction::ToBackend {
138-
continue;
139-
}
140-
let machine = &mut ctx.states[id];
141-
let slot_id = u32::try_from(id).expect("Failed to convert slot id to u32");
142-
let output = match machine.consume(&header.packet) {
143-
Ok(output) => output,
144-
Err(err) => {
145-
let msg = format_smolstr!("Failed to change machine state: {:?}", err);
146-
errors[id] = Some(msg);
147-
continue;
148-
}
149-
};
150-
let handle = match output {
151-
Some(ExecutorOutput::Parse) => tokio::spawn(parse(header, stream)),
152-
Some(ExecutorOutput::Flush) => {
153-
ctx.flush(slot_id);
154-
continue;
155-
}
156-
Some(ExecutorOutput::Compile) => {
157-
let Some(stmt) = std::mem::take(&mut ctx.statements[id]) else {
158-
errors[id] = Some(format_smolstr!("No statement found for slot: {id}"));
159-
continue;
160-
};
161-
tokio::spawn(compile(header, stream, stmt))
162-
}
163-
Some(ExecutorOutput::Bind) => {
164-
let Some(plan) = std::mem::take(&mut ctx.logical_plans[id]) else {
165-
errors[id] =
166-
Some(format_smolstr!("No logical plan found for slot: {id}"));
167-
continue;
168-
};
169-
tokio::spawn(bind(header, stream, plan))
170-
}
171-
None => unreachable!("Empty output in the worker state machine"),
172-
};
173-
ctx.tasks.push((slot_id, handle));
174-
}
175-
// Wait for the tasks to finish and process their results.
176-
for (id, task) in &mut ctx.tasks {
177-
let result = task.await.expect("Failed to await task");
178-
match result {
179-
Ok(TaskResult::Parsing((stmt, tables))) => {
180-
let mut stream = wait_stream(*id).await;
181-
if tables.is_empty() {
182-
// We don't need any table metadata for this query.
183-
// So, write a fake metadata packet to the slot and proceed it
184-
// in the next iteration.
185-
do_retry = true;
186-
if let Err(err) = prepare_empty_metadata(&mut stream) {
187-
errors[*id as usize] =
188-
Some(format_smolstr!("Failed to prepare metadata: {:?}", err));
189-
continue;
190-
}
191-
} else {
192-
send_table_refs(*id, stream, tables.as_slice())
193-
.expect("Failed to reqest table references");
194-
}
195-
ctx.statements[*id as usize] = Some(stmt);
196-
}
197-
Ok(TaskResult::Compilation(plan)) => {
198-
let stream = wait_stream(*id).await;
199-
if let Err(err) = request_params(*id, stream) {
200-
errors[*id as usize] =
201-
Some(format_smolstr!("Failed to request params: {:?}", err));
202-
continue;
203-
}
204-
ctx.logical_plans[*id as usize] = Some(plan);
205-
}
206-
Ok(TaskResult::Bind(plan)) => {
207-
ctx.logical_plans[*id as usize] = Some(plan);
208-
}
209-
Err(err) => {
210-
errors[*id as usize] =
211-
Some(format_smolstr!("Failed to execute task: {:?}", err))
212-
}
213-
}
214-
}
133+
create_tasks(&mut ctx, &mut errors, &mut do_retry).await;
134+
wait_results(&mut ctx, &mut errors, &mut do_retry).await;
215135
});
216-
// Process errors in the main PostgreSQL thread.
136+
// Process errors returned by the tasks.
217137
for (slot_id, msg) in errors.iter_mut().enumerate() {
218138
if let Some(msg) = msg {
219139
let stream;
@@ -233,6 +153,114 @@ pub extern "C" fn worker_main(_arg: pg_sys::Datum) {
233153
set_worker_id(INVALID_PROC_NUMBER);
234154
}
235155

156+
// ASYNC WORLD
157+
// Do not use any pgrx symbols in async functions. Tokio has a multithreaded
158+
// runtime, while postgres functions can work only in single thread.
159+
160+
/// Process packets from the slots and create tasks for them.
161+
async fn create_tasks(
162+
ctx: &mut WorkerContext,
163+
errors: &mut [Option<SmolStr>],
164+
do_retry: &mut bool,
165+
) {
166+
// Process packets from the slots.
167+
for (id, locked_slot) in Bus::new().into_iter().enumerate() {
168+
let Some(slot) = locked_slot else {
169+
continue;
170+
};
171+
let mut stream = SlotStream::from(slot);
172+
let header = match consume_header(&mut stream) {
173+
Ok(header) => header,
174+
Err(err) => {
175+
errors[id] = Some(format_smolstr!("Failed to consume header: {:?}", err));
176+
continue;
177+
}
178+
};
179+
if header.direction == Direction::ToBackend {
180+
continue;
181+
}
182+
let machine = &mut ctx.states[id];
183+
let slot_id = u32::try_from(id).expect("Failed to convert slot id to u32");
184+
let output = match machine.consume(&header.packet) {
185+
Ok(output) => output,
186+
Err(err) => {
187+
let msg = format_smolstr!("Failed to change machine state: {:?}", err);
188+
errors[id] = Some(msg);
189+
continue;
190+
}
191+
};
192+
let handle = match output {
193+
Some(ExecutorOutput::Parse) => tokio::spawn(parse(header, stream)),
194+
Some(ExecutorOutput::Flush) => {
195+
ctx.flush(slot_id);
196+
continue;
197+
}
198+
Some(ExecutorOutput::Compile) => {
199+
let Some(stmt) = std::mem::take(&mut ctx.statements[id]) else {
200+
errors[id] = Some(format_smolstr!("No statement found for slot: {id}"));
201+
continue;
202+
};
203+
tokio::spawn(compile(header, stream, stmt))
204+
}
205+
Some(ExecutorOutput::Bind) => {
206+
let Some(plan) = std::mem::take(&mut ctx.logical_plans[id]) else {
207+
errors[id] = Some(format_smolstr!("No logical plan found for slot: {id}"));
208+
continue;
209+
};
210+
tokio::spawn(bind(header, stream, plan))
211+
}
212+
None => unreachable!("Empty output in the worker state machine"),
213+
};
214+
ctx.tasks.push((slot_id, handle));
215+
}
216+
}
217+
218+
/// Wait for the tasks to finish and process their results.
219+
async fn wait_results(
220+
ctx: &mut WorkerContext,
221+
errors: &mut [Option<SmolStr>],
222+
do_retry: &mut bool,
223+
) {
224+
for (id, task) in &mut ctx.tasks {
225+
let result = task.await.expect("Failed to await task");
226+
match result {
227+
Ok(TaskResult::Parsing((stmt, tables))) => {
228+
let mut stream = wait_stream(*id).await;
229+
if tables.is_empty() {
230+
// We don't need any table metadata for this query.
231+
// So, write a fake metadata packet to the slot and proceed it
232+
// in the next iteration.
233+
*do_retry = true;
234+
if let Err(err) = prepare_empty_metadata(&mut stream) {
235+
errors[*id as usize] =
236+
Some(format_smolstr!("Failed to prepare metadata: {:?}", err));
237+
continue;
238+
}
239+
} else {
240+
send_table_refs(*id, stream, tables.as_slice())
241+
.expect("Failed to reqest table references");
242+
}
243+
ctx.statements[*id as usize] = Some(stmt);
244+
}
245+
Ok(TaskResult::Compilation(plan)) => {
246+
let stream = wait_stream(*id).await;
247+
if let Err(err) = request_params(*id, stream) {
248+
errors[*id as usize] =
249+
Some(format_smolstr!("Failed to request params: {:?}", err));
250+
continue;
251+
}
252+
ctx.logical_plans[*id as usize] = Some(plan);
253+
}
254+
Ok(TaskResult::Bind(plan)) => {
255+
ctx.logical_plans[*id as usize] = Some(plan);
256+
}
257+
Err(err) => {
258+
errors[*id as usize] = Some(format_smolstr!("Failed to execute task: {:?}", err))
259+
}
260+
}
261+
}
262+
}
263+
236264
#[inline(always)]
237265
async fn wait_stream(slot_id: u32) -> SlotStream {
238266
loop {
@@ -283,11 +311,6 @@ async fn bind(
283311
Ok(TaskResult::Bind(plan))
284312
}
285313

286-
fn response_error(id: SlotNumber, ctx: &mut WorkerContext, stream: SlotStream, message: &str) {
287-
ctx.flush(id);
288-
send_error(id, stream, message).expect("Failed to send error response");
289-
}
290-
291314
#[cfg(test)]
292315
mod tests {
293316
use datafusion::scalar::ScalarValue;

0 commit comments

Comments
 (0)