Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ datafusion-expr = { version = "42.0.0", optional = true }
datafusion-physical-expr = { version = "42.0.0", optional = true }
datafusion-physical-plan = { version = "42.0.0", optional = true }
datafusion-proto = { version = "42.0.0", optional = true }
datafusion-federation = { version = "0.3.0", features = ["sql"], optional = true }
datafusion-federation = { version = "=0.3.0", features = ["sql"], optional = true }
duckdb = { version = "1.1.1", features = [
"bundled",
"r2d2",
Expand Down
2 changes: 1 addition & 1 deletion examples/flight-sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use std::sync::Arc;
#[tokio::main]
async fn main() -> datafusion::common::Result<()> {
let ctx = SessionContext::new();
let flight_sql = FlightTableFactory::new(Arc::new(FlightSqlDriver::default()));
let flight_sql = FlightTableFactory::new(Arc::new(FlightSqlDriver::new()));
let table = flight_sql
.open_table(
"http://localhost:32010",
Expand Down
80 changes: 62 additions & 18 deletions src/flight.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,23 @@ impl FlightTableFactory {
.map_err(to_df_err)?;
let num_rows = precision(metadata.info.total_records);
let total_byte_size = precision(metadata.info.total_bytes);
let logical_schema = metadata.schema;
let logical_schema = metadata.schema.clone();
let stats = Statistics {
num_rows,
total_byte_size,
column_statistics: vec![],
};
let metadata_supplier = if metadata.props.reusable_flight_info {
MetadataSupplier::Reusable(Arc::new(metadata))
} else {
MetadataSupplier::Refresh {
driver: self.driver.clone(),
channel,
options,
}
};
Ok(FlightTable {
driver: self.driver.clone(),
channel,
options,
metadata_supplier,
origin,
logical_schema,
stats,
Expand Down Expand Up @@ -203,26 +210,41 @@ impl TryFrom<FlightInfo> for FlightMetadata {
/// for controlling the protocol and query execution details.
#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
pub struct FlightProperties {
unbounded_stream: bool,
unbounded_streams: bool,
grpc_headers: HashMap<String, String>,
size_limits: SizeLimits,
reusable_flight_info: bool,
}

impl FlightProperties {
pub fn unbounded_stream(mut self, unbounded_stream: bool) -> Self {
self.unbounded_stream = unbounded_stream;
pub fn new() -> Self {
Default::default()
}

/// Whether the service will produce infinite streams
pub fn with_unbounded_streams(mut self, unbounded_streams: bool) -> Self {
self.unbounded_streams = unbounded_streams;
self
}

pub fn grpc_headers(mut self, grpc_headers: HashMap<String, String>) -> Self {
/// gRPC headers to use on subsequent calls.
pub fn with_grpc_headers(mut self, grpc_headers: HashMap<String, String>) -> Self {
self.grpc_headers = grpc_headers;
self
}

pub fn size_limits(mut self, size_limits: SizeLimits) -> Self {
/// Max sizes in bytes for encoded/decoded gRPC messages.
pub fn with_size_limits(mut self, size_limits: SizeLimits) -> Self {
self.size_limits = size_limits;
self
}

/// Whether the FlightInfo objects produced by the service can be used multiple times
/// or need to be refreshed before every table scan.
pub fn with_reusable_flight_info(mut self, reusable_flight_info: bool) -> Self {
self.reusable_flight_info = reusable_flight_info;
self
}
}

/// Message size limits to be passed to the underlying gRPC library.
Expand All @@ -248,12 +270,38 @@ impl Default for SizeLimits {
}
}

#[derive(Clone, Debug)]
enum MetadataSupplier {
Reusable(Arc<FlightMetadata>),
Refresh {
driver: Arc<dyn FlightDriver>,
channel: Channel,
options: HashMap<String, String>,
},
}

impl MetadataSupplier {
async fn flight_metadata(&self) -> datafusion::common::Result<Arc<FlightMetadata>> {
match self {
Self::Reusable(metadata) => Ok(metadata.clone()),
Self::Refresh {
driver,
channel,
options,
} => Ok(Arc::new(
driver
.metadata(channel.clone(), options)
.await
.map_err(to_df_err)?,
)),
}
}
}

/// Table provider that wraps a specific flight from an Arrow Flight service
#[derive(Debug)]
pub struct FlightTable {
driver: Arc<dyn FlightDriver>,
channel: Channel,
options: HashMap<String, String>,
metadata_supplier: MetadataSupplier,
origin: String,
logical_schema: SchemaRef,
stats: Statistics,
Expand All @@ -280,13 +328,9 @@ impl TableProvider for FlightTable {
_filters: &[Expr],
_limit: Option<usize>,
) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
let metadata = self
.driver
.metadata(self.channel.clone(), &self.options)
.await
.map_err(to_df_err)?;
let metadata = self.metadata_supplier.flight_metadata().await?;
Ok(Arc::new(FlightExec::try_new(
metadata,
metadata.as_ref(),
projection,
&self.origin,
)?))
Expand Down
12 changes: 6 additions & 6 deletions src/flight/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl FlightExec {
/// Creates a FlightExec with the provided [FlightMetadata]
/// and origin URL (used as fallback location as per the protocol spec).
pub fn try_new(
metadata: FlightMetadata,
metadata: &FlightMetadata,
projection: Option<&Vec<usize>>,
origin: &str,
) -> Result<Self> {
Expand All @@ -70,7 +70,7 @@ impl FlightExec {
origin: origin.into(),
schema,
partitions,
properties: metadata.props,
properties: metadata.props.clone(),
};
Ok(config.into())
}
Expand All @@ -82,7 +82,7 @@ impl FlightExec {

impl From<FlightConfig> for FlightExec {
fn from(config: FlightConfig) -> Self {
let exec_mode = if config.properties.unbounded_stream {
let exec_mode = if config.properties.unbounded_streams {
ExecutionMode::Unbounded
} else {
ExecutionMode::Bounded
Expand Down Expand Up @@ -347,12 +347,12 @@ mod tests {
]
.into();
let properties = FlightProperties::default()
.unbounded_stream(true)
.grpc_headers(HashMap::from([
.with_unbounded_streams(true)
.with_grpc_headers(HashMap::from([
("h1".into(), "v1".into()),
("h2".into(), "v2".into()),
]))
.size_limits(SizeLimits::new(1024, 1024));
.with_size_limits(SizeLimits::new(1024, 1024));
let config = FlightConfig {
origin: "http://localhost:50050".into(),
schema,
Expand Down
48 changes: 41 additions & 7 deletions src/flight/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,31 @@ pub const HEADER_PREFIX: &str = "flight.sql.header.";
/// stored as a gRPC authorization header within the returned [FlightMetadata],
/// to be sent with the subsequent `DoGet` requests.
#[derive(Clone, Debug, Default)]
pub struct FlightSqlDriver {}
pub struct FlightSqlDriver {
properties_template: FlightProperties,
persistent_headers: bool,
}

impl FlightSqlDriver {
pub fn new() -> Self {
Default::default()
}

/// Custom flight properties to be returned from the metadata call instead of the default ones.
/// The headers (if any) will only be used for the Handshake/GetFlightInfo calls by default.
/// This behaviour can be changed by calling [Self::with_persistent_headers] below.
/// Headers provided as options for the metadata call will overwrite the template ones.
pub fn with_properties_template(mut self, properties_template: FlightProperties) -> Self {
self.properties_template = properties_template;
self
}

/// Propagate the static headers configured for Handshake/GetFlightInfo to the subsequent DoGet calls.
pub fn with_persistent_headers(mut self, persistent_headers: bool) -> Self {
self.persistent_headers = persistent_headers;
self
}
}

#[async_trait]
impl FlightDriver for FlightSqlDriver {
Expand All @@ -50,11 +74,13 @@ impl FlightDriver for FlightSqlDriver {
options: &HashMap<String, String>,
) -> Result<FlightMetadata> {
let mut client = FlightSqlServiceClient::new(channel);
let headers = options.iter().filter_map(|(key, value)| {
let mut handshake_headers = self.properties_template.grpc_headers.clone();
let headers_overlay = options.iter().filter_map(|(key, value)| {
key.strip_prefix(HEADER_PREFIX)
.map(|header_name| (header_name, value))
.map(|header_name| (header_name.to_owned(), value.to_owned()))
});
for (name, value) in headers {
handshake_headers.extend(headers_overlay);
for (name, value) in &handshake_headers {
client.set_header(name, value)
}
if let Some(username) = options.get(USERNAME) {
Expand All @@ -63,10 +89,18 @@ impl FlightDriver for FlightSqlDriver {
client.handshake(username, password).await.ok();
}
let info = client.execute(options[QUERY].clone(), None).await?;
let mut grpc_headers = HashMap::default();
let mut partition_headers = if self.persistent_headers {
handshake_headers
} else {
HashMap::default()
};
if let Some(token) = client.token() {
grpc_headers.insert("authorization".into(), format!("Bearer {}", token));
partition_headers.insert("authorization".into(), format!("Bearer {token}"));
}
FlightMetadata::try_new(info, FlightProperties::default().grpc_headers(grpc_headers))
let props = self
.properties_template
.clone()
.with_grpc_headers(partition_headers);
FlightMetadata::try_new(info, props)
}
}
8 changes: 4 additions & 4 deletions tests/flight/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use tonic::transport::Server;
use tonic::{Extensions, Request, Response, Status, Streaming};

use datafusion_table_providers::flight::sql::FlightSqlDriver;
use datafusion_table_providers::flight::FlightTableFactory;
use datafusion_table_providers::flight::{FlightProperties, FlightTableFactory};

const AUTH_HEADER: &str = "authorization";
const BEARER_TOKEN: &str = "Bearer flight-sql-token";
Expand Down Expand Up @@ -191,11 +191,11 @@ async fn test_flight_sql_data_source() -> datafusion::common::Result<()> {
};
let port = service.run_in_background(rx).await.port();
let ctx = SessionContext::new();
let props_template = FlightProperties::new().with_reusable_flight_info(true);
let driver = FlightSqlDriver::new().with_properties_template(props_template);
ctx.state_ref().write().table_factories_mut().insert(
"FLIGHT_SQL".into(),
Arc::new(FlightTableFactory::new(
Arc::new(FlightSqlDriver::default()),
)),
Arc::new(FlightTableFactory::new(Arc::new(driver))),
);
let _ = ctx
.sql(&format!(
Expand Down