Skip to content

Commit 7c98c23

Browse files
committed
feat: support explain packet
1 parent 4d6e7f3 commit 7c98c23

File tree

4 files changed

+199
-39
lines changed

4 files changed

+199
-39
lines changed

src/backend.rs

Lines changed: 78 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ use anyhow::Result as AnyResult;
22
use libc::c_long;
33
use pgrx::pg_sys::{
44
error, fetch_search_path_array, get_namespace_oid, get_relname_relid, palloc0,
5-
CustomExecMethods, CustomScan, CustomScanMethods, CustomScanState, EState, ExplainState,
6-
InvalidOid, List, ListCell, MyLatch, MyProcNumber, Node, NodeTag, Oid, ParamListInfo,
7-
RegisterCustomScanMethods, ResetLatch, TupleTableSlot, WaitLatch, PG_WAIT_EXTENSION,
8-
WL_LATCH_SET, WL_POSTMASTER_DEATH, WL_TIMEOUT,
5+
CustomExecMethods, CustomScan, CustomScanMethods, CustomScanState, EState, ExplainPropertyText,
6+
ExplainState, InvalidOid, List, ListCell, MyLatch, MyProcNumber, Node, NodeTag, Oid,
7+
ParamListInfo, RegisterCustomScanMethods, ResetLatch, TupleTableSlot, WaitLatch,
8+
PG_WAIT_EXTENSION, WL_LATCH_SET, WL_POSTMASTER_DEATH, WL_TIMEOUT,
99
};
1010
use pgrx::{check_for_interrupts, pg_guard};
1111
use rmp::decode::{read_array_len, read_bin_len};
@@ -16,8 +16,8 @@ use std::time::Duration;
1616
use crate::error::FusionError;
1717
use crate::ipc::{my_slot, Bus, SlotStream};
1818
use crate::protocol::{
19-
consume_header, read_error, send_metadata, send_params, send_query, Direction, NeedSchema,
20-
Packet,
19+
consume_header, read_error, request_explain, send_metadata, send_params, send_query, Direction,
20+
NeedSchema, Packet,
2121
};
2222

2323
const BACKEND_WAIT_TIMEOUT: Duration = Duration::from_millis(100);
@@ -60,27 +60,23 @@ pub(crate) fn exec_methods() -> *const CustomExecMethods {
6060
EXEC_METHODS.with(|m| &*m as *const CustomExecMethods)
6161
}
6262

63-
#[repr(C)]
64-
struct ScanState {
65-
css: CustomScanState,
63+
#[pg_guard]
64+
#[inline(always)]
65+
fn wait_stream() -> SlotStream {
66+
let my_proc_number = unsafe { MyProcNumber };
67+
loop {
68+
let Some(slot) = Bus::new().slot_locked(my_slot(), my_proc_number) else {
69+
wait_latch(Some(BACKEND_WAIT_TIMEOUT));
70+
continue;
71+
};
72+
return SlotStream::from(slot);
73+
}
6674
}
6775

6876
#[pg_guard]
6977
#[no_mangle]
7078
unsafe extern "C" fn create_df_scan_state(cscan: *mut CustomScan) -> *mut Node {
7179
let my_proc_number = unsafe { MyProcNumber };
72-
let wait_stream = || -> SlotStream {
73-
let stream;
74-
loop {
75-
let Some(slot) = Bus::new().slot_locked(my_slot(), my_proc_number) else {
76-
wait_latch(Some(BACKEND_WAIT_TIMEOUT));
77-
continue;
78-
};
79-
stream = Some(SlotStream::from(slot));
80-
break;
81-
}
82-
stream.expect("Failed to acquire a slot stream")
83-
};
8480
let list = (*cscan).custom_private;
8581
let pattern = (*list_nth(list, 0)).ptr_value as *mut c_char;
8682
let stream = wait_stream();
@@ -120,8 +116,11 @@ unsafe extern "C" fn create_df_scan_state(cscan: *mut CustomScan) -> *mut Node {
120116
}
121117
break;
122118
}
123-
Packet::Bind | Packet::Parse => {
124-
error!("Unexpected packet in backend: {:?}", header.packet)
119+
Packet::Bind | Packet::Parse | Packet::Explain => {
120+
error!(
121+
"Unexpected packet for create custom plan: {:?}",
122+
header.packet
123+
)
125124
}
126125
}
127126
}
@@ -136,36 +135,85 @@ unsafe extern "C" fn create_df_scan_state(cscan: *mut CustomScan) -> *mut Node {
136135
methods: exec_methods(),
137136
..Default::default()
138137
};
139-
let state = ScanState { css };
140-
let mut node = PgNode::empty(std::mem::size_of::<ScanState>());
138+
let mut node = PgNode::empty(std::mem::size_of::<CustomScanState>());
141139
node.set_tag(NodeTag::T_CustomScanState);
142140
node.set_data(unsafe {
143141
std::slice::from_raw_parts(
144-
&state as *const _ as *const u8,
145-
std::mem::size_of::<ScanState>(),
142+
&css as *const _ as *const u8,
143+
std::mem::size_of::<CustomScanState>(),
146144
)
147145
});
148146
node.mut_node()
149147
}
150148

149+
#[pg_guard]
150+
#[no_mangle]
151151
unsafe extern "C" fn begin_df_scan(node: *mut CustomScanState, estate: *mut EState, eflags: i32) {
152152
todo!()
153153
}
154154

155+
#[pg_guard]
156+
#[no_mangle]
155157
unsafe extern "C" fn exec_df_scan(node: *mut CustomScanState) -> *mut TupleTableSlot {
156158
todo!()
157159
}
158160

161+
#[pg_guard]
162+
#[no_mangle]
159163
unsafe extern "C" fn end_df_scan(node: *mut CustomScanState) {
160164
todo!()
161165
}
162166

167+
#[pg_guard]
168+
#[no_mangle]
163169
unsafe extern "C" fn explain_df_scan(
164-
node: *mut CustomScanState,
165-
ancestors: *mut List,
170+
_node: *mut CustomScanState,
171+
_ancestors: *mut List,
166172
es: *mut ExplainState,
167173
) {
168-
todo!()
174+
let my_proc_number = unsafe { MyProcNumber };
175+
let stream = wait_stream();
176+
if let Err(err) = request_explain(my_slot(), stream) {
177+
error!("Failed to request explain: {}", err);
178+
}
179+
loop {
180+
wait_latch(Some(BACKEND_WAIT_TIMEOUT));
181+
let Some(slot) = Bus::new().slot_locked(my_slot(), my_proc_number) else {
182+
continue;
183+
};
184+
let mut stream = SlotStream::from(slot);
185+
let header = consume_header(&mut stream).expect("Failed to consume header");
186+
if header.direction != Direction::ToBackend {
187+
continue;
188+
}
189+
match header.packet {
190+
Packet::None => {
191+
// No data, just continue waiting.
192+
continue;
193+
}
194+
Packet::Failure => {
195+
let msg = read_error(&mut stream).expect("Failed to read the error message");
196+
error!("Failed to compile the query: {}", msg);
197+
}
198+
Packet::Metadata | Packet::Bind | Packet::Parse => {
199+
error!("Unexpected packet for explain: {:?}", header.packet)
200+
}
201+
Packet::Explain => {
202+
let len = read_bin_len(&mut stream)
203+
.expect("Failed to read the length in explain message");
204+
let explain = stream
205+
.look_ahead(len as usize)
206+
.expect("Failed to read the explain message");
207+
unsafe {
208+
ExplainPropertyText(
209+
"DataFusion Plan\0".as_ptr() as _,
210+
explain.as_ptr() as _,
211+
es,
212+
);
213+
}
214+
}
215+
}
216+
}
169217
}
170218

171219
// We expect that the header is already consumed and the packet type is `Packet::Metadata`.

src/fsm.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
use rust_fsm::*;
22

33
#[derive(Debug)]
4-
pub enum ExecutorOutput {
4+
pub enum Action {
55
Bind,
66
Parse,
77
Compile,
88
Flush,
9+
Explain,
910
}
1011

1112
pub enum ExecutorState {
@@ -15,7 +16,7 @@ pub enum ExecutorState {
1516
}
1617

1718
state_machine! {
18-
#[state_machine(input(crate::protocol::Packet), state(crate::fsm::ExecutorState), output(crate::fsm::ExecutorOutput))]
19+
#[state_machine(input(crate::protocol::Packet), state(crate::fsm::ExecutorState), output(crate::fsm::Action))]
1920
pub executor(Initialized)
2021

2122
Initialized => {
@@ -28,5 +29,6 @@ state_machine! {
2829
LogicalPlan => {
2930
Failure => Initialized[Flush],
3031
Bind => LogicalPlan[Bind],
32+
Explain => Initialized[Explain],
3133
}
3234
}

src/protocol.rs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ pub enum Packet {
5656
Failure = 2,
5757
Metadata = 3,
5858
Parse = 4,
59+
Explain = 5,
5960
}
6061

6162
impl TryFrom<u8> for Packet {
@@ -69,6 +70,7 @@ impl TryFrom<u8> for Packet {
6970
2 => Ok(Packet::Failure),
7071
3 => Ok(Packet::Metadata),
7172
4 => Ok(Packet::Parse),
73+
5 => Ok(Packet::Explain),
7274
_ => Err(FusionError::Deserialize("packet".to_string(), value.into())),
7375
}
7476
}
@@ -474,6 +476,42 @@ pub(crate) fn consume_metadata(
474476
Ok(tables)
475477
}
476478

479+
// EXPLAIN
480+
481+
pub(crate) fn prepare_explain(stream: &mut SlotStream, explain: &str) -> Result<()> {
482+
stream.reset();
483+
let header = Header::default();
484+
write_header(stream, &header)?;
485+
let pos_init = stream.position();
486+
write_c_str(stream, explain)?;
487+
let pos_final = stream.position();
488+
let length = u16::try_from(pos_final - pos_init)?;
489+
let header = Header {
490+
direction: Direction::ToBackend,
491+
packet: Packet::Explain,
492+
length,
493+
flag: Flag::Last,
494+
};
495+
stream.reset();
496+
write_header(stream, &header)?;
497+
stream.rewind(length as usize)?;
498+
Ok(())
499+
}
500+
501+
pub(crate) fn request_explain(slot_id: SlotNumber, mut stream: SlotStream) -> Result<()> {
502+
let header = Header {
503+
direction: Direction::ToWorker,
504+
packet: Packet::Explain,
505+
length: 0,
506+
flag: Flag::Last,
507+
};
508+
write_header(&mut stream, &header)?;
509+
// Unlock the slot after writing the explain.
510+
let _guard = Slot::from(stream);
511+
signal(slot_id, Direction::ToWorker);
512+
Ok(())
513+
}
514+
477515
#[cfg(any(test, feature = "pg_test"))]
478516
#[pg_schema]
479517
mod tests {
@@ -732,4 +770,24 @@ mod tests {
732770
assert_eq!(t2.schema().field(0).data_type(), &DataType::Int32);
733771
assert!(t2.schema().field(0).is_nullable());
734772
}
773+
774+
#[pg_test]
775+
fn test_explain() {
776+
let mut slot_buf: [u8; SLOT_SIZE] = [1; SLOT_SIZE];
777+
let mut stream: SlotStream = make_slot(&mut slot_buf).into();
778+
let orig_explain = r#"Projection: * [a:Int32;N, b:Utf8]
779+
Filter: foo.a = $1 [a:Int32;N, b:Utf8]
780+
TableScan: foo [a:Int32;N, b:Utf8]"#;
781+
prepare_explain(&mut stream, orig_explain).unwrap();
782+
stream.reset();
783+
let header = consume_header(&mut stream).unwrap();
784+
assert_eq!(header.direction, Direction::ToBackend);
785+
assert_eq!(header.packet, Packet::Explain);
786+
assert_eq!(header.flag, Flag::Last);
787+
let explain_len = read_bin_len(&mut stream).unwrap();
788+
assert_eq!(explain_len as usize, orig_explain.len() + 1);
789+
let explain = stream.look_ahead(explain_len as usize).unwrap();
790+
let expected = format!("{}\0", orig_explain);
791+
assert_eq!(explain, expected.as_bytes());
792+
}
735793
}

0 commit comments

Comments
 (0)