Skip to content

Commit a3329ca

Browse files
committed
[flight] Enable TLS for tonic 0.12
1 parent 9975212 commit a3329ca

File tree

4 files changed

+35
-32
lines changed

4 files changed

+35
-32
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ trust-dns-resolver = "0.23.2"
6868
url = "2.5.1"
6969
pem = { version = "3.0.4", optional = true }
7070
tokio-rusqlite = { version = "0.5.1", optional = true }
71-
tonic = { version = "0.12.2", optional = true }
71+
tonic = { version = "0.12", optional = true, default-features = true, features = ["tls-native-roots", "tls-webpki-roots"] }
7272
itertools = "0.13.0"
7373
dyn-clone = { version = "1.0.17", optional = true }
7474
geo-types = "0.7.13"

src/flight.rs

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
2121
use std::any::Any;
2222
use std::collections::HashMap;
23+
use std::error::Error;
2324
use std::fmt::Debug;
2425
use std::sync::Arc;
2526

@@ -35,7 +36,7 @@ use datafusion::datasource::TableProvider;
3536
use datafusion::physical_plan::ExecutionPlan;
3637
use datafusion_expr::{CreateExternalTable, Expr, TableType};
3738
use serde::{Deserialize, Serialize};
38-
use tonic::transport::Channel;
39+
use tonic::transport::{Channel, ClientTlsConfig};
3940

4041
pub mod codec;
4142
mod exec;
@@ -107,16 +108,12 @@ impl FlightTableFactory {
107108
options: HashMap<String, String>,
108109
) -> datafusion::common::Result<FlightTable> {
109110
let origin = entry_point.into();
110-
let channel = Channel::from_shared(origin.clone())
111-
.unwrap()
112-
.connect()
113-
.await
114-
.map_err(|e| DataFusionError::External(Box::new(e)))?;
111+
let channel = flight_channel(&origin).await?;
115112
let metadata = self
116113
.driver
117114
.metadata(channel.clone(), &options)
118115
.await
119-
.map_err(|e| DataFusionError::External(Box::new(e)))?;
116+
.map_err(to_df_err)?;
120117
let num_rows = precision(metadata.info.total_records);
121118
let total_byte_size = precision(metadata.info.total_bytes);
122119
let logical_schema = metadata.schema;
@@ -136,14 +133,6 @@ impl FlightTableFactory {
136133
}
137134
}
138135

139-
fn precision(total: i64) -> Precision<usize> {
140-
if total < 0 {
141-
Precision::Absent
142-
} else {
143-
Precision::Exact(total as usize)
144-
}
145-
}
146-
147136
#[async_trait]
148137
impl TableProviderFactory for FlightTableFactory {
149138
async fn create(
@@ -292,7 +281,7 @@ impl TableProvider for FlightTable {
292281
.driver
293282
.metadata(self.channel.clone(), &self.options)
294283
.await
295-
.map_err(|e| DataFusionError::External(Box::new(e)))?;
284+
.map_err(to_df_err)?;
296285
Ok(Arc::new(FlightExec::try_new(
297286
metadata,
298287
projection,
@@ -304,3 +293,26 @@ impl TableProvider for FlightTable {
304293
Some(self.stats.clone())
305294
}
306295
}
296+
297+
fn to_df_err<E: Error + Send + Sync + 'static>(err: E) -> DataFusionError {
298+
DataFusionError::External(Box::new(err))
299+
}
300+
301+
async fn flight_channel(source: impl Into<String>) -> datafusion::common::Result<Channel> {
302+
let tls_config = ClientTlsConfig::new().with_enabled_roots();
303+
Channel::from_shared(source.into())
304+
.map_err(to_df_err)?
305+
.tls_config(tls_config)
306+
.map_err(to_df_err)?
307+
.connect()
308+
.await
309+
.map_err(to_df_err)
310+
}
311+
312+
fn precision(total: i64) -> Precision<usize> {
313+
if total < 0 {
314+
Precision::Absent
315+
} else {
316+
Precision::Exact(total as usize)
317+
}
318+
}

src/flight/codec.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
use std::sync::Arc;
2121

2222
use crate::flight::exec::{FlightConfig, FlightExec};
23+
use crate::flight::to_df_err;
2324
use datafusion::common::DataFusionError;
2425
use datafusion_expr::registry::FunctionRegistry;
2526
use datafusion_physical_plan::ExecutionPlan;
@@ -37,8 +38,7 @@ impl PhysicalExtensionCodec for FlightPhysicalCodec {
3738
_registry: &dyn FunctionRegistry,
3839
) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
3940
if inputs.is_empty() {
40-
let config: FlightConfig =
41-
serde_json::from_slice(buf).map_err(|e| DataFusionError::External(Box::new(e)))?;
41+
let config: FlightConfig = serde_json::from_slice(buf).map_err(to_df_err)?;
4242
Ok(Arc::from(FlightExec::from(config)))
4343
} else {
4444
Err(DataFusionError::Internal(
@@ -53,8 +53,7 @@ impl PhysicalExtensionCodec for FlightPhysicalCodec {
5353
buf: &mut Vec<u8>,
5454
) -> datafusion::common::Result<()> {
5555
if let Some(flight) = node.as_any().downcast_ref::<FlightExec>() {
56-
let mut bytes = serde_json::to_vec(flight.config())
57-
.map_err(|e| DataFusionError::External(Box::new(e)))?;
56+
let mut bytes = serde_json::to_vec(flight.config()).map_err(to_df_err)?;
5857
buf.append(&mut bytes);
5958
Ok(())
6059
} else {

src/flight/exec.rs

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use std::fmt::{Debug, Formatter};
2323
use std::str::FromStr;
2424
use std::sync::Arc;
2525

26-
use crate::flight::{FlightMetadata, FlightProperties, SizeLimits};
26+
use crate::flight::{flight_channel, to_df_err, FlightMetadata, FlightProperties, SizeLimits};
2727
use arrow_array::RecordBatch;
2828
use arrow_flight::error::FlightError;
2929
use arrow_flight::flight_service_client::FlightServiceClient;
@@ -41,7 +41,6 @@ use datafusion_physical_plan::{
4141
use futures::{StreamExt, TryStreamExt};
4242
use serde::{Deserialize, Serialize};
4343
use tonic::metadata::{AsciiMetadataKey, MetadataMap};
44-
use tonic::transport::Channel;
4544

4645
/// Arrow Flight physical plan that maps flight endpoints to partitions
4746
#[derive(Clone, Debug)]
@@ -170,11 +169,7 @@ async fn flight_client(
170169
grpc_headers: &MetadataMap,
171170
size_limits: &SizeLimits,
172171
) -> Result<FlightClient> {
173-
let channel = Channel::from_shared(source.into())
174-
.map_err(|e| DataFusionError::External(Box::new(e)))?
175-
.connect()
176-
.await
177-
.map_err(|e| DataFusionError::External(Box::new(e)))?;
172+
let channel = flight_channel(source).await?;
178173
let inner_client = FlightServiceClient::new(channel)
179174
.max_encoding_message_size(size_limits.encoding)
180175
.max_decoding_message_size(size_limits.decoding);
@@ -212,10 +207,7 @@ async fn try_fetch_stream(
212207
schema: SchemaRef,
213208
) -> arrow_flight::error::Result<SendableRecordBatchStream> {
214209
let ticket = Ticket::new(ticket.0.to_vec());
215-
let stream = client
216-
.do_get(ticket)
217-
.await?
218-
.map_err(|e| DataFusionError::External(Box::new(e)));
210+
let stream = client.do_get(ticket).await?.map_err(to_df_err);
219211
Ok(Box::pin(RecordBatchStreamAdapter::new(
220212
schema.clone(),
221213
stream.map(move |item| item.and_then(|rb| enforce_schema(rb, &schema).map_err(Into::into))),

0 commit comments

Comments
 (0)