Skip to content

Commit 13f9443

Browse files
committed
feat: make tests for async loop
Had to remove MyProcNumber symbol from the slot bus iterator to make it work with tokio.
1 parent 184819a commit 13f9443

File tree

3 files changed

+90
-38
lines changed

3 files changed

+90
-38
lines changed

src/backend.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use libc::c_long;
33
use pgrx::pg_sys::{
44
error, fetch_search_path_array, get_namespace_oid, get_relname_relid, palloc0,
55
CustomExecMethods, CustomScan, CustomScanMethods, CustomScanState, EState, ExplainState,
6-
InvalidOid, List, ListCell, MyLatch, Node, NodeTag, Oid, ParamListInfo,
6+
InvalidOid, List, ListCell, MyLatch, MyProcNumber, Node, NodeTag, Oid, ParamListInfo,
77
RegisterCustomScanMethods, ResetLatch, TupleTableSlot, WaitLatch, PG_WAIT_EXTENSION,
88
WL_LATCH_SET, WL_POSTMASTER_DEATH, WL_TIMEOUT,
99
};
@@ -68,10 +68,11 @@ struct ScanState {
6868
#[pg_guard]
6969
#[no_mangle]
7070
unsafe extern "C" fn create_df_scan_state(cscan: *mut CustomScan) -> *mut Node {
71+
let my_proc_number = unsafe { MyProcNumber };
7172
let wait_stream = || -> SlotStream {
7273
let stream;
7374
loop {
74-
let Some(slot) = Bus::new().slot_locked(my_slot()) else {
75+
let Some(slot) = Bus::new().slot_locked(my_slot(), my_proc_number) else {
7576
wait_latch(Some(BACKEND_WAIT_TIMEOUT));
7677
continue;
7778
};
@@ -95,7 +96,7 @@ unsafe extern "C" fn create_df_scan_state(cscan: *mut CustomScan) -> *mut Node {
9596
wait_latch(Some(BACKEND_WAIT_TIMEOUT));
9697
skip_wait = false;
9798
}
98-
let Some(slot) = Bus::new().slot_locked(my_slot()) else {
99+
let Some(slot) = Bus::new().slot_locked(my_slot(), my_proc_number) else {
99100
continue;
100101
};
101102
let mut stream = SlotStream::from(slot);

src/ipc.rs

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use std::slice::from_raw_parts_mut;
1111
use std::sync::atomic::{AtomicBool, AtomicI32, Ordering};
1212

1313
static mut SLOT_FREE_LIST_PTR: OnceCell<*mut c_void> = OnceCell::new();
14-
static mut BUS_PTR: OnceCell<*mut c_void> = OnceCell::new();
14+
pub(crate) static mut BUS_PTR: OnceCell<*mut c_void> = OnceCell::new();
1515
static mut WORKER_PID_PTR: OnceCell<*mut c_void> = OnceCell::new();
1616
pub(crate) const INVALID_PROC_NUMBER: i32 = -1;
1717
pub(crate) const DATA_SIZE: usize = 8 * 1024;
@@ -268,19 +268,25 @@ impl Slot {
268268
unsafe { (*self.locked).load(Ordering::Relaxed) }
269269
}
270270

271-
pub(crate) fn lock(&self) -> bool {
271+
pub(crate) fn lock(&self, holder: i32) -> bool {
272272
unsafe {
273-
if let Ok(_) =
274-
(*self.locked).compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
273+
if (*self.locked)
274+
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
275+
.is_ok()
275276
{
276-
(*self.holder).store(MyProcNumber, Ordering::Relaxed);
277+
(*self.holder).store(holder, Ordering::Relaxed);
277278
true
278279
} else {
279280
false
280281
}
281282
}
282283
}
283284

285+
pub(crate) fn pg_lock(&self) -> bool {
286+
let holder = unsafe { MyProcNumber };
287+
self.lock(holder)
288+
}
289+
284290
pub(crate) fn unlock(&self) {
285291
unsafe {
286292
(*self.holder).store(INVALID_PROC_NUMBER, Ordering::Release);
@@ -413,9 +419,9 @@ impl Bus {
413419
Slot::from_bytes(slot_ptr, Slot::estimated_size())
414420
}
415421

416-
pub(crate) fn slot_locked(&mut self, id: SlotNumber) -> Option<Slot> {
422+
pub(crate) fn slot_locked(&mut self, id: SlotNumber, holder: i32) -> Option<Slot> {
417423
let slot = self.slot(id);
418-
if slot.lock() {
424+
if slot.lock(holder) {
419425
Some(slot)
420426
} else {
421427
None
@@ -427,21 +433,23 @@ impl Bus {
427433
Self::from_bytes(ptr as *mut u8, Self::estimated_size())
428434
}
429435

430-
pub(crate) fn into_iter(self) -> BusIter {
431-
BusIter::from(self)
436+
pub(crate) fn into_iter(self, holder: i32) -> BusIter {
437+
BusIter::new(self, holder)
432438
}
433439
}
434440

435441
pub(crate) struct BusIter {
436442
pos: SlotNumber,
437443
inner: Bus,
444+
holder: i32,
438445
}
439446

440-
impl From<Bus> for BusIter {
441-
fn from(value: Bus) -> Self {
447+
impl BusIter {
448+
pub(crate) fn new(bus: Bus, holder: i32) -> Self {
442449
Self {
443450
pos: 0,
444-
inner: value,
451+
inner: bus,
452+
holder,
445453
}
446454
}
447455
}
@@ -453,7 +461,7 @@ impl Iterator for BusIter {
453461
if self.pos >= max_backends() {
454462
return None;
455463
}
456-
let res = self.inner.slot_locked(self.pos);
464+
let res = self.inner.slot_locked(self.pos, self.holder);
457465
self.pos += 1;
458466
Some(res)
459467
}
@@ -562,8 +570,8 @@ pub(crate) mod tests {
562570
assert_eq!(Slot::estimated_size(), SLOT_SIZE);
563571
let mut slot_buf: [u8; SLOT_SIZE] = [1; SLOT_SIZE];
564572
let mut slot = make_slot(&mut slot_buf);
565-
assert!(slot.lock());
566-
assert!(!slot.lock());
573+
assert!(slot.pg_lock());
574+
assert!(!slot.pg_lock());
567575
unsafe {
568576
assert_eq!(&slot_buf[Slot::range1()], &[1]);
569577
write(slot.data_mut(), [1; DATA_SIZE]);
@@ -577,7 +585,7 @@ pub(crate) mod tests {
577585
fn test_slot_stream() {
578586
let mut buffer: [u8; SLOT_SIZE] = [1; SLOT_SIZE];
579587
let slot = make_slot(&mut buffer);
580-
slot.lock();
588+
slot.pg_lock();
581589
let mut stream = SlotStream::from(slot);
582590
let data = [42; 10];
583591
let len = stream.write(&data).unwrap();
@@ -626,12 +634,13 @@ pub(crate) mod tests {
626634
assert_eq!(buffer[i * slot_size], 0);
627635
}
628636
}
637+
let my_proc_number = unsafe { MyProcNumber };
629638
let mut bus = Bus::from_bytes(ptr, len);
630639
for i in 0..max_backends() {
631-
let slot = bus.slot_locked(i);
640+
let slot = bus.slot_locked(i, my_proc_number);
632641
assert!(slot.is_some());
633642
let slot = slot.unwrap();
634-
assert!(!slot.lock());
643+
assert!(!slot.pg_lock());
635644
slot.unlock();
636645
assert_eq!(buffer[i as usize * slot_size], 0);
637646
}

src/worker.rs

Lines changed: 59 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ struct WorkerContext {
5454
}
5555

5656
impl WorkerContext {
57-
fn new() -> Self {
58-
let capacity = max_backends() as usize;
57+
fn with_capacity(capacity: usize) -> Self {
5958
let tasks = Vec::with_capacity(capacity);
6059
let mut states = Vec::with_capacity(capacity);
6160
let mut statements = Vec::with_capacity(capacity);
@@ -88,8 +87,8 @@ impl WorkerContext {
8887
}
8988
}
9089

91-
fn init_slots() -> Result<()> {
92-
for locked_slot in Bus::new().into_iter().flatten() {
90+
fn init_slots(holder: i32) -> Result<()> {
91+
for locked_slot in Bus::new().into_iter(holder).flatten() {
9392
let mut stream = SlotStream::from(locked_slot);
9493
stream.reset();
9594
let header = Header {
@@ -115,32 +114,41 @@ fn response_error(id: SlotNumber, ctx: &mut WorkerContext, stream: SlotStream, m
115114
#[no_mangle]
116115
pub extern "C" fn worker_main(_arg: pg_sys::Datum) {
117116
BackgroundWorker::attach_signal_handlers(SignalWakeFlags::SIGHUP | SignalWakeFlags::SIGTERM);
118-
let mut ctx = WorkerContext::new();
117+
let capacity = max_backends() as usize;
118+
let mut ctx = WorkerContext::with_capacity(capacity);
119119
let rt = Builder::new_multi_thread()
120120
.worker_threads(TOKIO_THREAD_NUMBER)
121121
.enable_all()
122122
.build()
123123
.unwrap();
124124
let mut do_retry = false;
125-
let capacity = max_backends() as usize;
126125
let mut errors: Vec<Option<SmolStr>> = vec![None; capacity];
127126
let mut signals: Vec<bool> = vec![false; capacity];
128-
init_slots().expect("Failed to initialize slots");
129-
unsafe { set_worker_id(MyProcNumber) };
127+
let worker_proc_number = unsafe { MyProcNumber };
128+
init_slots(worker_proc_number).expect("Failed to initialize slots");
129+
set_worker_id(worker_proc_number);
130130

131131
log!("DataFusion worker is running");
132132
while do_retry || BackgroundWorker::wait_latch(Some(WORKER_WAIT_TIMEOUT)) {
133133
rt.block_on(async {
134134
do_retry = false;
135-
create_tasks(&mut ctx, &mut errors).await;
136-
wait_results(&mut ctx, &mut errors, &mut signals, &mut do_retry).await;
135+
create_tasks(&mut ctx, &mut errors, worker_proc_number).await;
136+
wait_results(
137+
&mut ctx,
138+
&mut errors,
139+
&mut signals,
140+
&mut do_retry,
141+
worker_proc_number,
142+
)
143+
.await;
137144
});
138145
// Process errors returned by the tasks.
139146
for (slot_id, msg) in errors.iter_mut().enumerate() {
140147
if let Some(msg) = msg {
141148
let stream;
142149
loop {
143-
let Some(slot) = Bus::new().slot_locked(slot_id as u32) else {
150+
let Some(slot) = Bus::new().slot_locked(slot_id as u32, worker_proc_number)
151+
else {
144152
BackgroundWorker::wait_latch(Some(SLOT_WAIT_TIMEOUT));
145153
continue;
146154
};
@@ -168,8 +176,8 @@ pub extern "C" fn worker_main(_arg: pg_sys::Datum) {
168176
// runtime, while postgres functions can work only in single thread.
169177

170178
/// Process packets from the slots and create tasks for them.
171-
async fn create_tasks(ctx: &mut WorkerContext, errors: &mut [Option<SmolStr>]) {
172-
for (id, locked_slot) in Bus::new().into_iter().enumerate() {
179+
async fn create_tasks(ctx: &mut WorkerContext, errors: &mut [Option<SmolStr>], holder: i32) {
180+
for (id, locked_slot) in Bus::new().into_iter(holder).enumerate() {
173181
let Some(slot) = locked_slot else {
174182
continue;
175183
};
@@ -226,12 +234,13 @@ async fn wait_results(
226234
errors: &mut [Option<SmolStr>],
227235
signals: &mut [bool],
228236
do_retry: &mut bool,
237+
holder: i32,
229238
) {
230239
for (id, task) in &mut ctx.tasks {
231240
let result = task.await.expect("Failed to await task");
232241
match result {
233242
Ok(TaskResult::Parsing((stmt, tables))) => {
234-
let mut stream = wait_stream(*id).await;
243+
let mut stream = wait_stream(*id, holder).await;
235244
if tables.is_empty() {
236245
// We don't need any table metadata for this query.
237246
// So, write a fake metadata packet to the slot and proceed it
@@ -260,7 +269,7 @@ async fn wait_results(
260269
ctx.statements[*id as usize] = Some(stmt);
261270
}
262271
Ok(TaskResult::Compilation(plan)) => {
263-
let mut stream = wait_stream(*id).await;
272+
let mut stream = wait_stream(*id, holder).await;
264273
match request_params(&mut stream) {
265274
Ok(()) => signals[*id as usize] = true,
266275
Err(err) => {
@@ -282,9 +291,9 @@ async fn wait_results(
282291
}
283292

284293
#[inline(always)]
285-
async fn wait_stream(slot_id: u32) -> SlotStream {
294+
async fn wait_stream(slot_id: u32, holder: i32) -> SlotStream {
286295
loop {
287-
let Some(slot) = Bus::new().slot_locked(slot_id) else {
296+
let Some(slot) = Bus::new().slot_locked(slot_id, holder) else {
288297
tokio::time::sleep(SLOT_WAIT_TIMEOUT).await;
289298
continue;
290299
};
@@ -339,6 +348,7 @@ mod tests {
339348
use super::*;
340349
use crate::data_type::{write_scalar_value, EncodedType};
341350
use crate::ipc::tests::{make_slot, SLOT_SIZE};
351+
use crate::ipc::{Slot, BUS_PTR};
342352
use crate::protocol::prepare_query;
343353

344354
#[tokio::test]
@@ -503,4 +513,36 @@ mod tests {
503513
"Projection: Int64(1) [Int64(1):Int64]\n EmptyRelation []",
504514
);
505515
}
516+
517+
#[tokio::test]
518+
async fn test_loop() {
519+
let holder = 42;
520+
let capacity = 2;
521+
let mut ctx = WorkerContext::with_capacity(capacity);
522+
let mut errors: Vec<Option<SmolStr>> = vec![None; capacity];
523+
let mut signals: Vec<bool> = vec![false; capacity];
524+
let mut do_retry = false;
525+
let bus_size = Slot::estimated_size() * capacity;
526+
let mut buffer = vec![0; bus_size];
527+
unsafe { BUS_PTR.set(buffer.as_mut_ptr() as _).unwrap() };
528+
init_slots(holder).expect("Failed to initialize slots");
529+
// Check processing of the parse message.
530+
let sql = "SELECT * FROM foo";
531+
{
532+
let slot = Bus::new().slot_locked(0, holder).unwrap();
533+
let mut stream = SlotStream::from(slot);
534+
prepare_query(&mut stream, sql).unwrap();
535+
}
536+
create_tasks(&mut ctx, &mut errors, holder).await;
537+
wait_results(&mut ctx, &mut errors, &mut signals, &mut do_retry, holder).await;
538+
let error = errors[0].take();
539+
assert!(error.is_none(), "Error: {:?}", error);
540+
let stmt = ctx.statements[0].take().unwrap();
541+
let expected_stmt = DFParser::parse_sql(sql)
542+
.expect("Failed to parse SQL")
543+
.into_iter()
544+
.next()
545+
.expect("Failed to get statement");
546+
assert_eq!(stmt, expected_stmt);
547+
}
506548
}

0 commit comments

Comments
 (0)