1
+ mod json_parse;
1
2
mod rate_limiter;
2
3
3
- // use postgres ::{Client, Error, NoTls };
4
+ use json_parse :: { JsonPath , RateLimiterJson , Rule } ;
4
5
use proxy_wasm:: traits:: * ;
5
6
use proxy_wasm:: types:: * ;
6
7
use rate_limiter:: RateLimiter ;
7
- use serde:: { Deserialize , Serialize } ;
8
+ use serde:: Deserialize ;
8
9
10
+ use std:: collections:: HashMap ;
9
11
use std:: time:: SystemTime ;
10
12
13
+ // We need to make sure a HTTP root context is created and initialized when the filter is initialized.
14
+ // The _start() function initialises this root context
11
15
#[ no_mangle]
12
16
pub fn _start ( ) {
13
17
proxy_wasm:: set_log_level ( LogLevel :: Info ) ;
14
- proxy_wasm:: set_http_context ( |_context_id, _root_context_id| -> Box < dyn HttpContext > {
15
- Box :: new ( UpstreamCall :: new ( ) )
18
+ proxy_wasm:: set_root_context ( |_| -> Box < dyn RootContext > {
19
+ Box :: new ( UpstreamCallRoot {
20
+ config_json : HashMap :: new ( ) ,
21
+ } )
16
22
} ) ;
17
23
}
18
24
19
- #[ derive( Debug ) ]
20
- struct UpstreamCall {
21
- //paths: Vec<String>,
22
- }
23
-
24
- impl UpstreamCall {
25
- fn new ( ) -> Self {
26
- return Self {
27
- //paths: retrieve().unwrap(),
28
- } ;
29
- }
30
- }
31
-
32
- /*
33
- fn retrieve() -> Result<Vec<String>, Error> {
34
- let mut client = Client::connect("host=localhost user=postgres dbname=mesherydb", NoTls)?;
35
- let array = client
36
- .query("SELECT PathName FROM Paths", &[])?
37
- .iter()
38
- .map(|x| x.get(0))
39
- .collect();
40
- Ok(array)
41
- }
42
- */
43
-
44
- //to be removed
45
- static ALLOWED_PATHS : [ & str ; 4 ] = [ "/auth" , "/signup" , "/upgrade" , "/pull" ] ;
25
+ // Defining standard CORS headers
46
26
static CORS_HEADERS : [ ( & str , & str ) ; 5 ] = [
47
27
( "Powered-By" , "proxy-wasm" ) ,
48
28
( "Access-Control-Allow-Origin" , "*" ) ,
@@ -51,57 +31,131 @@ static CORS_HEADERS: [(&str, &str); 5] = [
51
31
( "Access-Control-Max-Age" , "3600" ) ,
52
32
] ;
53
33
54
- #[ derive( Serialize , Deserialize , Debug ) ]
34
+ // This struct is what the JWT token sent by the user will deserialize to
35
+ #[ derive( Deserialize , Debug ) ]
55
36
struct Data {
56
37
username : String ,
57
38
plan : String ,
58
39
}
59
40
41
+ // This is the instance of a call made. It sorta derives from the root context
42
+ #[ derive( Debug ) ]
43
+ struct UpstreamCall {
44
+ config_json : HashMap < String , Rule > ,
45
+ }
46
+
47
+ impl UpstreamCall {
48
+ // Takes in the HashMap created in the root context mapping path name to rule type
49
+ fn new ( json_hm : & HashMap < String , Rule > ) -> Self {
50
+ Self {
51
+ //TODO this clone is super heavy, find a way to get rid of it
52
+ config_json : json_hm. clone ( ) ,
53
+ }
54
+ }
55
+
56
+ // Check if the path specified in the incoming request's path header has rule type None.
57
+ // Returns Option containing path name that was sent
58
+ fn rule_is_none ( & self , path : String ) -> Option < String > {
59
+ let rule = self . config_json . get ( & path) . unwrap ( ) ;
60
+ // checking based only on type
61
+ if std:: mem:: discriminant ( rule) == std:: mem:: discriminant ( & Rule :: None ) {
62
+ return Some ( path) ;
63
+ }
64
+ return None ;
65
+ }
66
+
67
+ // Check if the path specified in the incoming request's path header has rule type RateLimiter.
68
+ // Returns Option containing vector of RateLimiterJson objects (list of plan names with limits)
69
+ fn rule_is_rate_limiter ( & self , path : String ) -> Option < Vec < RateLimiterJson > > {
70
+ let rule = self . config_json . get ( & path) . unwrap ( ) ;
71
+ // checking based only on type
72
+ if std:: mem:: discriminant ( rule) == std:: mem:: discriminant ( & Rule :: RateLimiter ( Vec :: new ( ) ) ) {
73
+ if let Rule :: RateLimiter ( plans_vec) = rule {
74
+ return Some ( plans_vec. to_vec ( ) ) ;
75
+ }
76
+ }
77
+ return None ;
78
+ }
79
+ }
80
+
81
+ impl Context for UpstreamCall { }
82
+
60
83
impl HttpContext for UpstreamCall {
61
84
fn on_http_request_headers ( & mut self , _num_headers : usize ) -> Action {
85
+ // Options
62
86
if let Some ( method) = self . get_http_request_header ( ":method" ) {
63
87
if method == "OPTIONS" {
64
88
self . send_http_response ( 204 , CORS_HEADERS . to_vec ( ) , None ) ;
65
89
return Action :: Pause ;
66
90
}
67
91
}
68
- if let Some ( path ) = self . get_http_request_header ( ":path" ) {
69
- if ALLOWED_PATHS . binary_search ( & path . as_str ( ) ) . is_ok ( ) {
70
- return Action :: Continue ;
71
- }
92
+
93
+ // Action for rule type: None
94
+ if let Some ( _ ) = self . rule_is_none ( self . get_http_request_header ( ":path" ) . unwrap ( ) ) {
95
+ return Action :: Continue ;
72
96
}
73
- /*
74
- if let Some(path) = self.get_http_request_header(":path") {
75
- if self.paths.binary_search(&path.to_string()).is_ok() {
76
- return Action::Continue;
77
- }
78
- }*/
79
- if let Some ( header) = self . get_http_request_header ( "Authorization" ) {
80
- if let Ok ( token) = base64:: decode ( header) {
81
- let obj: Data = serde_json:: from_slice ( & token) . unwrap ( ) ;
82
- proxy_wasm:: hostcalls:: log ( LogLevel :: Debug , format ! ( "Obj {:?}" , obj) . as_str ( ) ) . ok ( ) ;
83
- let curr = self . get_current_time ( ) ;
84
- let tm = curr. duration_since ( SystemTime :: UNIX_EPOCH ) . unwrap ( ) ;
85
- let mn = ( tm. as_secs ( ) / 60 ) % 60 ;
86
- let _sc = tm. as_secs ( ) % 60 ;
87
- let mut rl = RateLimiter :: get ( obj. username , obj. plan ) ;
88
-
89
- let mut headers = CORS_HEADERS . to_vec ( ) ;
90
- let count: String ;
91
-
92
- if !rl. update ( mn as i32 ) {
97
+
98
+ // Action for rule type: RateLimiter
99
+ if let Some ( plans_vec) =
100
+ self . rule_is_rate_limiter ( self . get_http_request_header ( ":path" ) . unwrap ( ) )
101
+ {
102
+ if let Some ( header) = self . get_http_request_header ( "Authorization" ) {
103
+ // Decoding JWT token
104
+ if let Ok ( token) = base64:: decode ( header) {
105
+ //Deserializing token
106
+ let obj: Data = serde_json:: from_slice ( & token) . unwrap ( ) ;
107
+
108
+ proxy_wasm:: hostcalls:: log ( LogLevel :: Debug , format ! ( "Obj {:?}" , obj) . as_str ( ) )
109
+ . ok ( ) ;
110
+
111
+ // Since the rate limit works on a rate per minute based quota, we find current time
112
+ let curr = self . get_current_time ( ) ;
113
+ let tm = curr. duration_since ( SystemTime :: UNIX_EPOCH ) . unwrap ( ) ;
114
+ let mn = ( tm. as_secs ( ) / 60 ) % 60 ;
115
+ let _sc = tm. as_secs ( ) % 60 ;
116
+
117
+ // Initialise RateLimiter object
118
+ let mut rl = RateLimiter :: get ( & obj. username , & obj. plan ) ;
119
+
120
+ // Initialising headers to send back
121
+ let mut headers = CORS_HEADERS . to_vec ( ) ;
122
+ let count: String ;
123
+
124
+ // Extracting limits based on plan stated in JWT token from the corresponding RateLimiterJson
125
+ let limit = plans_vec
126
+ . into_iter ( )
127
+ . filter ( |x| x. identifier == obj. plan )
128
+ . map ( |x| x. limit )
129
+ . collect :: < Vec < u32 > > ( ) ;
130
+
131
+ // Checking if the appropriate plan exists
132
+ if limit. len ( ) != 1 {
133
+ self . send_http_response (
134
+ 429 ,
135
+ headers,
136
+ Some ( b"Invalid plan name or duplicate plan names defined.\n " ) ,
137
+ ) ;
138
+ return Action :: Pause ;
139
+ }
140
+
141
+ //Update request count in RateLimiter object, and check if it exceeds limits
142
+ if rl. update ( mn as i32 ) > limit[ 0 ] {
143
+ count = rl. count . to_string ( ) ;
144
+ headers
145
+ . append ( & mut vec ! [ ( "x-rate-limit" , & count) , ( "x-app-user" , & rl. key) ] ) ;
146
+ self . send_http_response ( 429 , headers, Some ( b"Limit exceeded.\n " ) ) ;
147
+ rl. set ( ) ;
148
+ return Action :: Pause ;
149
+ }
150
+ proxy_wasm:: hostcalls:: log ( LogLevel :: Debug , format ! ( "Obj {:?}" , & rl) . as_str ( ) )
151
+ . ok ( ) ;
152
+ // set the new count in headers, and proxy_wasm storage
93
153
count = rl. count . to_string ( ) ;
94
- headers. append ( & mut vec ! [ ( "x-rate-limit" , & count) , ( "x-app-user" , & rl. key) ] ) ;
95
- self . send_http_response ( 429 , headers, Some ( b"Limit exceeded.\n " ) ) ;
96
154
rl. set ( ) ;
97
- return Action :: Pause ;
155
+ headers. append ( & mut vec ! [ ( "x-rate-limit" , & count) , ( "x-app-user" , & rl. key) ] ) ;
156
+ self . send_http_response ( 200 , headers, Some ( b"All Good!\n " ) ) ;
157
+ return Action :: Continue ;
98
158
}
99
- proxy_wasm:: hostcalls:: log ( LogLevel :: Debug , format ! ( "Obj {:?}" , & rl) . as_str ( ) ) . ok ( ) ;
100
- count = rl. count . to_string ( ) ;
101
- rl. set ( ) ;
102
- headers. append ( & mut vec ! [ ( "x-rate-limit" , & count) , ( "x-app-user" , & rl. key) ] ) ;
103
- self . send_http_response ( 200 , headers, Some ( b"All Good!\n " ) ) ;
104
- return Action :: Continue ;
105
159
}
106
160
}
107
161
self . send_http_response ( 401 , CORS_HEADERS . to_vec ( ) , Some ( b"Unauthorized\n " ) ) ;
@@ -115,10 +169,38 @@ impl HttpContext for UpstreamCall {
115
169
}
116
170
}
117
171
118
- impl UpstreamCall {
119
- // fn retrieve_rl(&self) -> RateLimiter {
120
- // }
172
+ struct UpstreamCallRoot {
173
+ config_json : HashMap < String , Rule > ,
121
174
}
122
175
123
- impl Context for UpstreamCall { }
124
- impl RootContext for UpstreamCall { }
176
+ impl Context for UpstreamCallRoot { }
177
+ impl < ' a > RootContext for UpstreamCallRoot {
178
+ //TODO: Revisit this once the read only feature is released in Istio 1.10
179
+ // Get Base64 encoded JSON from envoy config file when WASM VM starts
180
+ fn on_vm_start ( & mut self , _: usize ) -> bool {
181
+ if let Some ( config_bytes) = self . get_configuration ( ) {
182
+ // bytestring passed by VM -> String of base64 encoded JSON
183
+ let config_str = String :: from_utf8 ( config_bytes) . unwrap ( ) ;
184
+ // String of base64 encoded JSON -> bytestring of decoded JSON
185
+ let config_b64 = base64:: decode ( config_str) . unwrap ( ) ;
186
+ // bytestring of decoded JSON -> String of decoded JSON
187
+ let json_str = String :: from_utf8 ( config_b64) . unwrap ( ) ;
188
+ // Deserializing JSON String into vector of JsonPath objects
189
+ let json_vec: Vec < JsonPath > = serde_json:: from_str ( & json_str) . unwrap ( ) ;
190
+ // Creating HashMap of pattern ("path name", "rule type") and saving into UpstreamCallRoot object
191
+ for i in json_vec {
192
+ self . config_json . insert ( i. name , i. rule ) ;
193
+ }
194
+ }
195
+ true
196
+ }
197
+
198
+ fn create_http_context ( & self , _: u32 ) -> Option < Box < dyn HttpContext > > {
199
+ // creating UpstreamCall object for each new call
200
+ Some ( Box :: new ( UpstreamCall :: new ( & self . config_json ) ) )
201
+ }
202
+
203
+ fn get_type ( & self ) -> Option < ContextType > {
204
+ Some ( ContextType :: HttpContext )
205
+ }
206
+ }
0 commit comments