12
12
// See the License for the specific language governing permissions and
13
13
// limitations under the License.
14
14
15
+ use std:: collections:: HashMap ;
15
16
use std:: sync:: Arc ;
16
17
17
18
use common_exception:: ErrorCode ;
@@ -20,6 +21,7 @@ use headers::authorization::Basic;
20
21
use headers:: authorization:: Bearer ;
21
22
use headers:: authorization:: Credentials ;
22
23
use http:: header:: AUTHORIZATION ;
24
+ use http:: HeaderValue ;
23
25
use poem:: error:: Error as PoemError ;
24
26
use poem:: error:: Result as PoemResult ;
25
27
use poem:: http:: StatusCode ;
@@ -49,8 +51,8 @@ impl HTTPSessionMiddleware {
49
51
}
50
52
51
53
fn get_credential ( req : & Request , kind : HttpHandlerKind ) -> Result < Credential > {
52
- let auth_headers : Vec < _ > = req. headers ( ) . get_all ( AUTHORIZATION ) . iter ( ) . collect ( ) ;
53
- if auth_headers . len ( ) > 1 {
54
+ let std_auth_headers : Vec < _ > = req. headers ( ) . get_all ( AUTHORIZATION ) . iter ( ) . collect ( ) ;
55
+ if std_auth_headers . len ( ) > 1 {
54
56
let msg = & format ! ( "Multiple {} headers detected" , AUTHORIZATION ) ;
55
57
return Err ( ErrorCode :: AuthenticateFailure ( msg) ) ;
56
58
}
@@ -59,26 +61,24 @@ fn get_credential(req: &Request, kind: HttpHandlerKind) -> Result<Credential> {
59
61
Addr :: Custom ( ..) => Some ( "127.0.0.1" . to_string ( ) ) ,
60
62
_ => None ,
61
63
} ;
62
- if auth_headers. is_empty ( ) {
63
- if let HttpHandlerKind :: Clickhouse = kind {
64
- let ( user, key) = (
65
- req. headers ( ) . get ( "X-CLICKHOUSE-USER" ) ,
66
- req. headers ( ) . get ( "X-CLICKHOUSE-KEY" ) ,
67
- ) ;
68
- if let ( Some ( name) , Some ( password) ) = ( user, key) {
69
- let c = Credential :: Password {
70
- name : String :: from_utf8 ( name. as_bytes ( ) . to_vec ( ) ) . unwrap ( ) ,
71
- password : Some ( password. as_bytes ( ) . to_vec ( ) ) ,
72
- hostname : client_ip,
73
- } ;
74
- return Ok ( c) ;
75
- }
64
+ if std_auth_headers. is_empty ( ) {
65
+ if matches ! ( kind, HttpHandlerKind :: Clickhouse ) {
66
+ auth_clickhouse_name_password ( req, client_ip)
67
+ } else {
68
+ Err ( ErrorCode :: AuthenticateFailure (
69
+ "No authorization header detected" ,
70
+ ) )
76
71
}
77
- return Err ( ErrorCode :: AuthenticateFailure (
78
- "No authorization header detected" ,
79
- ) ) ;
72
+ } else {
73
+ auth_by_header ( & std_auth_headers, client_ip)
80
74
}
81
- let value = auth_headers[ 0 ] ;
75
+ }
76
+
77
+ fn auth_by_header (
78
+ std_auth_headers : & [ & HeaderValue ] ,
79
+ client_ip : Option < String > ,
80
+ ) -> Result < Credential > {
81
+ let value = & std_auth_headers[ 0 ] ;
82
82
if value. as_bytes ( ) . starts_with ( b"Basic " ) {
83
83
match Basic :: decode ( value) {
84
84
Some ( basic) => {
@@ -107,6 +107,37 @@ fn get_credential(req: &Request, kind: HttpHandlerKind) -> Result<Credential> {
107
107
}
108
108
}
109
109
110
+ fn auth_clickhouse_name_password ( req : & Request , client_ip : Option < String > ) -> Result < Credential > {
111
+ let ( user, key) = (
112
+ req. headers ( ) . get ( "X-CLICKHOUSE-USER" ) ,
113
+ req. headers ( ) . get ( "X-CLICKHOUSE-KEY" ) ,
114
+ ) ;
115
+ if let ( Some ( name) , Some ( password) ) = ( user, key) {
116
+ let c = Credential :: Password {
117
+ name : String :: from_utf8 ( name. as_bytes ( ) . to_vec ( ) ) . unwrap ( ) ,
118
+ password : Some ( password. as_bytes ( ) . to_vec ( ) ) ,
119
+ hostname : client_ip,
120
+ } ;
121
+ Ok ( c)
122
+ } else {
123
+ let query_str = req. uri ( ) . query ( ) . unwrap_or_default ( ) ;
124
+ let query_params = serde_urlencoded:: from_str :: < HashMap < String , String > > ( query_str)
125
+ . map_err ( |e| ErrorCode :: BadArguments ( format ! ( "{}" , e) ) ) ?;
126
+ let ( user, key) = ( query_params. get ( "user" ) , query_params. get ( "password" ) ) ;
127
+ if let ( Some ( name) , Some ( password) ) = ( user, key) {
128
+ Ok ( Credential :: Password {
129
+ name : name. clone ( ) ,
130
+ password : Some ( password. as_bytes ( ) . to_vec ( ) ) ,
131
+ hostname : client_ip,
132
+ } )
133
+ } else {
134
+ Err ( ErrorCode :: AuthenticateFailure (
135
+ "No header or query parameters for authorization detected" ,
136
+ ) )
137
+ }
138
+ }
139
+ }
140
+
110
141
impl < E : Endpoint > Middleware < E > for HTTPSessionMiddleware {
111
142
type Output = HTTPSessionEndpoint < E > ;
112
143
fn transform ( & self , ep : E ) -> Self :: Output {
0 commit comments