@@ -7,9 +7,11 @@ use datafusion::scalar::ScalarValue;
7
7
use datafusion_sql:: TableReference ;
8
8
use pgrx:: pg_sys:: { Oid , ParamExternData , ProcSendSignal } ;
9
9
use pgrx:: prelude:: * ;
10
- use rmp:: decode:: { read_array_len, read_bin_len, read_pfix, read_str_len, read_u16} ;
10
+ use pgrx:: { pg_guard, PgRelation } ;
11
+ use rmp:: decode:: { read_array_len, read_bin_len, read_pfix, read_str_len, read_u16, read_u8} ;
11
12
use rmp:: encode:: {
12
- write_array_len, write_bin_len, write_bool, write_pfix, write_str, write_u16, RmpWrite ,
13
+ write_array_len, write_bin_len, write_bool, write_pfix, write_str, write_u16, write_u32,
14
+ write_u8, RmpWrite ,
13
15
} ;
14
16
15
17
#[ repr( u8 ) ]
@@ -331,23 +333,71 @@ pub(crate) fn send_table_refs(
331
333
}
332
334
333
335
#[ inline]
334
- pub ( crate ) fn write_column (
335
- stream : & mut SlotStream ,
336
- column : & str ,
337
- is_null : bool ,
338
- etype : EncodedType ,
336
+ #[ pg_guard]
337
+ fn serialize_table ( rel_oid : Oid , stream : & mut SlotStream ) -> Result < ( ) > {
338
+ // The destructor will release the lock.
339
+ let rel = unsafe { PgRelation :: with_lock ( rel_oid, pg_sys:: AccessShareLock as i32 ) } ;
340
+ let tuple_desc = rel. tuple_desc ( ) ;
341
+ let attr_num = u32:: try_from ( tuple_desc. iter ( ) . filter ( |a| !a. is_dropped ( ) ) . count ( ) ) ?;
342
+ write_u32 ( stream, rel_oid. as_u32 ( ) ) ?;
343
+ write_array_len ( stream, attr_num) ?;
344
+ for attr in tuple_desc. iter ( ) {
345
+ if attr. is_dropped ( ) {
346
+ continue ;
347
+ }
348
+ let etype = EncodedType :: try_from ( attr. type_oid ( ) . value ( ) ) ?;
349
+ let is_nullable = !attr. attnotnull ;
350
+ let name = attr. name ( ) ;
351
+ write_array_len ( stream, 3 ) ?;
352
+ write_str ( stream, name) ?;
353
+ write_u8 ( stream, etype as u8 ) ?;
354
+ write_bool ( stream, is_nullable) ?;
355
+ }
356
+ Ok ( ( ) )
357
+ }
358
+
359
+ pub ( crate ) fn prepare_metadata ( rel_oids : & [ Oid ] , stream : & mut SlotStream ) -> Result < ( ) > {
360
+ stream. reset ( ) ;
361
+ // We don't know the length of the table metadata yet. So we write
362
+ // an invalid header to replace it with the correct one later.
363
+ write_header ( stream, & Header :: default ( ) ) ?;
364
+ let pos_init = stream. position ( ) ;
365
+ write_array_len ( stream, rel_oids. len ( ) as u32 ) ?;
366
+ for & rel_oid in rel_oids {
367
+ serialize_table ( rel_oid, stream) ?;
368
+ }
369
+ let pos_final = stream. position ( ) ;
370
+ let length = u16:: try_from ( pos_final - pos_init) ?;
371
+ let header = Header {
372
+ direction : Direction :: ToWorker ,
373
+ packet : Packet :: Metadata ,
374
+ length,
375
+ flag : Flag :: Last ,
376
+ } ;
377
+ stream. reset ( ) ;
378
+ write_header ( stream, & header) ?;
379
+ stream. rewind ( length as usize ) ?;
380
+ Ok ( ( ) )
381
+ }
382
+
383
+ pub ( crate ) fn send_metadata (
384
+ slot_id : SlotNumber ,
385
+ mut stream : SlotStream ,
386
+ rel_oids : & [ Oid ] ,
339
387
) -> Result < ( ) > {
340
- write_str ( stream, column) ?;
341
- write_bool ( stream, is_null) ?;
342
- write_pfix ( stream, etype as u8 ) ?;
388
+ prepare_metadata ( rel_oids, & mut stream) ?;
389
+ // Unlock the slot after writing the metadata.
390
+ let _guard = Slot :: from ( stream) ;
391
+ signal ( slot_id, Direction :: ToWorker ) ;
343
392
Ok ( ( ) )
344
393
}
345
394
346
395
#[ cfg( any( test, feature = "pg_test" ) ) ]
347
396
#[ pg_schema]
348
397
mod tests {
349
398
use super :: * ;
350
- use pgrx:: pg_sys:: Datum ;
399
+ use pgrx:: pg_sys:: { Datum , Oid } ;
400
+ use rmp:: decode:: { read_bool, read_u32} ;
351
401
use std:: ptr:: addr_of_mut;
352
402
353
403
const SLOT_SIZE : usize = 8204 ;
@@ -482,4 +532,58 @@ mod tests {
482
532
let t2 = stream. look_ahead ( t2_len as usize ) . unwrap ( ) ;
483
533
assert_eq ! ( t2, b"table2\0 " ) ;
484
534
}
535
+
536
+ #[ pg_test]
537
+ fn test_metadata_response ( ) {
538
+ Spi :: run ( "create table if not exists t1(a int not null, b text);" ) . unwrap ( ) ;
539
+ let t1_oid = Spi :: get_one :: < Oid > ( "select 't1'::regclass::oid;" )
540
+ . unwrap ( )
541
+ . unwrap ( ) ;
542
+
543
+ let mut slot_buf: [ u8 ; SLOT_SIZE ] = [ 1 ; SLOT_SIZE ] ;
544
+ let ptr = addr_of_mut ! ( slot_buf) as * mut u8 ;
545
+ Slot :: init ( ptr, slot_buf. len ( ) ) ;
546
+ let slot = Slot :: from_bytes ( ptr, slot_buf. len ( ) ) ;
547
+ let mut stream: SlotStream = slot. into ( ) ;
548
+
549
+ prepare_metadata ( & [ t1_oid] , & mut stream) . unwrap ( ) ;
550
+ stream. reset ( ) ;
551
+ let header = consume_header ( & mut stream) . unwrap ( ) ;
552
+ assert_eq ! ( header. direction, Direction :: ToWorker ) ;
553
+ assert_eq ! ( header. packet, Packet :: Metadata ) ;
554
+ assert_eq ! ( header. flag, Flag :: Last ) ;
555
+
556
+ // Check table metadata deserialization
557
+ let table_num = read_array_len ( & mut stream) . unwrap ( ) ;
558
+ assert_eq ! ( table_num, 1 ) ;
559
+ // t1
560
+ let oid = read_u32 ( & mut stream) . unwrap ( ) ;
561
+ assert_eq ! ( oid, t1_oid. as_u32( ) ) ;
562
+ let attr_num = read_array_len ( & mut stream) . unwrap ( ) ;
563
+ assert_eq ! ( attr_num, 2 ) ;
564
+ // a
565
+ let elem_num = read_array_len ( & mut stream) . unwrap ( ) ;
566
+ assert_eq ! ( elem_num, 3 ) ;
567
+ let name_len = read_str_len ( & mut stream) . unwrap ( ) ;
568
+ assert_eq ! ( name_len, "a" . len( ) as u32 ) ;
569
+ let name = stream. look_ahead ( name_len as usize ) . unwrap ( ) ;
570
+ assert_eq ! ( name, b"a" ) ;
571
+ stream. rewind ( name_len as usize ) . unwrap ( ) ;
572
+ let etype = read_u8 ( & mut stream) . unwrap ( ) ;
573
+ assert_eq ! ( etype, EncodedType :: Int32 as u8 ) ;
574
+ let is_nullable = read_bool ( & mut stream) . unwrap ( ) ;
575
+ assert ! ( !is_nullable) ;
576
+ // b
577
+ let elem_num = read_array_len ( & mut stream) . unwrap ( ) ;
578
+ assert_eq ! ( elem_num, 3 ) ;
579
+ let name_len = read_str_len ( & mut stream) . unwrap ( ) ;
580
+ assert_eq ! ( name_len, "b" . len( ) as u32 ) ;
581
+ let name = stream. look_ahead ( name_len as usize ) . unwrap ( ) ;
582
+ assert_eq ! ( name, b"b" ) ;
583
+ stream. rewind ( name_len as usize ) . unwrap ( ) ;
584
+ let etype = read_u8 ( & mut stream) . unwrap ( ) ;
585
+ assert_eq ! ( etype, EncodedType :: Utf8 as u8 ) ;
586
+ let is_nullable = read_bool ( & mut stream) . unwrap ( ) ;
587
+ assert ! ( is_nullable) ;
588
+ }
485
589
}
0 commit comments