Skip to content

Commit 9332074

Browse files
authored
Merge branch 'main' into rebase-spiceai-main-branches
2 parents 2cf7ade + 55cfc67 commit 9332074

File tree

2 files changed

+6
-159
lines changed

2 files changed

+6
-159
lines changed

Cargo.toml

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "datafusion-table-providers"
3-
version = "0.1.0"
3+
version = "0.2.0"
44
readme = "README.md"
55
edition = "2021"
66
repository = "https://github.com/datafusion-contrib/datafusion-table-providers"
@@ -17,8 +17,6 @@ arrow-json = "53"
1717
async-stream = { version = "0.3.5", optional = true }
1818
async-trait = "0.1.80"
1919
num-bigint = "0.4.4"
20-
base64 = { version = "0.22.1", optional = true }
21-
bytes = { version = "1.7.1", optional = true }
2220
bigdecimal = "0.4.5"
2321
byteorder = "1.5.0"
2422
chrono = "0.4.38"
@@ -87,7 +85,8 @@ test-log = { version = "0.2.16", features = ["trace"] }
8785
rstest = "0.22.0"
8886
geozero = { version = "0.13.0", features = ["with-wkb"] }
8987
tokio-stream = { version = "0.1.15", features = ["net"] }
90-
arrow-schema = "52.2.0"
88+
arrow-schema = "53.1.0"
89+
prost = { version = "0.13"}
9190

9291
[features]
9392
mysql = ["dep:mysql_async", "dep:async-stream"]
@@ -99,13 +98,10 @@ flight = [
9998
"dep:arrow-cast",
10099
"dep:arrow-flight",
101100
"dep:arrow-schema",
102-
"dep:base64",
103-
"dep:bytes",
104101
"dep:datafusion-expr",
105102
"dep:datafusion-physical-expr",
106103
"dep:datafusion-physical-plan",
107104
"dep:datafusion-proto",
108-
"dep:prost",
109105
"dep:serde",
110106
"dep:tonic",
111107
]

src/flight/sql.rs

Lines changed: 3 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,11 @@
1818
//! Default [FlightDriver] for Flight SQL
1919
2020
use std::collections::HashMap;
21-
use std::str::FromStr;
2221

2322
use arrow_flight::error::Result;
24-
use arrow_flight::flight_service_client::FlightServiceClient;
25-
use arrow_flight::sql::{CommandStatementQuery, ProstMessageExt};
26-
use arrow_flight::{FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse};
27-
use arrow_schema::ArrowError;
23+
use arrow_flight::sql::client::FlightSqlServiceClient;
2824
use async_trait::async_trait;
29-
use base64::prelude::BASE64_STANDARD;
30-
use base64::Engine;
31-
use bytes::Bytes;
32-
use futures::{stream, TryStreamExt};
33-
use prost::Message;
34-
use tonic::metadata::AsciiMetadataKey;
3525
use tonic::transport::Channel;
36-
use tonic::IntoRequest;
3726

3827
use crate::flight::{FlightDriver, FlightMetadata};
3928

@@ -60,7 +49,7 @@ impl FlightDriver for FlightSqlDriver {
6049
channel: Channel,
6150
options: &HashMap<String, String>,
6251
) -> Result<FlightMetadata> {
63-
let mut client = FlightSqlClient::new(channel);
52+
let mut client = FlightSqlServiceClient::new(channel);
6453
let headers = options.iter().filter_map(|(key, value)| {
6554
key.strip_prefix(HEADER_PREFIX)
6655
.map(|header_name| (header_name, value))
@@ -75,147 +64,9 @@ impl FlightDriver for FlightSqlDriver {
7564
}
7665
let info = client.execute(options[QUERY].clone(), None).await?;
7766
let mut grpc_headers = HashMap::default();
78-
if let Some(token) = client.token {
67+
if let Some(token) = client.token() {
7968
grpc_headers.insert("authorization".into(), format!("Bearer {}", token));
8069
}
8170
FlightMetadata::try_new(info, grpc_headers)
8271
}
8372
}
84-
85-
/////////////////////////////////////////////////////////////////////////
86-
// Shameless copy/paste from arrow-flight FlightSqlServiceClient
87-
// (only cherry-picked the functionality that we actually use).
88-
// This is only needed in order to access the bearer token received
89-
// during handshake, as the standard client does not expose this information.
90-
// The bearer token has to be passed to the clients that perform
91-
// the DoGet operation, since Dremio, Ballista and possibly others
92-
// expect the bearer token they produce with the handshake response
93-
// to be set on all subsequent requests, including DoGet.
94-
//
95-
// TODO: remove this and switch to the official client once
96-
// https://github.com/apache/arrow-rs/pull/6254 is released,
97-
// and remove a bunch of cargo dependencies, like base64 or bytes
98-
#[derive(Debug, Clone)]
99-
struct FlightSqlClient {
100-
token: Option<String>,
101-
headers: HashMap<String, String>,
102-
flight_client: FlightServiceClient<Channel>,
103-
}
104-
105-
impl FlightSqlClient {
106-
/// Creates a new FlightSql client that connects to a server over an arbitrary tonic `Channel`
107-
fn new(channel: Channel) -> Self {
108-
Self {
109-
token: None,
110-
flight_client: FlightServiceClient::new(channel),
111-
headers: HashMap::default(),
112-
}
113-
}
114-
115-
/// Perform a `handshake` with the server, passing credentials and establishing a session.
116-
///
117-
/// If the server returns an "authorization" header, it is automatically parsed and set as
118-
/// a token for future requests. Any other data returned by the server in the handshake
119-
/// response is returned as a binary blob.
120-
async fn handshake(
121-
&mut self,
122-
username: &str,
123-
password: &str,
124-
) -> std::result::Result<Bytes, ArrowError> {
125-
let cmd = HandshakeRequest {
126-
protocol_version: 0,
127-
payload: Default::default(),
128-
};
129-
let mut req = tonic::Request::new(stream::iter(vec![cmd]));
130-
let val = BASE64_STANDARD.encode(format!("{username}:{password}"));
131-
let val = format!("Basic {val}")
132-
.parse()
133-
.map_err(|_| ArrowError::ParseError("Cannot parse header".to_string()))?;
134-
req.metadata_mut().insert("authorization", val);
135-
let req = self.set_request_headers(req)?;
136-
let resp = self
137-
.flight_client
138-
.handshake(req)
139-
.await
140-
.map_err(|e| ArrowError::IpcError(format!("Can't handshake {e}")))?;
141-
if let Some(auth) = resp.metadata().get("authorization") {
142-
let auth = auth
143-
.to_str()
144-
.map_err(|_| ArrowError::ParseError("Can't read auth header".to_string()))?;
145-
let bearer = "Bearer ";
146-
if !auth.starts_with(bearer) {
147-
Err(ArrowError::ParseError("Invalid auth header!".to_string()))?;
148-
}
149-
let auth = auth[bearer.len()..].to_string();
150-
self.token = Some(auth);
151-
}
152-
let responses: Vec<HandshakeResponse> = resp
153-
.into_inner()
154-
.try_collect()
155-
.await
156-
.map_err(|_| ArrowError::ParseError("Can't collect responses".to_string()))?;
157-
let resp = match responses.as_slice() {
158-
[resp] => resp.payload.clone(),
159-
[] => Bytes::new(),
160-
_ => Err(ArrowError::ParseError(
161-
"Multiple handshake responses".to_string(),
162-
))?,
163-
};
164-
Ok(resp)
165-
}
166-
167-
async fn execute(
168-
&mut self,
169-
query: String,
170-
transaction_id: Option<Bytes>,
171-
) -> std::result::Result<FlightInfo, ArrowError> {
172-
let cmd = CommandStatementQuery {
173-
query,
174-
transaction_id,
175-
};
176-
self.get_flight_info_for_command(cmd).await
177-
}
178-
179-
async fn get_flight_info_for_command<M: ProstMessageExt>(
180-
&mut self,
181-
cmd: M,
182-
) -> std::result::Result<FlightInfo, ArrowError> {
183-
let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
184-
let req = self.set_request_headers(descriptor.into_request())?;
185-
let fi = self
186-
.flight_client
187-
.get_flight_info(req)
188-
.await
189-
.map_err(|status| ArrowError::IpcError(format!("{status:?}")))?
190-
.into_inner();
191-
Ok(fi)
192-
}
193-
194-
fn set_header(&mut self, key: impl Into<String>, value: impl Into<String>) {
195-
let key: String = key.into();
196-
let value: String = value.into();
197-
self.headers.insert(key, value);
198-
}
199-
200-
fn set_request_headers<T>(
201-
&self,
202-
mut req: tonic::Request<T>,
203-
) -> std::result::Result<tonic::Request<T>, ArrowError> {
204-
for (k, v) in &self.headers {
205-
let k = AsciiMetadataKey::from_str(k.as_str()).map_err(|e| {
206-
ArrowError::ParseError(format!("Cannot convert header key \"{k}\": {e}"))
207-
})?;
208-
let v = v.parse().map_err(|e| {
209-
ArrowError::ParseError(format!("Cannot convert header value \"{v}\": {e}"))
210-
})?;
211-
req.metadata_mut().insert(k, v);
212-
}
213-
if let Some(token) = &self.token {
214-
let val = format!("Bearer {token}").parse().map_err(|e| {
215-
ArrowError::ParseError(format!("Cannot convert token to header value: {e}"))
216-
})?;
217-
req.metadata_mut().insert("authorization", val);
218-
}
219-
Ok(req)
220-
}
221-
}

0 commit comments

Comments
 (0)