Skip to content

Commit 57c14e3

Browse files
committed
wip: process metadata request
1 parent 668d076 commit 57c14e3

File tree

6 files changed

+146
-15
lines changed

6 files changed

+146
-15
lines changed

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ pg17 = ["pgrx/pg17", "pgrx-tests/pg17"]
1616
pg_test = []
1717

1818
[dependencies]
19+
ahash = "0.8"
1920
anyhow = "1.0"
2021
datafusion = "44.0"
2122
datafusion-sql = "44.0"
@@ -24,6 +25,7 @@ libc = "0.2"
2425
pgrx = "0.12"
2526
rmp = "0.8"
2627
rust-fsm = { version = "0.7", features = ["diagram"] }
28+
smallvec = { version = "1.14", features = ["const_generics", "union"] }
2729
smol_str = "0.3"
2830
thiserror = "2.0"
2931
tokio = { version = "1.42", features = ["full"] }

src/backend.rs

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
use libc::c_long;
22
use pgrx::pg_sys::{
3-
error, palloc0, CustomExecMethods, CustomScan, CustomScanMethods, CustomScanState, EState,
4-
ExplainState, List, ListCell, MyLatch, Node, NodeTag, ParamListInfo, RegisterCustomScanMethods,
5-
ResetLatch, TupleTableSlot, WaitLatch, PG_WAIT_EXTENSION, WL_LATCH_SET, WL_POSTMASTER_DEATH,
6-
WL_TIMEOUT,
3+
error, fetch_search_path_array, get_namespace_oid, get_relname_relid, palloc0,
4+
CustomExecMethods, CustomScan, CustomScanMethods, CustomScanState, EState, ExplainState,
5+
InvalidOid, List, ListCell, MyLatch, Node, NodeTag, Oid, ParamListInfo,
6+
RegisterCustomScanMethods, ResetLatch, TupleTableSlot, WaitLatch, PG_WAIT_EXTENSION,
7+
WL_LATCH_SET, WL_POSTMASTER_DEATH, WL_TIMEOUT,
78
};
89
use pgrx::{check_for_interrupts, pg_guard};
10+
use rmp::decode::{read_array_len, read_bin_len};
11+
use smallvec::{smallvec, SmallVec};
912
use std::ffi::c_char;
1013
use std::time::Duration;
1114

@@ -145,6 +148,79 @@ unsafe extern "C" fn explain_df_scan(
145148
todo!()
146149
}
147150

151+
// We expect that the header is already consumed and the packet type is `Packet::Metadata`.
152+
#[pg_guard]
153+
fn table_oids(stream: &mut SlotStream) -> SmallVec<[Oid; 16]> {
154+
let table_num = read_array_len(stream).expect("Failed to read the number of tables");
155+
let mut oids: SmallVec<[Oid; 16]> = SmallVec::with_capacity(table_num as usize);
156+
for _ in 0..table_num {
157+
let elem_num = read_array_len(stream)
158+
.expect("Failed to read the number of elements in the table name");
159+
match elem_num {
160+
1 => {
161+
let table_len = read_bin_len(stream).expect("Failed to read the table name length");
162+
let table_name = stream
163+
.look_ahead(table_len as usize)
164+
.expect("Failed to read the table name");
165+
let mut search_path: [Oid; 16] = [InvalidOid; 16];
166+
let path_len = unsafe {
167+
fetch_search_path_array(search_path.as_mut_ptr(), search_path.len() as i32)
168+
};
169+
let path = &search_path[..path_len as usize];
170+
let mut rel_oid = InvalidOid;
171+
for ns_oid in path {
172+
rel_oid =
173+
unsafe { get_relname_relid(table_name.as_ptr() as *const c_char, *ns_oid) };
174+
if rel_oid != InvalidOid {
175+
oids.push(rel_oid);
176+
break;
177+
}
178+
}
179+
if rel_oid == InvalidOid {
180+
match std::str::from_utf8(table_name) {
181+
Ok(name) => error!("Table not found: {}", name),
182+
Err(_) => error!("Table not found: {:?}", table_name),
183+
}
184+
}
185+
stream
186+
.rewind(table_len as usize)
187+
.expect("Failed to rewind the stream");
188+
}
189+
2 => {
190+
let schema_len =
191+
read_bin_len(stream).expect("Failed to read the schema name length");
192+
let schema_name = stream
193+
.look_ahead(schema_len as usize)
194+
.expect("Failed to read the schema name");
195+
// Through an error if schema name not found.
196+
let ns_oid =
197+
unsafe { get_namespace_oid(schema_name.as_ptr() as *const c_char, false) };
198+
stream
199+
.rewind(schema_len as usize)
200+
.expect("Failed to rewind the stream");
201+
let table_len = read_bin_len(stream).expect("Failed to read the table name length");
202+
let table_name = stream
203+
.look_ahead(table_len as usize)
204+
.expect("Failed to read the table name");
205+
let rel_oid =
206+
unsafe { get_relname_relid(table_name.as_ptr() as *const c_char, ns_oid) };
207+
if rel_oid == InvalidOid {
208+
match std::str::from_utf8(table_name) {
209+
Ok(name) => error!("Table not found: {}", name),
210+
Err(_) => error!("Table not found: {:?}", table_name),
211+
}
212+
}
213+
stream
214+
.rewind(table_len as usize)
215+
.expect("Failed to rewind the stream");
216+
oids.push(rel_oid);
217+
}
218+
_ => error!("Table name should consist either of the name or the schema and the name"),
219+
}
220+
}
221+
oids
222+
}
223+
148224
fn wait_latch(timeout: Option<Duration>) {
149225
let timeout: c_long = timeout
150226
.map(|t| t.as_millis().try_into().unwrap())

src/data_type.rs

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ fn type_to_oid(type_: &DataType) -> pg_sys::Oid {
9393
}
9494

9595
#[repr(u8)]
96-
enum EncodedType {
96+
pub(crate) enum EncodedType {
9797
Boolean = 0,
9898
Utf8 = 1,
9999
Int16 = 2,
@@ -123,9 +123,33 @@ impl TryFrom<u8> for EncodedType {
123123
8 => Ok(EncodedType::Time64),
124124
9 => Ok(EncodedType::Timestamp),
125125
10 => Ok(EncodedType::Interval),
126-
_ => Err(FusionError::DeserializeU8(
126+
_ => Err(FusionError::Deserialize(
127127
"encoded type".to_string(),
128-
value,
128+
value.into(),
129+
)),
130+
}
131+
}
132+
}
133+
134+
impl TryFrom<pg_sys::Oid> for EncodedType {
135+
type Error = FusionError;
136+
137+
fn try_from(value: pg_sys::Oid) -> Result<Self, Self::Error> {
138+
match value {
139+
pg_sys::BOOLOID => Ok(EncodedType::Boolean),
140+
pg_sys::TEXTOID => Ok(EncodedType::Utf8),
141+
pg_sys::INT2OID => Ok(EncodedType::Int16),
142+
pg_sys::INT4OID => Ok(EncodedType::Int32),
143+
pg_sys::INT8OID => Ok(EncodedType::Int64),
144+
pg_sys::FLOAT4OID => Ok(EncodedType::Float32),
145+
pg_sys::FLOAT8OID => Ok(EncodedType::Float64),
146+
pg_sys::DATEOID => Ok(EncodedType::Date32),
147+
pg_sys::TIMEOID => Ok(EncodedType::Time64),
148+
pg_sys::TIMESTAMPOID => Ok(EncodedType::Timestamp),
149+
pg_sys::INTERVALOID => Ok(EncodedType::Interval),
150+
_ => Err(FusionError::Deserialize(
151+
"encoded type".to_string(),
152+
value.as_u32().into(),
129153
)),
130154
}
131155
}

src/error.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@ pub enum FusionError {
1010
#[error("Payload is too large: {0} bytes")]
1111
PayloadTooLarge(usize),
1212
#[error("Failed to deserialize {0}: {1}")]
13-
DeserializeU8(String, u8),
13+
Deserialize(String, u64),
1414
}

src/protocol.rs

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1-
use crate::data_type::{datum_to_scalar, read_scalar_value, write_scalar_value};
1+
use crate::data_type::{datum_to_scalar, read_scalar_value, write_scalar_value, EncodedType};
22
use crate::error::FusionError;
33
use crate::ipc::{Bus, Slot, SlotNumber, SlotStream, DATA_SIZE};
44
use crate::worker::worker_id;
55
use anyhow::Result;
66
use datafusion::scalar::ScalarValue;
77
use datafusion_sql::TableReference;
8-
use pgrx::pg_sys::{ParamExternData, ProcSendSignal};
8+
use pgrx::pg_sys::{Oid, ParamExternData, ProcSendSignal};
99
use pgrx::prelude::*;
1010
use rmp::decode::{read_array_len, read_bin_len, read_pfix, read_str_len, read_u16};
11-
use rmp::encode::{write_array_len, write_bin_len, write_pfix, write_str, write_u16, RmpWrite};
11+
use rmp::encode::{
12+
write_array_len, write_bin_len, write_bool, write_pfix, write_str, write_u16, RmpWrite,
13+
};
1214

1315
#[repr(u8)]
1416
#[derive(Clone, Debug, Default, PartialEq)]
@@ -26,7 +28,10 @@ impl TryFrom<u8> for Direction {
2628
match value {
2729
0 => Ok(Direction::ToWorker),
2830
1 => Ok(Direction::ToBackend),
29-
_ => Err(FusionError::DeserializeU8("direction".to_string(), value)),
31+
_ => Err(FusionError::Deserialize(
32+
"direction".to_string(),
33+
value.into(),
34+
)),
3035
}
3136
}
3237
}
@@ -53,7 +58,7 @@ impl TryFrom<u8> for Packet {
5358
2 => Ok(Packet::Failure),
5459
3 => Ok(Packet::Metadata),
5560
4 => Ok(Packet::Parse),
56-
_ => Err(FusionError::DeserializeU8("packet".to_string(), value)),
61+
_ => Err(FusionError::Deserialize("packet".to_string(), value.into())),
5762
}
5863
}
5964
}
@@ -74,7 +79,7 @@ impl TryFrom<u8> for Flag {
7479
match value {
7580
0 => Ok(Flag::More),
7681
1 => Ok(Flag::Last),
77-
_ => Err(FusionError::DeserializeU8("flag".to_string(), value)),
82+
_ => Err(FusionError::Deserialize("flag".to_string(), value.into())),
7883
}
7984
}
8085
}
@@ -110,6 +115,8 @@ fn signal(slot_id: SlotNumber, direction: Direction) {
110115
}
111116
}
112117

118+
// HEADER
119+
113120
pub(crate) fn consume_header(stream: &mut SlotStream) -> Result<Header> {
114121
assert_eq!(stream.position(), 0);
115122
let direction = Direction::try_from(read_pfix(stream)?)?;
@@ -132,6 +139,8 @@ pub(crate) fn write_header(stream: &mut SlotStream, header: &Header) -> Result<(
132139
Ok(())
133140
}
134141

142+
// PARSE
143+
135144
/// Reads the query from the stream, but leaves the stream position at the beginning of the query.
136145
/// It is required to return the reference to the query bytes without copying them. It is the
137146
/// caller's responsibility to move the stream position to the end of the query.
@@ -170,6 +179,8 @@ pub(crate) fn send_query(slot_id: SlotNumber, mut stream: SlotStream, query: &st
170179
Ok(())
171180
}
172181

182+
// BIND
183+
173184
fn prepare_params(stream: &mut SlotStream, params: &[ParamExternData]) -> Result<()> {
174185
stream.reset();
175186
// We don't know the length of the parameters yet. So we write an invalid header
@@ -217,6 +228,8 @@ pub(crate) fn send_params(
217228
Ok(())
218229
}
219230

231+
// FAILURE
232+
220233
pub(crate) fn read_error(stream: &mut SlotStream) -> Result<String> {
221234
let len = read_str_len(stream)?;
222235
let buf = stream.look_ahead(len as usize)?;
@@ -255,6 +268,8 @@ fn write_c_str(stream: &mut SlotStream, s: &str) -> Result<()> {
255268
Ok(())
256269
}
257270

271+
// METADATA
272+
258273
/// Writes a table reference as null-terminated strings to
259274
/// the stream. It would be used by the Rust wrappers to the
260275
/// C code, so if we serialize the table and schema as
@@ -312,6 +327,19 @@ pub(crate) fn send_table_refs(
312327
Ok(())
313328
}
314329

330+
#[inline]
331+
pub(crate) fn write_column(
332+
stream: &mut SlotStream,
333+
column: &str,
334+
is_null: bool,
335+
etype: EncodedType,
336+
) -> Result<()> {
337+
write_str(stream, column)?;
338+
write_bool(stream, is_null)?;
339+
write_pfix(stream, etype as u8)?;
340+
Ok(())
341+
}
342+
315343
#[cfg(any(test, feature = "pg_test"))]
316344
#[pg_schema]
317345
mod tests {

src/sql.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::cell::OnceCell;
22
use std::collections::HashMap;
33
use std::sync::Arc;
44

5+
use ahash::AHashMap;
56
use datafusion::arrow::datatypes::DataType;
67
use datafusion::arrow::datatypes::SchemaRef;
78
use datafusion::common::DFSchemaRef;
@@ -93,7 +94,7 @@ impl Builtin {
9394

9495
struct Catalog {
9596
builtin: Arc<Builtin>,
96-
tables: HashMap<SmolStr, Arc<dyn TableSource>>,
97+
tables: AHashMap<SmolStr, Arc<dyn TableSource>>,
9798
}
9899

99100
impl ContextProvider for Catalog {

0 commit comments

Comments
 (0)