Skip to content

Commit bcf2636

Browse files
committed
wip: return columns from the worker and fix errors
1 parent 576ef04 commit bcf2636

File tree

4 files changed

+190
-34
lines changed

4 files changed

+190
-34
lines changed

src/backend.rs

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ use smallvec::SmallVec;
1313
use std::ffi::c_char;
1414
use std::time::Duration;
1515

16+
use crate::data_type::unpack_target_entry;
1617
use crate::error::FusionError;
17-
use crate::ipc::{my_slot, Bus, SlotStream};
18+
use crate::ipc::{my_slot, worker_id, Bus, SlotStream, INVALID_PROC_NUMBER};
1819
use crate::protocol::{
1920
consume_header, read_error, request_explain, send_metadata, send_params, send_query, Direction,
2021
NeedSchema, Packet,
@@ -88,6 +89,9 @@ unsafe extern "C" fn create_df_scan_state(cscan: *mut CustomScan) -> *mut Node {
8889
}
8990
let mut skip_wait = true;
9091
loop {
92+
if worker_id() == INVALID_PROC_NUMBER {
93+
error!("Worker ID is invalid");
94+
}
9195
if !skip_wait {
9296
wait_latch(Some(BACKEND_WAIT_TIMEOUT));
9397
skip_wait = false;
@@ -96,7 +100,11 @@ unsafe extern "C" fn create_df_scan_state(cscan: *mut CustomScan) -> *mut Node {
96100
continue;
97101
};
98102
let mut stream = SlotStream::from(slot);
99-
let header = consume_header(&mut stream).expect("Failed to consume header");
103+
let header = match consume_header(&mut stream) {
104+
Ok(header) => header,
105+
// TODO: before panic we should send a Failure message to the worker.
106+
Err(err) => error!("Failed to consume header: {}", err),
107+
};
100108
if header.direction != Direction::ToBackend {
101109
continue;
102110
}
@@ -105,37 +113,44 @@ unsafe extern "C" fn create_df_scan_state(cscan: *mut CustomScan) -> *mut Node {
105113
// No data, just continue waiting.
106114
continue;
107115
}
108-
Packet::Failure => {
109-
let msg = read_error(&mut stream).expect("Failed to read the error message");
110-
error!("Failed to compile the query: {}", msg);
111-
}
116+
Packet::Failure => match read_error(&mut stream) {
117+
Ok(msg) => error!("Failed to compile the query: {}", msg),
118+
Err(err) => error!("Double error: {}", err),
119+
},
112120
Packet::Metadata => {
113-
let oids = table_oids(&mut stream).expect("Failed to read table OIDs");
121+
let oids = match table_oids(&mut stream) {
122+
Ok(oids) => oids,
123+
Err(err) => error!("Failed to read the table OIDs: {}", err),
124+
};
114125
if let Err(err) = send_metadata(my_slot(), stream, &oids) {
115126
error!("Failed to send the table metadata: {}", err);
116127
}
128+
}
129+
Packet::Bind => {
130+
let mut params: &[ParamExternData] = &[];
131+
let param_list = (*list_nth(list, 1)).ptr_value as ParamListInfo;
132+
if !param_list.is_null() {
133+
let num_params = unsafe { (*param_list).numParams } as usize;
134+
params = unsafe { (*param_list).params.as_slice(num_params) };
135+
}
136+
if let Err(err) = send_params(my_slot(), stream, params) {
137+
error!("Failed to send the parameter list: {}", err);
138+
}
139+
}
140+
Packet::Columns => {
141+
if let Err(err) = unpack_target_entry(&mut stream, (*cscan).custom_scan_tlist) {
142+
error!("Failed to unpack target entry: {}", err);
143+
}
117144
break;
118145
}
119-
Packet::Bind | Packet::Parse | Packet::Explain => {
146+
Packet::Parse | Packet::Explain => {
120147
error!(
121-
"Unexpected packet for create custom plan: {:?}",
148+
"Unexpected packet while creating a custom plan: {:?}",
122149
header.packet
123150
)
124151
}
125152
}
126153
}
127-
let mut params: &[ParamExternData] = &[];
128-
let param_list = (*list_nth(list, 1)).ptr_value as ParamListInfo;
129-
if !param_list.is_null() {
130-
let num_params = unsafe { (*param_list).numParams } as usize;
131-
params = unsafe { (*param_list).params.as_slice(num_params) };
132-
}
133-
let stream = wait_stream();
134-
if let Err(err) = send_params(my_slot(), stream, params) {
135-
error!("Failed to send the parameter list: {}", err);
136-
}
137-
// TODO: request plan fields from the worker to build custom_scan_tlist
138-
// with repack_output().
139154
let css = CustomScanState {
140155
methods: exec_methods(),
141156
..Default::default()
@@ -200,7 +215,7 @@ unsafe extern "C" fn explain_df_scan(
200215
let msg = read_error(&mut stream).expect("Failed to read the error message");
201216
error!("Failed to compile the query: {}", msg);
202217
}
203-
Packet::Metadata | Packet::Bind | Packet::Parse => {
218+
Packet::Columns | Packet::Metadata | Packet::Bind | Packet::Parse => {
204219
error!("Unexpected packet for explain: {:?}", header.packet)
205220
}
206221
Packet::Explain => {

src/data_type.rs

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ use pgrx::pg_sys::{
1313
TargetEntry, GETSTRUCT,
1414
};
1515
use rmp::decode::{
16-
read_bool, read_f32, read_f64, read_i16, read_i32, read_i64, read_str_len, RmpRead,
16+
read_array_len, read_bin_len, read_bool, read_f32, read_f64, read_i16, read_i32, read_i64,
17+
read_str_len, read_u8, RmpRead,
1718
};
1819
use rmp::encode::{
1920
write_bool, write_f32, write_f64, write_i16, write_i32, write_i64, write_pfix, write_str,
@@ -173,6 +174,27 @@ impl TryFrom<pg_sys::Oid> for EncodedType {
173174
}
174175
}
175176

177+
impl TryFrom<&DataType> for EncodedType {
178+
type Error = FusionError;
179+
180+
fn try_from(value: &DataType) -> Result<Self, Self::Error> {
181+
match value {
182+
DataType::Boolean => Ok(EncodedType::Boolean),
183+
DataType::Utf8 => Ok(EncodedType::Utf8),
184+
DataType::Int16 => Ok(EncodedType::Int16),
185+
DataType::Int32 => Ok(EncodedType::Int32),
186+
DataType::Int64 => Ok(EncodedType::Int64),
187+
DataType::Float32 => Ok(EncodedType::Float32),
188+
DataType::Float64 => Ok(EncodedType::Float64),
189+
DataType::Date32 => Ok(EncodedType::Date32),
190+
DataType::Time64(_) => Ok(EncodedType::Time64),
191+
DataType::Timestamp(_, _) => Ok(EncodedType::Timestamp),
192+
DataType::Interval(_) => Ok(EncodedType::Interval),
193+
_ => Err(FusionError::UnsupportedType(format!("{:?}", value))),
194+
}
195+
}
196+
}
197+
176198
#[inline]
177199
pub(crate) fn write_scalar_value(stream: &mut SlotStream, value: &ScalarValue) -> Result<()> {
178200
let write_null = |stream: &mut SlotStream| -> Result<()> {
@@ -416,6 +438,41 @@ pub(crate) fn repack_output(columns: &[Field]) -> *mut List {
416438
list
417439
}
418440

441+
// The header of the stream must already be consumed.
442+
pub(crate) fn unpack_target_entry(stream: &mut SlotStream, list: *mut List) -> Result<()> {
443+
let column_len = read_array_len(stream)?;
444+
assert!(column_len < i16::MAX as u32);
445+
for position in 0..column_len {
446+
let pos = position as i16 + 1;
447+
let etype = read_u8(stream)?;
448+
let oid = type_to_oid(&EncodedType::try_from(etype)?.to_arrow());
449+
let tuple =
450+
unsafe { pg_sys::SearchSysCache1(TYPEOID as i32, pg_sys::ObjectIdGetDatum(oid)) };
451+
if tuple.is_null() {
452+
bail!(FusionError::UnsupportedType(format!("{:?}", oid)));
453+
}
454+
let name_len = read_bin_len(stream)? as usize;
455+
let name = stream.look_ahead(name_len)?;
456+
unsafe {
457+
let typtup = GETSTRUCT(tuple) as pg_sys::Form_pg_type;
458+
let expr = makeVar(
459+
pg_sys::INDEX_VAR,
460+
pos,
461+
oid,
462+
(*typtup).typtypmod,
463+
(*typtup).typcollation,
464+
0,
465+
);
466+
let col_name = palloc0(name_len) as *mut u8;
467+
std::ptr::copy_nonoverlapping(name.as_ptr(), col_name, name_len);
468+
let entry = makeTargetEntry(expr as *mut Expr, pos, col_name as *mut i8, false);
469+
pg_sys::ReleaseSysCache(tuple);
470+
list_append_unique_ptr(list, entry as *mut c_void);
471+
}
472+
}
473+
Ok(())
474+
}
475+
419476
#[cfg(any(test, feature = "pg_test"))]
420477
#[pgrx::pg_schema]
421478
mod tests {

src/protocol.rs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::ipc::{worker_id, Bus, Slot, SlotNumber, SlotStream, DATA_SIZE};
77
use crate::sql::Table;
88
use ahash::AHashMap;
99
use anyhow::Result;
10-
use datafusion::arrow::datatypes::{Field, Schema};
10+
use datafusion::arrow::datatypes::{Field, Fields, Schema};
1111
use datafusion::logical_expr::TableSource;
1212
use datafusion::scalar::ScalarValue;
1313
use datafusion_sql::TableReference;
@@ -57,6 +57,7 @@ pub enum Packet {
5757
Metadata = 3,
5858
Parse = 4,
5959
Explain = 5,
60+
Columns = 6,
6061
}
6162

6263
impl TryFrom<u8> for Packet {
@@ -71,6 +72,7 @@ impl TryFrom<u8> for Packet {
7172
3 => Ok(Packet::Metadata),
7273
4 => Ok(Packet::Parse),
7374
5 => Ok(Packet::Explain),
75+
6 => Ok(Packet::Columns),
7476
_ => Err(FusionError::Deserialize("packet".to_string(), value.into())),
7577
}
7678
}
@@ -512,6 +514,34 @@ pub(crate) fn request_explain(slot_id: SlotNumber, mut stream: SlotStream) -> Re
512514
Ok(())
513515
}
514516

517+
// COLUMNS
518+
519+
pub(crate) fn prepare_columns(stream: &mut SlotStream, columns: &Fields) -> Result<()> {
520+
stream.reset();
521+
write_header(stream, &Header::default())?;
522+
let pos_init = stream.position();
523+
write_array_len(stream, u32::try_from(columns.len())?)?;
524+
for column in columns {
525+
write_u8(stream, EncodedType::try_from(column.data_type())? as u8)?;
526+
let len = u32::try_from(column.name().len() + 1)?;
527+
write_bin_len(stream, len)?;
528+
stream.write_bytes(column.name().as_bytes())?;
529+
write_pfix(stream, 0)?;
530+
}
531+
let pos_final = stream.position();
532+
let length = u16::try_from(pos_final - pos_init)?;
533+
let header = Header {
534+
direction: Direction::ToBackend,
535+
packet: Packet::Columns,
536+
length,
537+
flag: Flag::Last,
538+
};
539+
stream.reset();
540+
write_header(stream, &header)?;
541+
stream.rewind(length as usize)?;
542+
Ok(())
543+
}
544+
515545
#[cfg(any(test, feature = "pg_test"))]
516546
#[pg_schema]
517547
mod tests {

0 commit comments

Comments
 (0)