@@ -7,13 +7,16 @@ use futures_core::future::BoxFuture;
7
7
use crate :: cache:: StatementCache ;
8
8
use crate :: connection:: Connection ;
9
9
use crate :: io:: { Buf , BufStream } ;
10
- use crate :: postgres:: protocol:: { self , Decode , Encode , Message , StatementId , SaslResponse , SaslInitialResponse , hi, Authentication } ;
10
+ use crate :: postgres:: protocol:: {
11
+ self , hi, Authentication , Decode , Encode , Message , SaslInitialResponse , SaslResponse ,
12
+ StatementId ,
13
+ } ;
11
14
use crate :: postgres:: PgError ;
12
15
use crate :: url:: Url ;
13
- use sha2:: { Sha256 , Digest } ;
14
- use hmac:: { Mac , Hmac } ;
15
16
use crate :: Result ;
17
+ use hmac:: { Hmac , Mac } ;
16
18
use rand:: Rng ;
19
+ use sha2:: { Digest , Sha256 } ;
17
20
18
21
/// An asynchronous connection to a [Postgres] database.
19
22
///
@@ -97,20 +100,35 @@ impl PgConnection {
97
100
}
98
101
99
102
protocol:: Authentication :: Sasl { mechanisms } => {
100
- match mechanisms. get ( 0 ) . map ( |m| & * * m) {
101
- Some ( "SCRAM-SHA-256" ) => {
102
- sasl_auth (
103
- self ,
104
- username,
105
- url. password ( ) . unwrap_or_default ( ) ,
106
- )
107
- . await ?;
103
+ let mut has_sasl: bool = false ;
104
+ let mut has_sasl_plus: bool = false ;
105
+
106
+ for mechanism in & * mechanisms {
107
+ match & * * mechanism {
108
+ "SCRAM-SHA-256" => {
109
+ has_sasl = true ;
110
+ }
111
+
112
+ "SCRAM-SHA-256-PLUS" => {
113
+ has_sasl_plus = true ;
114
+ }
115
+
116
+ _ => {
117
+ log:: info!( "unsupported auth mechanism: {}" , mechanism) ;
118
+ }
108
119
}
120
+ }
109
121
110
- _ => return Err ( protocol_err ! (
111
- "Expected mechanisms SCRAM-SHA-256, but received {:?}" ,
122
+ if has_sasl || has_sasl_plus {
123
+ // TODO: Handle -PLUS differently if we're in a TLS stream
124
+ sasl_auth ( self , username, url. password ( ) . unwrap_or_default ( ) )
125
+ . await ?;
126
+ } else {
127
+ return Err ( protocol_err ! (
128
+ "unsupported SASL auth mechanisms: {:?}" ,
112
129
mechanisms
113
- ) . into ( ) ) ,
130
+ )
131
+ . into ( ) ) ;
114
132
}
115
133
}
116
134
@@ -288,11 +306,7 @@ fn nonce() -> String {
288
306
289
307
// Performs authenticiton using Simple Authentication Security Layer (SASL) which is what
290
308
// Postgres uses
291
- async fn sasl_auth < T : AsRef < str > > (
292
- conn : & mut PgConnection ,
293
- username : T ,
294
- password : T ,
295
- ) -> Result < ( ) > {
309
+ async fn sasl_auth < T : AsRef < str > > ( conn : & mut PgConnection , username : T , password : T ) -> Result < ( ) > {
296
310
// channel-binding = "c=" base64
297
311
let channel_binding = format ! ( "{}={}" , CHANNEL_ATTR , base64:: encode( GS2_HEADER ) ) ;
298
312
// "n=" saslname ;; Usernames are prepared using SASLprep.
@@ -308,8 +322,7 @@ async fn sasl_auth<T: AsRef<str>>(
308
322
client_first_message_bare = client_first_message_bare
309
323
) ;
310
324
311
- SaslInitialResponse ( & client_first_message)
312
- . encode ( conn. stream . buffer_mut ( ) ) ;
325
+ SaslInitialResponse ( & client_first_message) . encode ( conn. stream . buffer_mut ( ) ) ;
313
326
conn. stream . flush ( ) . await ?;
314
327
315
328
let server_first_message = conn. receive ( ) . await ?;
@@ -379,8 +392,7 @@ async fn sasl_auth<T: AsRef<str>>(
379
392
client_proof = base64:: encode( & client_proof)
380
393
) ;
381
394
382
- SaslResponse ( & client_final_message)
383
- . encode ( conn. stream . buffer_mut ( ) ) ;
395
+ SaslResponse ( & client_final_message) . encode ( conn. stream . buffer_mut ( ) ) ;
384
396
conn. stream . flush ( ) . await ?;
385
397
let _server_final_response = conn. receive ( ) . await ?;
386
398
0 commit comments