Skip to content

Commit cd001ac

Browse files
committed
[flight] Enhanced customizability
1 parent a9e1469 commit cd001ac

File tree

5 files changed

+114
-36
lines changed

5 files changed

+114
-36
lines changed

examples/flight-sql.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use std::sync::Arc;
2929
#[tokio::main]
3030
async fn main() -> datafusion::common::Result<()> {
3131
let ctx = SessionContext::new();
32-
let flight_sql = FlightTableFactory::new(Arc::new(FlightSqlDriver::default()));
32+
let flight_sql = FlightTableFactory::new(Arc::new(FlightSqlDriver::new()));
3333
let table = flight_sql
3434
.open_table(
3535
"http://localhost:32010",

src/flight.rs

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -118,16 +118,23 @@ impl FlightTableFactory {
118118
.map_err(to_df_err)?;
119119
let num_rows = precision(metadata.info.total_records);
120120
let total_byte_size = precision(metadata.info.total_bytes);
121-
let logical_schema = metadata.schema;
121+
let logical_schema = metadata.schema.clone();
122122
let stats = Statistics {
123123
num_rows,
124124
total_byte_size,
125125
column_statistics: vec![],
126126
};
127+
let metadata_supplier = if metadata.props.reusable_flight_info {
128+
MetadataSupplier::Reusable(Arc::new(metadata))
129+
} else {
130+
MetadataSupplier::Refresh {
131+
driver: self.driver.clone(),
132+
channel,
133+
options,
134+
}
135+
};
127136
Ok(FlightTable {
128-
driver: self.driver.clone(),
129-
channel,
130-
options,
137+
metadata_supplier,
131138
origin,
132139
logical_schema,
133140
stats,
@@ -203,26 +210,41 @@ impl TryFrom<FlightInfo> for FlightMetadata {
203210
/// for controlling the protocol and query execution details.
204211
#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
205212
pub struct FlightProperties {
206-
unbounded_stream: bool,
213+
unbounded_streams: bool,
207214
grpc_headers: HashMap<String, String>,
208215
size_limits: SizeLimits,
216+
reusable_flight_info: bool,
209217
}
210218

211219
impl FlightProperties {
212-
pub fn unbounded_stream(mut self, unbounded_stream: bool) -> Self {
213-
self.unbounded_stream = unbounded_stream;
220+
pub fn new() -> Self {
221+
Default::default()
222+
}
223+
224+
/// Whether the service will produce infinite streams
225+
pub fn with_unbounded_streams(mut self, unbounded_streams: bool) -> Self {
226+
self.unbounded_streams = unbounded_streams;
214227
self
215228
}
216229

217-
pub fn grpc_headers(mut self, grpc_headers: HashMap<String, String>) -> Self {
230+
/// GRPC headers that to use on subsequent calls.
231+
pub fn with_grpc_headers(mut self, grpc_headers: HashMap<String, String>) -> Self {
218232
self.grpc_headers = grpc_headers;
219233
self
220234
}
221235

222-
pub fn size_limits(mut self, size_limits: SizeLimits) -> Self {
236+
/// Max sizes in bytes for encoded/decoded GRPC messages.
237+
pub fn with_size_limits(mut self, size_limits: SizeLimits) -> Self {
223238
self.size_limits = size_limits;
224239
self
225240
}
241+
242+
/// Whether the FlightInfo objects produced by the service can be used multiple times
243+
/// or need to be refreshed before every table scan.
244+
pub fn with_reusable_flight_info(mut self, reusable_flight_info: bool) -> Self {
245+
self.reusable_flight_info = reusable_flight_info;
246+
self
247+
}
226248
}
227249

228250
/// Message size limits to be passed to the underlying gRPC library.
@@ -248,12 +270,38 @@ impl Default for SizeLimits {
248270
}
249271
}
250272

273+
#[derive(Clone, Debug)]
274+
enum MetadataSupplier {
275+
Reusable(Arc<FlightMetadata>),
276+
Refresh {
277+
driver: Arc<dyn FlightDriver>,
278+
channel: Channel,
279+
options: HashMap<String, String>,
280+
},
281+
}
282+
283+
impl MetadataSupplier {
284+
async fn flight_metadata(&self) -> datafusion::common::Result<Arc<FlightMetadata>> {
285+
match self {
286+
Self::Reusable(metadata) => Ok(metadata.clone()),
287+
Self::Refresh {
288+
driver,
289+
channel,
290+
options,
291+
} => Ok(Arc::new(
292+
driver
293+
.metadata(channel.clone(), options)
294+
.await
295+
.map_err(to_df_err)?,
296+
)),
297+
}
298+
}
299+
}
300+
251301
/// Table provider that wraps a specific flight from an Arrow Flight service
252302
#[derive(Debug)]
253303
pub struct FlightTable {
254-
driver: Arc<dyn FlightDriver>,
255-
channel: Channel,
256-
options: HashMap<String, String>,
304+
metadata_supplier: MetadataSupplier,
257305
origin: String,
258306
logical_schema: SchemaRef,
259307
stats: Statistics,
@@ -280,13 +328,9 @@ impl TableProvider for FlightTable {
280328
_filters: &[Expr],
281329
_limit: Option<usize>,
282330
) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
283-
let metadata = self
284-
.driver
285-
.metadata(self.channel.clone(), &self.options)
286-
.await
287-
.map_err(to_df_err)?;
331+
let metadata = self.metadata_supplier.flight_metadata().await?;
288332
Ok(Arc::new(FlightExec::try_new(
289-
metadata,
333+
metadata.as_ref(),
290334
projection,
291335
&self.origin,
292336
)?))

src/flight/exec.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ impl FlightExec {
5555
/// Creates a FlightExec with the provided [FlightMetadata]
5656
/// and origin URL (used as fallback location as per the protocol spec).
5757
pub fn try_new(
58-
metadata: FlightMetadata,
58+
metadata: &FlightMetadata,
5959
projection: Option<&Vec<usize>>,
6060
origin: &str,
6161
) -> Result<Self> {
@@ -70,7 +70,7 @@ impl FlightExec {
7070
origin: origin.into(),
7171
schema,
7272
partitions,
73-
properties: metadata.props,
73+
properties: metadata.props.clone(),
7474
};
7575
Ok(config.into())
7676
}
@@ -82,7 +82,7 @@ impl FlightExec {
8282

8383
impl From<FlightConfig> for FlightExec {
8484
fn from(config: FlightConfig) -> Self {
85-
let exec_mode = if config.properties.unbounded_stream {
85+
let exec_mode = if config.properties.unbounded_streams {
8686
ExecutionMode::Unbounded
8787
} else {
8888
ExecutionMode::Bounded
@@ -347,12 +347,12 @@ mod tests {
347347
]
348348
.into();
349349
let properties = FlightProperties::default()
350-
.unbounded_stream(true)
351-
.grpc_headers(HashMap::from([
350+
.with_unbounded_streams(true)
351+
.with_grpc_headers(HashMap::from([
352352
("h1".into(), "v1".into()),
353353
("h2".into(), "v2".into()),
354354
]))
355-
.size_limits(SizeLimits::new(1024, 1024));
355+
.with_size_limits(SizeLimits::new(1024, 1024));
356356
let config = FlightConfig {
357357
origin: "http://localhost:50050".into(),
358358
schema,

src/flight/sql.rs

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,31 @@ pub const HEADER_PREFIX: &str = "flight.sql.header.";
4040
/// stored as a gRPC authorization header within the returned [FlightMetadata],
4141
/// to be sent with the subsequent `DoGet` requests.
4242
#[derive(Clone, Debug, Default)]
43-
pub struct FlightSqlDriver {}
43+
pub struct FlightSqlDriver {
44+
properties_template: FlightProperties,
45+
persistent_headers: bool,
46+
}
47+
48+
impl FlightSqlDriver {
49+
pub fn new() -> Self {
50+
Default::default()
51+
}
52+
53+
/// Custom flight properties to be returned from the metadata call instead of the default ones.
54+
/// The headers (if any) will only be used for the Handshake/GetFlightInfo calls by default.
55+
/// This behaviour can be changed by calling [Self::with_persistent_headers] below.
56+
/// Headers provided as options for the metadata call will overwrite the template ones.
57+
pub fn with_properties_template(mut self, properties_template: FlightProperties) -> Self {
58+
self.properties_template = properties_template;
59+
self
60+
}
61+
62+
/// Propagate the static headers configured for Handshake/GetFlightInfo to the subsequent DoGet calls.
63+
pub fn with_persistent_headers(mut self, persistent_headers: bool) -> Self {
64+
self.persistent_headers = persistent_headers;
65+
self
66+
}
67+
}
4468

4569
#[async_trait]
4670
impl FlightDriver for FlightSqlDriver {
@@ -50,11 +74,13 @@ impl FlightDriver for FlightSqlDriver {
5074
options: &HashMap<String, String>,
5175
) -> Result<FlightMetadata> {
5276
let mut client = FlightSqlServiceClient::new(channel);
53-
let headers = options.iter().filter_map(|(key, value)| {
77+
let mut handshake_headers = self.properties_template.grpc_headers.clone();
78+
let headers_overlay = options.into_iter().filter_map(|(key, value)| {
5479
key.strip_prefix(HEADER_PREFIX)
55-
.map(|header_name| (header_name, value))
80+
.map(|header_name| (header_name.to_owned(), value.to_owned()))
5681
});
57-
for (name, value) in headers {
82+
handshake_headers.extend(headers_overlay);
83+
for (name, value) in &handshake_headers {
5884
client.set_header(name, value)
5985
}
6086
if let Some(username) = options.get(USERNAME) {
@@ -63,10 +89,18 @@ impl FlightDriver for FlightSqlDriver {
6389
client.handshake(username, password).await.ok();
6490
}
6591
let info = client.execute(options[QUERY].clone(), None).await?;
66-
let mut grpc_headers = HashMap::default();
92+
let mut partition_headers = if self.persistent_headers {
93+
handshake_headers
94+
} else {
95+
HashMap::default()
96+
};
6797
if let Some(token) = client.token() {
68-
grpc_headers.insert("authorization".into(), format!("Bearer {}", token));
98+
partition_headers.insert("authorization".into(), format!("Bearer {token}"));
6999
}
70-
FlightMetadata::try_new(info, FlightProperties::default().grpc_headers(grpc_headers))
100+
let props = self
101+
.properties_template
102+
.clone()
103+
.with_grpc_headers(partition_headers);
104+
FlightMetadata::try_new(info, props)
71105
}
72106
}

tests/flight/mod.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ use tonic::transport::Server;
2828
use tonic::{Extensions, Request, Response, Status, Streaming};
2929

3030
use datafusion_table_providers::flight::sql::FlightSqlDriver;
31-
use datafusion_table_providers::flight::FlightTableFactory;
31+
use datafusion_table_providers::flight::{FlightProperties, FlightTableFactory};
3232

3333
const AUTH_HEADER: &str = "authorization";
3434
const BEARER_TOKEN: &str = "Bearer flight-sql-token";
@@ -191,11 +191,11 @@ async fn test_flight_sql_data_source() -> datafusion::common::Result<()> {
191191
};
192192
let port = service.run_in_background(rx).await.port();
193193
let ctx = SessionContext::new();
194+
let props_template = FlightProperties::new().with_reusable_flight_info(true);
195+
let driver = FlightSqlDriver::new().with_properties_template(props_template);
194196
ctx.state_ref().write().table_factories_mut().insert(
195197
"FLIGHT_SQL".into(),
196-
Arc::new(FlightTableFactory::new(
197-
Arc::new(FlightSqlDriver::default()),
198-
)),
198+
Arc::new(FlightTableFactory::new(Arc::new(driver))),
199199
);
200200
let _ = ctx
201201
.sql(&format!(

0 commit comments

Comments
 (0)