Skip to content

Commit 092e9cd

Browse files
committed
feat: process metadata request (get table oids)
1 parent 668d076 commit 092e9cd

File tree

6 files changed

+200
-16
lines changed

6 files changed

+200
-16
lines changed

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ pg17 = ["pgrx/pg17", "pgrx-tests/pg17"]
1616
pg_test = []
1717

1818
[dependencies]
19+
ahash = "0.8"
1920
anyhow = "1.0"
2021
datafusion = "44.0"
2122
datafusion-sql = "44.0"
@@ -24,6 +25,7 @@ libc = "0.2"
2425
pgrx = "0.12"
2526
rmp = "0.8"
2627
rust-fsm = { version = "0.7", features = ["diagram"] }
28+
smallvec = { version = "1.14", features = ["const_generics", "union"] }
2729
smol_str = "0.3"
2830
thiserror = "2.0"
2931
tokio = { version = "1.42", features = ["full"] }

src/backend.rs

Lines changed: 126 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
1+
use anyhow::Result as AnyResult;
12
use libc::c_long;
23
use pgrx::pg_sys::{
3-
error, palloc0, CustomExecMethods, CustomScan, CustomScanMethods, CustomScanState, EState,
4-
ExplainState, List, ListCell, MyLatch, Node, NodeTag, ParamListInfo, RegisterCustomScanMethods,
5-
ResetLatch, TupleTableSlot, WaitLatch, PG_WAIT_EXTENSION, WL_LATCH_SET, WL_POSTMASTER_DEATH,
6-
WL_TIMEOUT,
4+
error, fetch_search_path_array, get_namespace_oid, get_relname_relid, palloc0,
5+
CustomExecMethods, CustomScan, CustomScanMethods, CustomScanState, EState, ExplainState,
6+
InvalidOid, List, ListCell, MyLatch, Node, NodeTag, Oid, ParamListInfo,
7+
RegisterCustomScanMethods, ResetLatch, TupleTableSlot, WaitLatch, PG_WAIT_EXTENSION,
8+
WL_LATCH_SET, WL_POSTMASTER_DEATH, WL_TIMEOUT,
79
};
810
use pgrx::{check_for_interrupts, pg_guard};
11+
use rmp::decode::{read_array_len, read_bin_len};
12+
use smallvec::{smallvec, SmallVec};
913
use std::ffi::c_char;
1014
use std::time::Duration;
1115

16+
use crate::error::FusionError;
1217
use crate::ipc::{my_slot, Bus, SlotHandler, SlotNumber, SlotStream, CURRENT_SLOT};
1318
use crate::protocol::{consume_header, read_error, send_params, send_query, Direction, Packet};
1419

@@ -145,6 +150,74 @@ unsafe extern "C" fn explain_df_scan(
145150
todo!()
146151
}
147152

153+
// We expect that the header is already consumed and the packet type is `Packet::Metadata`.
154+
fn table_oids(stream: &mut SlotStream) -> AnyResult<SmallVec<[Oid; 16]>> {
155+
let table_not_found = |c_table_name: &[u8]| -> Result<(), FusionError> {
156+
assert!(!c_table_name.is_empty());
157+
let table_name = c_table_name[..c_table_name.len() - 1].as_ref();
158+
match std::str::from_utf8(table_name) {
159+
Ok(name) => Err(FusionError::NotFound("Table".into(), name.into())),
160+
Err(_) => Err(FusionError::NotFound(
161+
"Table".into(),
162+
format!("{:?}", table_name),
163+
)),
164+
}
165+
};
166+
let table_num = read_array_len(stream)?;
167+
let mut oids: SmallVec<[Oid; 16]> = SmallVec::with_capacity(table_num as usize);
168+
for _ in 0..table_num {
169+
let elem_num = read_array_len(stream)?;
170+
match elem_num {
171+
1 => {
172+
let table_len = read_bin_len(stream)?;
173+
let table_name = stream.look_ahead(table_len as usize)?;
174+
let mut search_path: [Oid; 16] = [InvalidOid; 16];
175+
let path_len = unsafe {
176+
fetch_search_path_array(search_path.as_mut_ptr(), search_path.len() as i32)
177+
};
178+
let path = &search_path[..path_len as usize];
179+
let mut rel_oid = InvalidOid;
180+
for ns_oid in path {
181+
rel_oid =
182+
unsafe { get_relname_relid(table_name.as_ptr() as *const c_char, *ns_oid) };
183+
if rel_oid != InvalidOid {
184+
oids.push(rel_oid);
185+
break;
186+
}
187+
}
188+
if rel_oid == InvalidOid {
189+
table_not_found(table_name)?;
190+
}
191+
stream.rewind(table_len as usize)?;
192+
}
193+
2 => {
194+
let schema_len = read_bin_len(stream)?;
195+
let schema_name = stream.look_ahead(schema_len as usize)?;
196+
// Through an error if schema name not found.
197+
let ns_oid =
198+
unsafe { get_namespace_oid(schema_name.as_ptr() as *const c_char, false) };
199+
stream.rewind(schema_len as usize)?;
200+
let table_len = read_bin_len(stream)?;
201+
let table_name = stream.look_ahead(table_len as usize)?;
202+
let rel_oid =
203+
unsafe { get_relname_relid(table_name.as_ptr() as *const c_char, ns_oid) };
204+
if rel_oid == InvalidOid {
205+
table_not_found(table_name)?;
206+
}
207+
stream.rewind(table_len as usize)?;
208+
oids.push(rel_oid);
209+
}
210+
_ => {
211+
return Err(FusionError::InvalidName(
212+
"Table".into(),
213+
"support only 'schema.table' format".into(),
214+
))?
215+
}
216+
}
217+
}
218+
Ok(oids)
219+
}
220+
148221
fn wait_latch(timeout: Option<Duration>) {
149222
let timeout: c_long = timeout
150223
.map(|t| t.as_millis().try_into().unwrap())
@@ -215,9 +288,17 @@ fn list_nth(list: *mut List, n: i32) -> *mut ListCell {
215288
#[cfg(any(test, feature = "pg_test"))]
216289
#[pgrx::pg_schema]
217290
mod tests {
291+
use crate::ipc::Slot;
292+
use crate::protocol::prepare_table_refs;
293+
218294
use super::*;
295+
use datafusion_sql::TableReference;
219296
use pgrx::prelude::*;
297+
use pgrx::spi::Spi;
220298
use std::ffi::c_void;
299+
use std::ptr::addr_of_mut;
300+
301+
const SLOT_SIZE: usize = 8204;
221302

222303
#[pg_test]
223304
fn test_node() {
@@ -231,4 +312,45 @@ mod tests {
231312
pg_sys::pfree(ptr);
232313
}
233314
}
315+
316+
#[pg_test]
317+
fn test_table_oids() {
318+
Spi::run("create table if not exists t1(a int, b text);").unwrap();
319+
Spi::run("create schema if not exists s1;").unwrap();
320+
Spi::run("create table if not exists s1.t2(a int, b text);").unwrap();
321+
let t1_oid = Spi::get_one::<Oid>("select 't1'::regclass::oid;")
322+
.unwrap()
323+
.unwrap();
324+
let t2_oid = Spi::get_one::<Oid>("select 's1.t2'::regclass::oid;")
325+
.unwrap()
326+
.unwrap();
327+
328+
let mut slot_buf: [u8; SLOT_SIZE] = [1; SLOT_SIZE];
329+
let ptr = addr_of_mut!(slot_buf) as *mut u8;
330+
Slot::init(ptr, slot_buf.len());
331+
let slot = Slot::from_bytes(ptr, slot_buf.len());
332+
let mut stream: SlotStream = slot.into();
333+
334+
// Check valid tables.
335+
let t1 = TableReference::bare("t1");
336+
let t2 = TableReference::partial("s1", "t2");
337+
let tables = vec![&t1, &t2];
338+
prepare_table_refs(&mut stream, &tables).unwrap();
339+
stream.reset();
340+
let _ = consume_header(&mut stream).unwrap();
341+
let oids = table_oids(&mut stream).unwrap();
342+
assert_eq!(oids.len(), 2);
343+
assert_eq!(oids[0], t1_oid);
344+
assert_eq!(oids[1], t2_oid);
345+
stream.reset();
346+
347+
// Check invalid table.
348+
let t3 = TableReference::bare("t3");
349+
let tables = vec![&t3];
350+
prepare_table_refs(&mut stream, &tables).unwrap();
351+
stream.reset();
352+
let _ = consume_header(&mut stream).unwrap();
353+
let err = table_oids(&mut stream).unwrap_err();
354+
assert_eq!(err.to_string(), "Table not found: t3");
355+
}
234356
}

src/data_type.rs

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ fn type_to_oid(type_: &DataType) -> pg_sys::Oid {
9393
}
9494

9595
#[repr(u8)]
96-
enum EncodedType {
96+
pub(crate) enum EncodedType {
9797
Boolean = 0,
9898
Utf8 = 1,
9999
Int16 = 2,
@@ -123,9 +123,33 @@ impl TryFrom<u8> for EncodedType {
123123
8 => Ok(EncodedType::Time64),
124124
9 => Ok(EncodedType::Timestamp),
125125
10 => Ok(EncodedType::Interval),
126-
_ => Err(FusionError::DeserializeU8(
126+
_ => Err(FusionError::Deserialize(
127127
"encoded type".to_string(),
128-
value,
128+
value.into(),
129+
)),
130+
}
131+
}
132+
}
133+
134+
impl TryFrom<pg_sys::Oid> for EncodedType {
135+
type Error = FusionError;
136+
137+
fn try_from(value: pg_sys::Oid) -> Result<Self, Self::Error> {
138+
match value {
139+
pg_sys::BOOLOID => Ok(EncodedType::Boolean),
140+
pg_sys::TEXTOID => Ok(EncodedType::Utf8),
141+
pg_sys::INT2OID => Ok(EncodedType::Int16),
142+
pg_sys::INT4OID => Ok(EncodedType::Int32),
143+
pg_sys::INT8OID => Ok(EncodedType::Int64),
144+
pg_sys::FLOAT4OID => Ok(EncodedType::Float32),
145+
pg_sys::FLOAT8OID => Ok(EncodedType::Float64),
146+
pg_sys::DATEOID => Ok(EncodedType::Date32),
147+
pg_sys::TIMEOID => Ok(EncodedType::Time64),
148+
pg_sys::TIMESTAMPOID => Ok(EncodedType::Timestamp),
149+
pg_sys::INTERVALOID => Ok(EncodedType::Interval),
150+
_ => Err(FusionError::Deserialize(
151+
"encoded type".to_string(),
152+
value.as_u32().into(),
129153
)),
130154
}
131155
}

src/error.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,9 @@ pub enum FusionError {
1010
#[error("Payload is too large: {0} bytes")]
1111
PayloadTooLarge(usize),
1212
#[error("Failed to deserialize {0}: {1}")]
13-
DeserializeU8(String, u8),
13+
Deserialize(String, u64),
14+
#[error("{0} not found: {1}")]
15+
NotFound(String, String),
16+
#[error("{0} name is not valid: {1}")]
17+
InvalidName(String, String),
1418
}

src/protocol.rs

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1-
use crate::data_type::{datum_to_scalar, read_scalar_value, write_scalar_value};
1+
use crate::data_type::{datum_to_scalar, read_scalar_value, write_scalar_value, EncodedType};
22
use crate::error::FusionError;
33
use crate::ipc::{Bus, Slot, SlotNumber, SlotStream, DATA_SIZE};
44
use crate::worker::worker_id;
55
use anyhow::Result;
66
use datafusion::scalar::ScalarValue;
77
use datafusion_sql::TableReference;
8-
use pgrx::pg_sys::{ParamExternData, ProcSendSignal};
8+
use pgrx::pg_sys::{Oid, ParamExternData, ProcSendSignal};
99
use pgrx::prelude::*;
1010
use rmp::decode::{read_array_len, read_bin_len, read_pfix, read_str_len, read_u16};
11-
use rmp::encode::{write_array_len, write_bin_len, write_pfix, write_str, write_u16, RmpWrite};
11+
use rmp::encode::{
12+
write_array_len, write_bin_len, write_bool, write_pfix, write_str, write_u16, RmpWrite,
13+
};
1214

1315
#[repr(u8)]
1416
#[derive(Clone, Debug, Default, PartialEq)]
@@ -26,7 +28,10 @@ impl TryFrom<u8> for Direction {
2628
match value {
2729
0 => Ok(Direction::ToWorker),
2830
1 => Ok(Direction::ToBackend),
29-
_ => Err(FusionError::DeserializeU8("direction".to_string(), value)),
31+
_ => Err(FusionError::Deserialize(
32+
"direction".to_string(),
33+
value.into(),
34+
)),
3035
}
3136
}
3237
}
@@ -53,7 +58,7 @@ impl TryFrom<u8> for Packet {
5358
2 => Ok(Packet::Failure),
5459
3 => Ok(Packet::Metadata),
5560
4 => Ok(Packet::Parse),
56-
_ => Err(FusionError::DeserializeU8("packet".to_string(), value)),
61+
_ => Err(FusionError::Deserialize("packet".to_string(), value.into())),
5762
}
5863
}
5964
}
@@ -74,7 +79,7 @@ impl TryFrom<u8> for Flag {
7479
match value {
7580
0 => Ok(Flag::More),
7681
1 => Ok(Flag::Last),
77-
_ => Err(FusionError::DeserializeU8("flag".to_string(), value)),
82+
_ => Err(FusionError::Deserialize("flag".to_string(), value.into())),
7883
}
7984
}
8085
}
@@ -110,6 +115,8 @@ fn signal(slot_id: SlotNumber, direction: Direction) {
110115
}
111116
}
112117

118+
// HEADER
119+
113120
pub(crate) fn consume_header(stream: &mut SlotStream) -> Result<Header> {
114121
assert_eq!(stream.position(), 0);
115122
let direction = Direction::try_from(read_pfix(stream)?)?;
@@ -132,6 +139,8 @@ pub(crate) fn write_header(stream: &mut SlotStream, header: &Header) -> Result<(
132139
Ok(())
133140
}
134141

142+
// PARSE
143+
135144
/// Reads the query from the stream, but leaves the stream position at the beginning of the query.
136145
/// It is required to return the reference to the query bytes without copying them. It is the
137146
/// caller's responsibility to move the stream position to the end of the query.
@@ -170,6 +179,8 @@ pub(crate) fn send_query(slot_id: SlotNumber, mut stream: SlotStream, query: &st
170179
Ok(())
171180
}
172181

182+
// BIND
183+
173184
fn prepare_params(stream: &mut SlotStream, params: &[ParamExternData]) -> Result<()> {
174185
stream.reset();
175186
// We don't know the length of the parameters yet. So we write an invalid header
@@ -217,6 +228,8 @@ pub(crate) fn send_params(
217228
Ok(())
218229
}
219230

231+
// FAILURE
232+
220233
pub(crate) fn read_error(stream: &mut SlotStream) -> Result<String> {
221234
let len = read_str_len(stream)?;
222235
let buf = stream.look_ahead(len as usize)?;
@@ -255,6 +268,8 @@ fn write_c_str(stream: &mut SlotStream, s: &str) -> Result<()> {
255268
Ok(())
256269
}
257270

271+
// METADATA
272+
258273
/// Writes a table reference as null-terminated strings to
259274
/// the stream. It would be used by the Rust wrappers to the
260275
/// C code, so if we serialize the table and schema as
@@ -276,7 +291,10 @@ fn write_table_ref(stream: &mut SlotStream, table: &TableReference) -> Result<()
276291
Ok(())
277292
}
278293

279-
fn prepare_table_refs(stream: &mut SlotStream, tables: &[&TableReference]) -> Result<()> {
294+
pub(crate) fn prepare_table_refs(
295+
stream: &mut SlotStream,
296+
tables: &[&TableReference],
297+
) -> Result<()> {
280298
stream.reset();
281299
// We don't know the length of the tables yet. So we write an invalid header
282300
// to replace it with the correct one later.
@@ -312,6 +330,19 @@ pub(crate) fn send_table_refs(
312330
Ok(())
313331
}
314332

333+
#[inline]
334+
pub(crate) fn write_column(
335+
stream: &mut SlotStream,
336+
column: &str,
337+
is_null: bool,
338+
etype: EncodedType,
339+
) -> Result<()> {
340+
write_str(stream, column)?;
341+
write_bool(stream, is_null)?;
342+
write_pfix(stream, etype as u8)?;
343+
Ok(())
344+
}
345+
315346
#[cfg(any(test, feature = "pg_test"))]
316347
#[pg_schema]
317348
mod tests {

src/sql.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::cell::OnceCell;
22
use std::collections::HashMap;
33
use std::sync::Arc;
44

5+
use ahash::AHashMap;
56
use datafusion::arrow::datatypes::DataType;
67
use datafusion::arrow::datatypes::SchemaRef;
78
use datafusion::common::DFSchemaRef;
@@ -93,7 +94,7 @@ impl Builtin {
9394

9495
struct Catalog {
9596
builtin: Arc<Builtin>,
96-
tables: HashMap<SmolStr, Arc<dyn TableSource>>,
97+
tables: AHashMap<SmolStr, Arc<dyn TableSource>>,
9798
}
9899

99100
impl ContextProvider for Catalog {

0 commit comments

Comments
 (0)