Skip to content

Commit bcf1cac

Browse files
committed
feat: handle metadata and bind packets
1 parent b91b5d4 commit bcf1cac

File tree

6 files changed

+199
-81
lines changed

6 files changed

+199
-81
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ datafusion = "44.0"
2222
datafusion-sql = "44.0"
2323
fasthash = "0.4"
2424
libc = "0.2"
25+
once_cell = "1.21"
2526
pgrx = "0.12"
2627
rmp = "0.8"
2728
rust-fsm = { version = "0.7", features = ["diagram"] }

src/backend.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,18 +113,16 @@ unsafe extern "C" fn create_df_scan_state(cscan: *mut CustomScan) -> *mut Node {
113113
error!("Failed to send the table metadata: {}", err);
114114
}
115115
}
116-
Packet::Ack => break,
117-
_ => error!("Unexpected packet in backend: {:?}", header.packet),
116+
Packet::Bind => break,
117+
Packet::Parse => error!("Unexpected packet in backend: {:?}", header.packet),
118118
}
119119
}
120120
let param_list = (*list_nth(list, 1)).ptr_value as ParamListInfo;
121121
let num_params = unsafe { (*param_list).numParams } as usize;
122-
if num_params > 0 {
123-
let params = unsafe { (*param_list).params.as_slice(num_params) };
124-
let stream = wait_stream();
125-
if let Err(err) = send_params(my_slot(), stream, params) {
126-
error!("Failed to send the parameter list: {}", err);
127-
}
122+
let params = unsafe { (*param_list).params.as_slice(num_params) };
123+
let stream = wait_stream();
124+
if let Err(err) = send_params(my_slot(), stream, params) {
125+
error!("Failed to send the parameter list: {}", err);
128126
}
129127
let css = CustomScanState {
130128
methods: exec_methods(),

src/data_type.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::error::FusionError;
22
use crate::ipc::SlotStream;
33
use anyhow::{bail, Result};
4-
use datafusion::arrow::datatypes::DataType;
4+
use datafusion::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
55
use datafusion::common::arrow::array::types::IntervalMonthDayNano;
66
use datafusion::common::arrow::datatypes::Field;
77
use datafusion::common::ScalarValue;
@@ -107,6 +107,24 @@ pub(crate) enum EncodedType {
107107
Interval = 10,
108108
}
109109

110+
impl EncodedType {
111+
pub(crate) fn to_arrow(&self) -> DataType {
112+
match self {
113+
EncodedType::Boolean => DataType::Boolean,
114+
EncodedType::Utf8 => DataType::Utf8,
115+
EncodedType::Int16 => DataType::Int16,
116+
EncodedType::Int32 => DataType::Int32,
117+
EncodedType::Int64 => DataType::Int64,
118+
EncodedType::Float32 => DataType::Float32,
119+
EncodedType::Float64 => DataType::Float64,
120+
EncodedType::Date32 => DataType::Date32,
121+
EncodedType::Time64 => DataType::Time64(TimeUnit::Microsecond),
122+
EncodedType::Timestamp => DataType::Timestamp(TimeUnit::Microsecond, None),
123+
EncodedType::Interval => DataType::Interval(IntervalUnit::MonthDayNano),
124+
}
125+
}
126+
}
127+
110128
impl TryFrom<u8> for EncodedType {
111129
type Error = FusionError;
112130

src/protocol.rs

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,10 @@ impl TryFrom<u8> for Direction {
4242
#[derive(Clone, Debug, Default, PartialEq)]
4343
pub enum Packet {
4444
#[default]
45-
Ack = 0,
46-
Bind = 1,
47-
Failure = 2,
48-
Metadata = 3,
49-
Parse = 4,
45+
Bind = 0,
46+
Failure = 1,
47+
Metadata = 2,
48+
Parse = 3,
5049
}
5150

5251
impl TryFrom<u8> for Packet {
@@ -55,11 +54,10 @@ impl TryFrom<u8> for Packet {
5554
fn try_from(value: u8) -> Result<Self, Self::Error> {
5655
assert!(value < 128);
5756
match value {
58-
0 => Ok(Packet::Ack),
59-
1 => Ok(Packet::Bind),
60-
2 => Ok(Packet::Failure),
61-
3 => Ok(Packet::Metadata),
62-
4 => Ok(Packet::Parse),
57+
0 => Ok(Packet::Bind),
58+
1 => Ok(Packet::Failure),
59+
2 => Ok(Packet::Metadata),
60+
3 => Ok(Packet::Parse),
6361
_ => Err(FusionError::Deserialize("packet".to_string(), value.into())),
6462
}
6563
}
@@ -230,6 +228,21 @@ pub(crate) fn send_params(
230228
Ok(())
231229
}
232230

231+
pub(crate) fn request_params(slot_id: SlotNumber, mut stream: SlotStream) -> Result<()> {
232+
stream.reset();
233+
let header = Header {
234+
direction: Direction::ToBackend,
235+
packet: Packet::Bind,
236+
length: 0,
237+
flag: Flag::Last,
238+
};
239+
write_header(&mut stream, &header)?;
240+
// Unlock the slot after writing the parameters.
241+
let _guard = Slot::from(stream);
242+
signal(slot_id, Direction::ToBackend);
243+
Ok(())
244+
}
245+
233246
// FAILURE
234247

235248
pub(crate) fn read_error(stream: &mut SlotStream) -> Result<String> {
@@ -334,21 +347,23 @@ pub(crate) fn send_table_refs(
334347
fn serialize_table(rel_oid: Oid, stream: &mut SlotStream) -> Result<()> {
335348
// The destructor will release the lock.
336349
let rel = unsafe { PgRelation::with_lock(rel_oid, pg_sys::AccessShareLock as i32) };
350+
write_u32(stream, rel_oid.as_u32())?;
351+
write_str(stream, rel.namespace())?;
352+
write_str(stream, rel.name())?;
337353
let tuple_desc = rel.tuple_desc();
338354
let attr_num = u32::try_from(tuple_desc.iter().filter(|a| !a.is_dropped()).count())?;
339-
write_u32(stream, rel_oid.as_u32())?;
340355
write_array_len(stream, attr_num)?;
341356
for attr in tuple_desc.iter() {
342357
if attr.is_dropped() {
343358
continue;
344359
}
360+
write_array_len(stream, 3)?;
345361
let etype = EncodedType::try_from(attr.type_oid().value())?;
362+
write_u8(stream, etype as u8)?;
346363
let is_nullable = !attr.attnotnull;
364+
write_bool(stream, is_nullable)?;
347365
let name = attr.name();
348-
write_array_len(stream, 3)?;
349366
write_str(stream, name)?;
350-
write_u8(stream, etype as u8)?;
351-
write_bool(stream, is_nullable)?;
352367
}
353368
Ok(())
354369
}
@@ -403,7 +418,7 @@ mod tests {
403418
fn test_header() {
404419
let header = Header {
405420
direction: Direction::ToWorker,
406-
packet: Packet::Ack,
421+
packet: Packet::Parse,
407422
length: 42,
408423
flag: Flag::Last,
409424
};
@@ -556,31 +571,39 @@ mod tests {
556571
// t1
557572
let oid = read_u32(&mut stream).unwrap();
558573
assert_eq!(oid, t1_oid.as_u32());
574+
let ns_len = read_str_len(&mut stream).unwrap();
575+
let ns = stream.look_ahead(ns_len as usize).unwrap();
576+
assert_eq!(ns, b"public");
577+
stream.rewind(ns_len as usize).unwrap();
578+
let name_len = read_str_len(&mut stream).unwrap();
579+
let name = stream.look_ahead(name_len as usize).unwrap();
580+
assert_eq!(name, b"t1");
581+
stream.rewind(name_len as usize).unwrap();
559582
let attr_num = read_array_len(&mut stream).unwrap();
560583
assert_eq!(attr_num, 2);
561584
// a
562585
let elem_num = read_array_len(&mut stream).unwrap();
563586
assert_eq!(elem_num, 3);
587+
let etype = read_u8(&mut stream).unwrap();
588+
assert_eq!(etype, EncodedType::Int32 as u8);
589+
let is_nullable = read_bool(&mut stream).unwrap();
590+
assert!(!is_nullable);
564591
let name_len = read_str_len(&mut stream).unwrap();
565592
assert_eq!(name_len, "a".len() as u32);
566593
let name = stream.look_ahead(name_len as usize).unwrap();
567594
assert_eq!(name, b"a");
568595
stream.rewind(name_len as usize).unwrap();
569-
let etype = read_u8(&mut stream).unwrap();
570-
assert_eq!(etype, EncodedType::Int32 as u8);
571-
let is_nullable = read_bool(&mut stream).unwrap();
572-
assert!(!is_nullable);
573596
// b
574597
let elem_num = read_array_len(&mut stream).unwrap();
575598
assert_eq!(elem_num, 3);
599+
let etype = read_u8(&mut stream).unwrap();
600+
assert_eq!(etype, EncodedType::Utf8 as u8);
601+
let is_nullable = read_bool(&mut stream).unwrap();
602+
assert!(is_nullable);
576603
let name_len = read_str_len(&mut stream).unwrap();
577604
assert_eq!(name_len, "b".len() as u32);
578605
let name = stream.look_ahead(name_len as usize).unwrap();
579606
assert_eq!(name, b"b");
580607
stream.rewind(name_len as usize).unwrap();
581-
let etype = read_u8(&mut stream).unwrap();
582-
assert_eq!(etype, EncodedType::Utf8 as u8);
583-
let is_nullable = read_bool(&mut stream).unwrap();
584-
assert!(is_nullable);
585608
}
586609
}

src/sql.rs

Lines changed: 76 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
1-
use std::cell::OnceCell;
2-
use std::collections::HashMap;
3-
use std::sync::Arc;
4-
1+
use crate::data_type::EncodedType;
2+
use crate::ipc::SlotStream;
53
use ahash::AHashMap;
6-
use datafusion::arrow::datatypes::DataType;
7-
use datafusion::arrow::datatypes::SchemaRef;
8-
use datafusion::common::DFSchemaRef;
9-
use datafusion::common::ScalarValue;
4+
use anyhow::Result;
5+
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
106
use datafusion::config::ConfigOptions;
117
use datafusion::error::DataFusionError;
128
use datafusion::error::Result as DataFusionResult;
@@ -16,34 +12,22 @@ use datafusion::functions_aggregate::all_default_aggregate_functions;
1612
use datafusion::functions_window::all_default_window_functions;
1713
use datafusion::logical_expr::planner::ContextProvider;
1814
use datafusion::logical_expr::planner::ExprPlanner;
19-
use datafusion::logical_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF};
20-
use datafusion_sql::planner::SqlToRel;
21-
use datafusion_sql::sqlparser::dialect::PostgreSqlDialect;
22-
use datafusion_sql::sqlparser::parser::Parser;
15+
use datafusion::logical_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF};
2316
use datafusion_sql::TableReference;
17+
use once_cell::sync::Lazy;
2418
use pgrx::pg_sys::Oid;
19+
use rmp::decode::read_array_len;
20+
use rmp::decode::read_bool;
21+
use rmp::decode::read_str_len;
22+
use rmp::decode::{read_u32, read_u8};
2523
use smol_str::SmolStr;
24+
use std::collections::HashMap;
25+
use std::str::from_utf8;
26+
use std::sync::Arc;
2627

27-
// fn sql_to_logical_plan(
28-
// sql: &str,
29-
// params: Vec<ScalarValue>,
30-
// ) -> Result<LogicalPlan, DataFusionError> {
31-
// let dialect = PostgreSqlDialect {};
32-
// let ast = Parser::parse_sql(&dialect, sql).map_err(|e| DataFusionError::SQL(e, None))?;
33-
// assert_eq!(ast.len(), 1);
34-
// let statement = ast.into_iter().next().expect("ast is not empty");
35-
//
36-
// // Cash metadata provider in a static variable to avoid re-allocation on each query.
37-
// let base_plan = CATALOG.with(|catalog| {
38-
// let catalog = catalog.get_or_init(Builtin::new);
39-
// let sql_to_rel = SqlToRel::new(catalog);
40-
// sql_to_rel.sql_statement_to_plan(statement)
41-
// })?;
42-
// let plan = base_plan.with_param_values(params)?;
43-
//
44-
// Ok(plan)
45-
// }
28+
static BUILDIN: Lazy<Arc<Builtin>> = Lazy::new(|| Arc::new(Builtin::new()));
4629

30+
#[derive(PartialEq, Eq, Hash)]
4731
pub(crate) struct Table {
4832
oid: Oid,
4933
schema: SchemaRef,
@@ -92,36 +76,86 @@ impl Builtin {
9276
}
9377
}
9478

95-
struct Catalog {
79+
pub(crate) struct Catalog {
9680
builtin: Arc<Builtin>,
97-
tables: AHashMap<SmolStr, Arc<dyn TableSource>>,
81+
tables: AHashMap<TableReference, Arc<dyn TableSource>>,
82+
}
83+
84+
impl Catalog {
85+
pub(crate) fn from_stream(stream: &mut SlotStream) -> Result<Self> {
86+
// The header should be consumed before calling this function.
87+
let table_num = read_array_len(stream)?;
88+
let mut tables = AHashMap::with_capacity(table_num as usize);
89+
for _ in 0..table_num {
90+
let oid = read_u32(stream)?;
91+
let ns_len = read_str_len(stream)?;
92+
let ns_bytes = stream.look_ahead(ns_len as usize)?;
93+
let schema = SmolStr::from(from_utf8(ns_bytes)?);
94+
stream.rewind(ns_len as usize)?;
95+
let name_len = read_str_len(stream)?;
96+
let name_bytes = stream.look_ahead(name_len as usize)?;
97+
let name = from_utf8(name_bytes)?;
98+
let table_ref = TableReference::partial(schema.as_str(), name);
99+
stream.rewind(name_len as usize)?;
100+
let column_num = read_array_len(stream)?;
101+
let mut fields = Vec::with_capacity(column_num as usize);
102+
for _ in 0..column_num {
103+
let elem_num = read_array_len(stream)?;
104+
assert_eq!(elem_num, 3);
105+
let etype = read_u8(stream)?;
106+
let df_type = EncodedType::try_from(etype)?.to_arrow();
107+
let is_nullable = read_bool(stream)?;
108+
let name_len = read_str_len(stream)?;
109+
let name_bytes = stream.look_ahead(name_len as usize)?;
110+
let name = from_utf8(name_bytes)?;
111+
let field = Field::new(name, df_type, is_nullable);
112+
fields.push(field);
113+
}
114+
let schema = Schema::new(fields);
115+
let table = Table {
116+
oid: Oid::from(oid),
117+
schema: Arc::new(schema),
118+
};
119+
tables.insert(table_ref, Arc::new(table) as Arc<dyn TableSource>);
120+
}
121+
Ok(Self {
122+
builtin: Arc::clone(&*BUILDIN),
123+
tables,
124+
})
125+
}
98126
}
99127

100128
impl ContextProvider for Catalog {
101-
fn get_table_source(&self, name: TableReference) -> DataFusionResult<Arc<dyn TableSource>> {
102-
match self.tables.get(name.table()) {
129+
fn get_table_source(&self, table: TableReference) -> DataFusionResult<Arc<dyn TableSource>> {
130+
match self.tables.get(&table) {
103131
Some(table) => Ok(Arc::clone(table)),
104-
_ => Err(DataFusionError::Plan(format!(
105-
"Table not found: {}",
106-
name.table()
107-
))),
132+
_ => {
133+
let schema = match table.schema() {
134+
Some(schema) => &format!("{schema}."),
135+
_ => table.table(),
136+
};
137+
Err(DataFusionError::Plan(format!(
138+
"Table not found: {schema}{}",
139+
table.table()
140+
)))
141+
}
108142
}
109143
}
110144

111145
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
112-
self.builtin.scalar_udf.get(name).map(|f| Arc::clone(f))
146+
self.builtin.scalar_udf.get(name).map(Arc::clone)
113147
}
114148

115149
fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
116-
self.builtin.agg_udf.get(name).map(|f| Arc::clone(f))
150+
self.builtin.agg_udf.get(name).map(Arc::clone)
117151
}
118152

119153
fn get_variable_type(&self, _variable_names: &[String]) -> Option<DataType> {
120154
None
121155
}
122156

123157
fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
124-
self.builtin.window_udf.get(name).map(|f| Arc::clone(f))
158+
self.builtin.window_udf.get(name).map(Arc::clone)
125159
}
126160

127161
fn options(&self) -> &ConfigOptions {

0 commit comments

Comments
 (0)