Skip to content

Commit 5995dbd

Browse files
committed
feat: serialize metadata response
1 parent 092e9cd commit 5995dbd

File tree

2 files changed

+118
-14
lines changed

2 files changed

+118
-14
lines changed

src/backend.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
use anyhow::Result as AnyResult;
22
use libc::c_long;
33
use pgrx::pg_sys::{
4-
error, fetch_search_path_array, get_namespace_oid, get_relname_relid, palloc0,
4+
self, error, fetch_search_path_array, get_namespace_oid, get_relname_relid, palloc0,
55
CustomExecMethods, CustomScan, CustomScanMethods, CustomScanState, EState, ExplainState,
66
InvalidOid, List, ListCell, MyLatch, Node, NodeTag, Oid, ParamListInfo,
77
RegisterCustomScanMethods, ResetLatch, TupleTableSlot, WaitLatch, PG_WAIT_EXTENSION,
88
WL_LATCH_SET, WL_POSTMASTER_DEATH, WL_TIMEOUT,
99
};
1010
use pgrx::{check_for_interrupts, pg_guard};
1111
use rmp::decode::{read_array_len, read_bin_len};
12-
use smallvec::{smallvec, SmallVec};
12+
use smallvec::SmallVec;
1313
use std::ffi::c_char;
1414
use std::time::Duration;
1515

1616
use crate::error::FusionError;
17-
use crate::ipc::{my_slot, Bus, SlotHandler, SlotNumber, SlotStream, CURRENT_SLOT};
17+
use crate::ipc::{my_slot, Bus, SlotStream};
1818
use crate::protocol::{consume_header, read_error, send_params, send_query, Direction, Packet};
1919

2020
thread_local! {

src/protocol.rs

Lines changed: 115 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ use datafusion::scalar::ScalarValue;
77
use datafusion_sql::TableReference;
88
use pgrx::pg_sys::{Oid, ParamExternData, ProcSendSignal};
99
use pgrx::prelude::*;
10-
use rmp::decode::{read_array_len, read_bin_len, read_pfix, read_str_len, read_u16};
10+
use pgrx::{pg_guard, PgRelation};
11+
use rmp::decode::{read_array_len, read_bin_len, read_pfix, read_str_len, read_u16, read_u8};
1112
use rmp::encode::{
12-
write_array_len, write_bin_len, write_bool, write_pfix, write_str, write_u16, RmpWrite,
13+
write_array_len, write_bin_len, write_bool, write_pfix, write_str, write_u16, write_u32,
14+
write_u8, RmpWrite,
1315
};
1416

1517
#[repr(u8)]
@@ -331,23 +333,71 @@ pub(crate) fn send_table_refs(
331333
}
332334

333335
#[inline]
334-
pub(crate) fn write_column(
335-
stream: &mut SlotStream,
336-
column: &str,
337-
is_null: bool,
338-
etype: EncodedType,
336+
#[pg_guard]
337+
fn serialize_table(rel_oid: Oid, stream: &mut SlotStream) -> Result<()> {
338+
// The destructor will release the lock.
339+
let rel = unsafe { PgRelation::with_lock(rel_oid, pg_sys::AccessShareLock as i32) };
340+
let tuple_desc = rel.tuple_desc();
341+
let attr_num = u32::try_from(tuple_desc.iter().filter(|a| !a.is_dropped()).count())?;
342+
write_u32(stream, rel_oid.as_u32())?;
343+
write_array_len(stream, attr_num)?;
344+
for attr in tuple_desc.iter() {
345+
if attr.is_dropped() {
346+
continue;
347+
}
348+
let etype = EncodedType::try_from(attr.type_oid().value())?;
349+
let is_nullable = !attr.attnotnull;
350+
let name = attr.name();
351+
write_array_len(stream, 3)?;
352+
write_str(stream, name)?;
353+
write_u8(stream, etype as u8)?;
354+
write_bool(stream, is_nullable)?;
355+
}
356+
Ok(())
357+
}
358+
359+
pub(crate) fn prepare_metadata(rel_oids: &[Oid], stream: &mut SlotStream) -> Result<()> {
360+
stream.reset();
361+
// We don't know the length of the table metadata yet. So we write
362+
// an invalid header to replace it with the correct one later.
363+
write_header(stream, &Header::default())?;
364+
let pos_init = stream.position();
365+
write_array_len(stream, rel_oids.len() as u32)?;
366+
for &rel_oid in rel_oids {
367+
serialize_table(rel_oid, stream)?;
368+
}
369+
let pos_final = stream.position();
370+
let length = u16::try_from(pos_final - pos_init)?;
371+
let header = Header {
372+
direction: Direction::ToWorker,
373+
packet: Packet::Metadata,
374+
length,
375+
flag: Flag::Last,
376+
};
377+
stream.reset();
378+
write_header(stream, &header)?;
379+
stream.rewind(length as usize)?;
380+
Ok(())
381+
}
382+
383+
pub(crate) fn send_metadata(
384+
slot_id: SlotNumber,
385+
mut stream: SlotStream,
386+
rel_oids: &[Oid],
339387
) -> Result<()> {
340-
write_str(stream, column)?;
341-
write_bool(stream, is_null)?;
342-
write_pfix(stream, etype as u8)?;
388+
prepare_metadata(rel_oids, &mut stream)?;
389+
// Unlock the slot after writing the metadata.
390+
let _guard = Slot::from(stream);
391+
signal(slot_id, Direction::ToWorker);
343392
Ok(())
344393
}
345394

346395
#[cfg(any(test, feature = "pg_test"))]
347396
#[pg_schema]
348397
mod tests {
349398
use super::*;
350-
use pgrx::pg_sys::Datum;
399+
use pgrx::pg_sys::{Datum, Oid};
400+
use rmp::decode::{read_bool, read_u32};
351401
use std::ptr::addr_of_mut;
352402

353403
const SLOT_SIZE: usize = 8204;
@@ -482,4 +532,58 @@ mod tests {
482532
let t2 = stream.look_ahead(t2_len as usize).unwrap();
483533
assert_eq!(t2, b"table2\0");
484534
}
535+
536+
#[pg_test]
537+
fn test_metadata_response() {
538+
Spi::run("create table if not exists t1(a int not null, b text);").unwrap();
539+
let t1_oid = Spi::get_one::<Oid>("select 't1'::regclass::oid;")
540+
.unwrap()
541+
.unwrap();
542+
543+
let mut slot_buf: [u8; SLOT_SIZE] = [1; SLOT_SIZE];
544+
let ptr = addr_of_mut!(slot_buf) as *mut u8;
545+
Slot::init(ptr, slot_buf.len());
546+
let slot = Slot::from_bytes(ptr, slot_buf.len());
547+
let mut stream: SlotStream = slot.into();
548+
549+
prepare_metadata(&[t1_oid], &mut stream).unwrap();
550+
stream.reset();
551+
let header = consume_header(&mut stream).unwrap();
552+
assert_eq!(header.direction, Direction::ToWorker);
553+
assert_eq!(header.packet, Packet::Metadata);
554+
assert_eq!(header.flag, Flag::Last);
555+
556+
// Check table metadata deserialization
557+
let table_num = read_array_len(&mut stream).unwrap();
558+
assert_eq!(table_num, 1);
559+
// t1
560+
let oid = read_u32(&mut stream).unwrap();
561+
assert_eq!(oid, t1_oid.as_u32());
562+
let attr_num = read_array_len(&mut stream).unwrap();
563+
assert_eq!(attr_num, 2);
564+
// a
565+
let elem_num = read_array_len(&mut stream).unwrap();
566+
assert_eq!(elem_num, 3);
567+
let name_len = read_str_len(&mut stream).unwrap();
568+
assert_eq!(name_len, "a".len() as u32);
569+
let name = stream.look_ahead(name_len as usize).unwrap();
570+
assert_eq!(name, b"a");
571+
stream.rewind(name_len as usize).unwrap();
572+
let etype = read_u8(&mut stream).unwrap();
573+
assert_eq!(etype, EncodedType::Int32 as u8);
574+
let is_nullable = read_bool(&mut stream).unwrap();
575+
assert!(!is_nullable);
576+
// b
577+
let elem_num = read_array_len(&mut stream).unwrap();
578+
assert_eq!(elem_num, 3);
579+
let name_len = read_str_len(&mut stream).unwrap();
580+
assert_eq!(name_len, "b".len() as u32);
581+
let name = stream.look_ahead(name_len as usize).unwrap();
582+
assert_eq!(name, b"b");
583+
stream.rewind(name_len as usize).unwrap();
584+
let etype = read_u8(&mut stream).unwrap();
585+
assert_eq!(etype, EncodedType::Utf8 as u8);
586+
let is_nullable = read_bool(&mut stream).unwrap();
587+
assert!(is_nullable);
588+
}
485589
}

0 commit comments

Comments
 (0)