Skip to content

Commit 1769f78

Browse files
committed
feat: refactor catalog creation from stream
1 parent 47956b8 commit 1769f78

File tree

2 files changed

+114
-55
lines changed

2 files changed

+114
-55
lines changed

src/protocol.rs

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,28 @@
1+
use std::str::from_utf8;
2+
use std::sync::Arc;
3+
14
use crate::data_type::{datum_to_scalar, read_scalar_value, write_scalar_value, EncodedType};
25
use crate::error::FusionError;
36
use crate::ipc::{Bus, Slot, SlotNumber, SlotStream, DATA_SIZE};
7+
use crate::sql::Table;
48
use crate::worker::worker_id;
9+
use ahash::AHashMap;
510
use anyhow::Result;
11+
use datafusion::arrow::datatypes::{Field, Schema};
12+
use datafusion::logical_expr::TableSource;
613
use datafusion::scalar::ScalarValue;
714
use datafusion_sql::TableReference;
815
use pgrx::pg_sys::{Oid, ParamExternData, ProcSendSignal};
916
use pgrx::prelude::*;
1017
use pgrx::{pg_guard, PgRelation};
11-
use rmp::decode::{read_array_len, read_bin_len, read_pfix, read_str_len, read_u16, read_u8};
18+
use rmp::decode::{
19+
read_array_len, read_bin_len, read_bool, read_pfix, read_str_len, read_u16, read_u32, read_u8,
20+
};
1221
use rmp::encode::{
1322
write_array_len, write_bin_len, write_bool, write_pfix, write_str, write_u16, write_u32,
1423
write_u8, RmpWrite,
1524
};
25+
use smol_str::SmolStr;
1626

1727
#[repr(u8)]
1828
#[derive(Clone, Debug, Default, PartialEq)]
@@ -416,10 +426,59 @@ pub(crate) fn send_metadata(
416426
Ok(())
417427
}
418428

429+
#[inline]
430+
pub(crate) fn consume_metadata(
431+
stream: &mut SlotStream,
432+
) -> Result<AHashMap<TableReference, Arc<dyn TableSource>>> {
433+
// The header should be consumed before calling this function.
434+
let table_num = read_array_len(stream)?;
435+
let mut tables = AHashMap::with_capacity(table_num as usize);
436+
for _ in 0..table_num {
437+
let name_part_num = read_array_len(stream)?;
438+
assert!(name_part_num == 2 || name_part_num == 3);
439+
let oid = read_u32(stream)?;
440+
let mut schema = None;
441+
if name_part_num == 3 {
442+
let ns_len = read_str_len(stream)?;
443+
let ns_bytes = stream.look_ahead(ns_len as usize)?;
444+
schema = Some(SmolStr::new(from_utf8(ns_bytes)?));
445+
stream.rewind(ns_len as usize)?;
446+
}
447+
let name_len = read_str_len(stream)?;
448+
let name_bytes = stream.look_ahead(name_len as usize)?;
449+
let name = from_utf8(name_bytes)?;
450+
let table_ref = match schema {
451+
Some(schema) => TableReference::partial(schema, name),
452+
None => TableReference::bare(name),
453+
};
454+
stream.rewind(name_len as usize)?;
455+
let column_num = read_array_len(stream)?;
456+
let mut fields = Vec::with_capacity(column_num as usize);
457+
for _ in 0..column_num {
458+
let elem_num = read_array_len(stream)?;
459+
assert_eq!(elem_num, 3);
460+
let etype = read_u8(stream)?;
461+
let df_type = EncodedType::try_from(etype)?.to_arrow();
462+
let is_nullable = read_bool(stream)?;
463+
let name_len = read_str_len(stream)?;
464+
let name_bytes = stream.look_ahead(name_len as usize)?;
465+
let name = from_utf8(name_bytes)?;
466+
let field = Field::new(name, df_type, is_nullable);
467+
stream.rewind(name_len as usize)?;
468+
fields.push(field);
469+
}
470+
let schema = Schema::new(fields);
471+
let table = Table::new(Oid::from(oid), Arc::new(schema));
472+
tables.insert(table_ref, Arc::new(table) as Arc<dyn TableSource>);
473+
}
474+
Ok(tables)
475+
}
476+
419477
#[cfg(any(test, feature = "pg_test"))]
420478
#[pg_schema]
421479
mod tests {
422480
use super::*;
481+
use datafusion::arrow::datatypes::DataType;
423482
use pgrx::pg_sys::{Datum, Oid};
424483
use rmp::decode::{read_bool, read_u32};
425484
use std::ptr::addr_of_mut;
@@ -651,4 +710,47 @@ mod tests {
651710
assert_eq!(name, b"a");
652711
stream.rewind(name_len as usize).unwrap();
653712
}
713+
#[pg_test]
714+
fn test_metadata_to_tables() {
715+
Spi::run("create table if not exists t1(a int not null, b text);").unwrap();
716+
Spi::run("create schema if not exists s1;").unwrap();
717+
Spi::run("create table if not exists s1.t2(a int);").unwrap();
718+
let t1_oid = Spi::get_one::<Oid>("select 't1'::regclass::oid;")
719+
.unwrap()
720+
.unwrap();
721+
let t2_oid = Spi::get_one::<Oid>("select 's1.t2'::regclass::oid;")
722+
.unwrap()
723+
.unwrap();
724+
725+
let mut slot_buf: [u8; SLOT_SIZE] = [1; SLOT_SIZE];
726+
let ptr = addr_of_mut!(slot_buf) as *mut u8;
727+
Slot::init(ptr, slot_buf.len());
728+
let slot = Slot::from_bytes(ptr, slot_buf.len());
729+
let mut stream: SlotStream = slot.into();
730+
731+
prepare_metadata(&[(t1_oid, false), (t2_oid, true)], &mut stream).unwrap();
732+
stream.reset();
733+
let header = consume_header(&mut stream).unwrap();
734+
assert_eq!(header.direction, Direction::ToWorker);
735+
assert_eq!(header.packet, Packet::Metadata);
736+
assert_eq!(header.flag, Flag::Last);
737+
738+
let tables = consume_metadata(&mut stream).unwrap();
739+
assert_eq!(tables.len(), 2);
740+
// t1
741+
let t1 = tables.get(&TableReference::bare("t1")).unwrap();
742+
assert_eq!(t1.schema().fields().len(), 2);
743+
assert_eq!(t1.schema().field(0).name(), "a");
744+
assert_eq!(t1.schema().field(1).name(), "b");
745+
assert_eq!(t1.schema().field(0).data_type(), &DataType::Int32);
746+
assert_eq!(t1.schema().field(1).data_type(), &DataType::Utf8);
747+
assert!(!t1.schema().field(0).is_nullable());
748+
assert!(t1.schema().field(1).is_nullable());
749+
// s1.t2
750+
let t2 = tables.get(&TableReference::partial("s1", "t2")).unwrap();
751+
assert_eq!(t2.schema().fields().len(), 1);
752+
assert_eq!(t2.schema().field(0).name(), "a");
753+
assert_eq!(t2.schema().field(0).data_type(), &DataType::Int32);
754+
assert!(t2.schema().field(0).is_nullable());
755+
}
654756
}

src/sql.rs

Lines changed: 11 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
use crate::data_type::EncodedType;
21
use crate::ipc::SlotStream;
2+
use crate::protocol::consume_metadata;
33
use ahash::AHashMap;
44
use anyhow::Result;
5-
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
5+
use datafusion::arrow::datatypes::{DataType, SchemaRef};
66
use datafusion::config::ConfigOptions;
77
use datafusion::error::DataFusionError;
88
use datafusion::error::Result as DataFusionResult;
@@ -16,21 +16,21 @@ use datafusion::logical_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF};
1616
use datafusion_sql::TableReference;
1717
use once_cell::sync::Lazy;
1818
use pgrx::pg_sys::Oid;
19-
use rmp::decode::read_array_len;
20-
use rmp::decode::read_bool;
21-
use rmp::decode::read_str_len;
22-
use rmp::decode::{read_u32, read_u8};
23-
use smol_str::SmolStr;
2419
use std::collections::HashMap;
25-
use std::str::from_utf8;
2620
use std::sync::Arc;
2721

2822
static BUILDIN: Lazy<Arc<Builtin>> = Lazy::new(|| Arc::new(Builtin::new()));
2923

3024
#[derive(PartialEq, Eq, Hash)]
3125
pub(crate) struct Table {
32-
oid: Oid,
33-
schema: SchemaRef,
26+
pub(crate) oid: Oid,
27+
pub(crate) schema: SchemaRef,
28+
}
29+
30+
impl Table {
31+
pub(crate) fn new(oid: Oid, schema: SchemaRef) -> Self {
32+
Self { oid, schema }
33+
}
3434
}
3535

3636
impl TableSource for Table {
@@ -83,52 +83,9 @@ pub(crate) struct Catalog {
8383

8484
impl Catalog {
8585
pub(crate) fn from_stream(stream: &mut SlotStream) -> Result<Self> {
86-
// The header should be consumed before calling this function.
87-
let table_num = read_array_len(stream)?;
88-
let mut tables = AHashMap::with_capacity(table_num as usize);
89-
for _ in 0..table_num {
90-
let name_part_num = read_array_len(stream)?;
91-
assert!(name_part_num == 2 || name_part_num == 3);
92-
let oid = read_u32(stream)?;
93-
let mut schema = None;
94-
if name_part_num == 3 {
95-
let ns_len = read_str_len(stream)?;
96-
let ns_bytes = stream.look_ahead(ns_len as usize)?;
97-
schema = Some(SmolStr::new(from_utf8(ns_bytes)?));
98-
stream.rewind(ns_len as usize)?;
99-
}
100-
let name_len = read_str_len(stream)?;
101-
let name_bytes = stream.look_ahead(name_len as usize)?;
102-
let name = from_utf8(name_bytes)?;
103-
let table_ref = match schema {
104-
Some(schema) => TableReference::partial(schema, name),
105-
None => TableReference::bare(name),
106-
};
107-
stream.rewind(name_len as usize)?;
108-
let column_num = read_array_len(stream)?;
109-
let mut fields = Vec::with_capacity(column_num as usize);
110-
for _ in 0..column_num {
111-
let elem_num = read_array_len(stream)?;
112-
assert_eq!(elem_num, 3);
113-
let etype = read_u8(stream)?;
114-
let df_type = EncodedType::try_from(etype)?.to_arrow();
115-
let is_nullable = read_bool(stream)?;
116-
let name_len = read_str_len(stream)?;
117-
let name_bytes = stream.look_ahead(name_len as usize)?;
118-
let name = from_utf8(name_bytes)?;
119-
let field = Field::new(name, df_type, is_nullable);
120-
fields.push(field);
121-
}
122-
let schema = Schema::new(fields);
123-
let table = Table {
124-
oid: Oid::from(oid),
125-
schema: Arc::new(schema),
126-
};
127-
tables.insert(table_ref, Arc::new(table) as Arc<dyn TableSource>);
128-
}
12986
Ok(Self {
13087
builtin: Arc::clone(&*BUILDIN),
131-
tables,
88+
tables: consume_metadata(stream)?,
13289
})
13390
}
13491
}

0 commit comments

Comments
 (0)