1818//! Default [FlightDriver] for Flight SQL
1919
2020use std:: collections:: HashMap ;
21- use std:: str:: FromStr ;
2221
2322use arrow_flight:: error:: Result ;
24- use arrow_flight:: flight_service_client:: FlightServiceClient ;
25- use arrow_flight:: sql:: { CommandStatementQuery , ProstMessageExt } ;
26- use arrow_flight:: { FlightDescriptor , FlightInfo , HandshakeRequest , HandshakeResponse } ;
27- use arrow_schema:: ArrowError ;
23+ use arrow_flight:: sql:: client:: FlightSqlServiceClient ;
2824use async_trait:: async_trait;
29- use base64:: prelude:: BASE64_STANDARD ;
30- use base64:: Engine ;
31- use bytes:: Bytes ;
32- use futures:: { stream, TryStreamExt } ;
33- use prost:: Message ;
34- use tonic:: metadata:: AsciiMetadataKey ;
3525use tonic:: transport:: Channel ;
36- use tonic:: IntoRequest ;
3726
3827use crate :: flight:: { FlightDriver , FlightMetadata } ;
3928
@@ -60,7 +49,7 @@ impl FlightDriver for FlightSqlDriver {
6049 channel : Channel ,
6150 options : & HashMap < String , String > ,
6251 ) -> Result < FlightMetadata > {
63- let mut client = FlightSqlClient :: new ( channel) ;
52+ let mut client = FlightSqlServiceClient :: new ( channel) ;
6453 let headers = options. iter ( ) . filter_map ( |( key, value) | {
6554 key. strip_prefix ( HEADER_PREFIX )
6655 . map ( |header_name| ( header_name, value) )
@@ -75,147 +64,9 @@ impl FlightDriver for FlightSqlDriver {
7564 }
7665 let info = client. execute ( options[ QUERY ] . clone ( ) , None ) . await ?;
7766 let mut grpc_headers = HashMap :: default ( ) ;
78- if let Some ( token) = client. token {
67+ if let Some ( token) = client. token ( ) {
7968 grpc_headers. insert ( "authorization" . into ( ) , format ! ( "Bearer {}" , token) ) ;
8069 }
8170 FlightMetadata :: try_new ( info, grpc_headers)
8271 }
8372}
84-
85- /////////////////////////////////////////////////////////////////////////
86- // Shameless copy/paste from arrow-flight FlightSqlServiceClient
87- // (only cherry-picked the functionality that we actually use).
88- // This is only needed in order to access the bearer token received
89- // during handshake, as the standard client does not expose this information.
90- // The bearer token has to be passed to the clients that perform
91- // the DoGet operation, since Dremio, Ballista and possibly others
92- // expect the bearer token they produce with the handshake response
93- // to be set on all subsequent requests, including DoGet.
94- //
95- // TODO: remove this and switch to the official client once
96- // https://github.com/apache/arrow-rs/pull/6254 is released,
97- // and remove a bunch of cargo dependencies, like base64 or bytes
98- #[ derive( Debug , Clone ) ]
99- struct FlightSqlClient {
100- token : Option < String > ,
101- headers : HashMap < String , String > ,
102- flight_client : FlightServiceClient < Channel > ,
103- }
104-
105- impl FlightSqlClient {
106- /// Creates a new FlightSql client that connects to a server over an arbitrary tonic `Channel`
107- fn new ( channel : Channel ) -> Self {
108- Self {
109- token : None ,
110- flight_client : FlightServiceClient :: new ( channel) ,
111- headers : HashMap :: default ( ) ,
112- }
113- }
114-
115- /// Perform a `handshake` with the server, passing credentials and establishing a session.
116- ///
117- /// If the server returns an "authorization" header, it is automatically parsed and set as
118- /// a token for future requests. Any other data returned by the server in the handshake
119- /// response is returned as a binary blob.
120- async fn handshake (
121- & mut self ,
122- username : & str ,
123- password : & str ,
124- ) -> std:: result:: Result < Bytes , ArrowError > {
125- let cmd = HandshakeRequest {
126- protocol_version : 0 ,
127- payload : Default :: default ( ) ,
128- } ;
129- let mut req = tonic:: Request :: new ( stream:: iter ( vec ! [ cmd] ) ) ;
130- let val = BASE64_STANDARD . encode ( format ! ( "{username}:{password}" ) ) ;
131- let val = format ! ( "Basic {val}" )
132- . parse ( )
133- . map_err ( |_| ArrowError :: ParseError ( "Cannot parse header" . to_string ( ) ) ) ?;
134- req. metadata_mut ( ) . insert ( "authorization" , val) ;
135- let req = self . set_request_headers ( req) ?;
136- let resp = self
137- . flight_client
138- . handshake ( req)
139- . await
140- . map_err ( |e| ArrowError :: IpcError ( format ! ( "Can't handshake {e}" ) ) ) ?;
141- if let Some ( auth) = resp. metadata ( ) . get ( "authorization" ) {
142- let auth = auth
143- . to_str ( )
144- . map_err ( |_| ArrowError :: ParseError ( "Can't read auth header" . to_string ( ) ) ) ?;
145- let bearer = "Bearer " ;
146- if !auth. starts_with ( bearer) {
147- Err ( ArrowError :: ParseError ( "Invalid auth header!" . to_string ( ) ) ) ?;
148- }
149- let auth = auth[ bearer. len ( ) ..] . to_string ( ) ;
150- self . token = Some ( auth) ;
151- }
152- let responses: Vec < HandshakeResponse > = resp
153- . into_inner ( )
154- . try_collect ( )
155- . await
156- . map_err ( |_| ArrowError :: ParseError ( "Can't collect responses" . to_string ( ) ) ) ?;
157- let resp = match responses. as_slice ( ) {
158- [ resp] => resp. payload . clone ( ) ,
159- [ ] => Bytes :: new ( ) ,
160- _ => Err ( ArrowError :: ParseError (
161- "Multiple handshake responses" . to_string ( ) ,
162- ) ) ?,
163- } ;
164- Ok ( resp)
165- }
166-
167- async fn execute (
168- & mut self ,
169- query : String ,
170- transaction_id : Option < Bytes > ,
171- ) -> std:: result:: Result < FlightInfo , ArrowError > {
172- let cmd = CommandStatementQuery {
173- query,
174- transaction_id,
175- } ;
176- self . get_flight_info_for_command ( cmd) . await
177- }
178-
179- async fn get_flight_info_for_command < M : ProstMessageExt > (
180- & mut self ,
181- cmd : M ,
182- ) -> std:: result:: Result < FlightInfo , ArrowError > {
183- let descriptor = FlightDescriptor :: new_cmd ( cmd. as_any ( ) . encode_to_vec ( ) ) ;
184- let req = self . set_request_headers ( descriptor. into_request ( ) ) ?;
185- let fi = self
186- . flight_client
187- . get_flight_info ( req)
188- . await
189- . map_err ( |status| ArrowError :: IpcError ( format ! ( "{status:?}" ) ) ) ?
190- . into_inner ( ) ;
191- Ok ( fi)
192- }
193-
194- fn set_header ( & mut self , key : impl Into < String > , value : impl Into < String > ) {
195- let key: String = key. into ( ) ;
196- let value: String = value. into ( ) ;
197- self . headers . insert ( key, value) ;
198- }
199-
200- fn set_request_headers < T > (
201- & self ,
202- mut req : tonic:: Request < T > ,
203- ) -> std:: result:: Result < tonic:: Request < T > , ArrowError > {
204- for ( k, v) in & self . headers {
205- let k = AsciiMetadataKey :: from_str ( k. as_str ( ) ) . map_err ( |e| {
206- ArrowError :: ParseError ( format ! ( "Cannot convert header key \" {k}\" : {e}" ) )
207- } ) ?;
208- let v = v. parse ( ) . map_err ( |e| {
209- ArrowError :: ParseError ( format ! ( "Cannot convert header value \" {v}\" : {e}" ) )
210- } ) ?;
211- req. metadata_mut ( ) . insert ( k, v) ;
212- }
213- if let Some ( token) = & self . token {
214- let val = format ! ( "Bearer {token}" ) . parse ( ) . map_err ( |e| {
215- ArrowError :: ParseError ( format ! ( "Cannot convert token to header value: {e}" ) )
216- } ) ?;
217- req. metadata_mut ( ) . insert ( "authorization" , val) ;
218- }
219- Ok ( req)
220- }
221- }
0 commit comments