Skip to content

Commit 7963237

Browse files
authored
Merge pull request #49 from pranav-bhatt/master
ImageHub(Feature addition): Generalised the paths and modularized the filter for easy addition of rules
2 parents 636118e + 7b4e81f commit 7963237

File tree

7 files changed

+197
-105
lines changed

7 files changed

+197
-105
lines changed

rate-limit-filter/Cargo.toml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,9 @@ edition = "2018"
99
crate-type = ["cdylib"]
1010

1111
[dependencies]
12-
wasm-bindgen = "0.2"
12+
base64 = "0.13.0"
13+
bincode = "1.0"
1314
proxy-wasm = "^0.1"
1415
serde = { version = "1.0", default-features = false, features = ["derive"] }
15-
bincode = "1.0"
16-
#postgres = "^0.19.0"
17-
base64 = "0.12.1"
18-
serde_json ="1.0"
16+
serde_json ="1.0"
17+
wasm-bindgen = "0.2"
50.3 KB
Binary file not shown.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
mod rules;
2+
3+
pub use rules::JsonPath;
4+
pub use rules::RateLimiterJson;
5+
pub use rules::Rule;
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
use serde::Deserialize;
2+
3+
#[derive(Clone, Debug, Deserialize, PartialEq, PartialOrd)]
4+
pub struct JsonPath {
5+
pub name: String,
6+
pub rule: Rule,
7+
}
8+
9+
#[derive(Clone, Debug, Deserialize, PartialEq, PartialOrd)]
10+
#[serde(rename_all(deserialize = "kebab-case"))]
11+
#[serde(tag = "ruleType", content = "parameters")]
12+
pub enum Rule {
13+
RateLimiter(Vec<RateLimiterJson>),
14+
None,
15+
}
16+
17+
#[derive(Clone, Debug, Deserialize, PartialEq, PartialOrd)]
18+
pub struct RateLimiterJson {
19+
pub identifier: String,
20+
pub limit: u32,
21+
}

rate-limit-filter/src/lib.rs

Lines changed: 152 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,28 @@
1+
mod json_parse;
12
mod rate_limiter;
23

3-
//use postgres::{Client, Error, NoTls};
4+
use json_parse::{JsonPath, RateLimiterJson, Rule};
45
use proxy_wasm::traits::*;
56
use proxy_wasm::types::*;
67
use rate_limiter::RateLimiter;
7-
use serde::{Deserialize, Serialize};
8+
use serde::Deserialize;
89

10+
use std::collections::HashMap;
911
use std::time::SystemTime;
1012

13+
// We need to make sure a HTTP root context is created and initialized when the filter is initialized.
14+
// The _start() function initialises this root context
1115
#[no_mangle]
1216
pub fn _start() {
1317
proxy_wasm::set_log_level(LogLevel::Info);
14-
proxy_wasm::set_http_context(|_context_id, _root_context_id| -> Box<dyn HttpContext> {
15-
Box::new(UpstreamCall::new())
18+
proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> {
19+
Box::new(UpstreamCallRoot {
20+
config_json: HashMap::new(),
21+
})
1622
});
1723
}
1824

19-
#[derive(Debug)]
20-
struct UpstreamCall {
21-
//paths: Vec<String>,
22-
}
23-
24-
impl UpstreamCall {
25-
fn new() -> Self {
26-
return Self {
27-
//paths: retrieve().unwrap(),
28-
};
29-
}
30-
}
31-
32-
/*
33-
fn retrieve() -> Result<Vec<String>, Error> {
34-
let mut client = Client::connect("host=localhost user=postgres dbname=mesherydb", NoTls)?;
35-
let array = client
36-
.query("SELECT PathName FROM Paths", &[])?
37-
.iter()
38-
.map(|x| x.get(0))
39-
.collect();
40-
Ok(array)
41-
}
42-
*/
43-
44-
//to be removed
45-
static ALLOWED_PATHS: [&str; 4] = ["/auth", "/signup", "/upgrade", "/pull"];
25+
// Defining standard CORS headers
4626
static CORS_HEADERS: [(&str, &str); 5] = [
4727
("Powered-By", "proxy-wasm"),
4828
("Access-Control-Allow-Origin", "*"),
@@ -51,57 +31,131 @@ static CORS_HEADERS: [(&str, &str); 5] = [
5131
("Access-Control-Max-Age", "3600"),
5232
];
5333

54-
#[derive(Serialize, Deserialize, Debug)]
34+
// This struct is what the JWT token sent by the user will deserialize to
35+
#[derive(Deserialize, Debug)]
5536
struct Data {
5637
username: String,
5738
plan: String,
5839
}
5940

41+
// This is the instance of a call made. It sorta derives from the root context
42+
#[derive(Debug)]
43+
struct UpstreamCall {
44+
config_json: HashMap<String, Rule>,
45+
}
46+
47+
impl UpstreamCall {
48+
// Takes in the HashMap created in the root context mapping path name to rule type
49+
fn new(json_hm: &HashMap<String, Rule>) -> Self {
50+
Self {
51+
//TODO this clone is super heavy, find a way to get rid of it
52+
config_json: json_hm.clone(),
53+
}
54+
}
55+
56+
// Check if the path specified in the incoming request's path header has rule type None.
57+
// Returns Option containing path name that was sent
58+
fn rule_is_none(&self, path: String) -> Option<String> {
59+
let rule = self.config_json.get(&path).unwrap();
60+
// checking based only on type
61+
if std::mem::discriminant(rule) == std::mem::discriminant(&Rule::None) {
62+
return Some(path);
63+
}
64+
return None;
65+
}
66+
67+
// Check if the path specified in the incoming request's path header has rule type RateLimiter.
68+
// Returns Option containing vector of RateLimiterJson objects (list of plan names with limits)
69+
fn rule_is_rate_limiter(&self, path: String) -> Option<Vec<RateLimiterJson>> {
70+
let rule = self.config_json.get(&path).unwrap();
71+
// checking based only on type
72+
if std::mem::discriminant(rule) == std::mem::discriminant(&Rule::RateLimiter(Vec::new())) {
73+
if let Rule::RateLimiter(plans_vec) = rule {
74+
return Some(plans_vec.to_vec());
75+
}
76+
}
77+
return None;
78+
}
79+
}
80+
81+
impl Context for UpstreamCall {}
82+
6083
impl HttpContext for UpstreamCall {
6184
fn on_http_request_headers(&mut self, _num_headers: usize) -> Action {
85+
// Options
6286
if let Some(method) = self.get_http_request_header(":method") {
6387
if method == "OPTIONS" {
6488
self.send_http_response(204, CORS_HEADERS.to_vec(), None);
6589
return Action::Pause;
6690
}
6791
}
68-
if let Some(path) = self.get_http_request_header(":path") {
69-
if ALLOWED_PATHS.binary_search(&path.as_str()).is_ok() {
70-
return Action::Continue;
71-
}
92+
93+
// Action for rule type: None
94+
if let Some(_) = self.rule_is_none(self.get_http_request_header(":path").unwrap()) {
95+
return Action::Continue;
7296
}
73-
/*
74-
if let Some(path) = self.get_http_request_header(":path") {
75-
if self.paths.binary_search(&path.to_string()).is_ok() {
76-
return Action::Continue;
77-
}
78-
}*/
79-
if let Some(header) = self.get_http_request_header("Authorization") {
80-
if let Ok(token) = base64::decode(header) {
81-
let obj: Data = serde_json::from_slice(&token).unwrap();
82-
proxy_wasm::hostcalls::log(LogLevel::Debug, format!("Obj {:?}", obj).as_str()).ok();
83-
let curr = self.get_current_time();
84-
let tm = curr.duration_since(SystemTime::UNIX_EPOCH).unwrap();
85-
let mn = (tm.as_secs() / 60) % 60;
86-
let _sc = tm.as_secs() % 60;
87-
let mut rl = RateLimiter::get(obj.username, obj.plan);
88-
89-
let mut headers = CORS_HEADERS.to_vec();
90-
let count: String;
91-
92-
if !rl.update(mn as i32) {
97+
98+
// Action for rule type: RateLimiter
99+
if let Some(plans_vec) =
100+
self.rule_is_rate_limiter(self.get_http_request_header(":path").unwrap())
101+
{
102+
if let Some(header) = self.get_http_request_header("Authorization") {
103+
// Decoding JWT token
104+
if let Ok(token) = base64::decode(header) {
105+
//Deserializing token
106+
let obj: Data = serde_json::from_slice(&token).unwrap();
107+
108+
proxy_wasm::hostcalls::log(LogLevel::Debug, format!("Obj {:?}", obj).as_str())
109+
.ok();
110+
111+
// Since the rate limit works on a rate per minute based quota, we find current time
112+
let curr = self.get_current_time();
113+
let tm = curr.duration_since(SystemTime::UNIX_EPOCH).unwrap();
114+
let mn = (tm.as_secs() / 60) % 60;
115+
let _sc = tm.as_secs() % 60;
116+
117+
// Initialise RateLimiter object
118+
let mut rl = RateLimiter::get(&obj.username, &obj.plan);
119+
120+
// Initialising headers to send back
121+
let mut headers = CORS_HEADERS.to_vec();
122+
let count: String;
123+
124+
// Extracting limits based on plan stated in JWT token from the corresponding RateLimiterJson
125+
let limit = plans_vec
126+
.into_iter()
127+
.filter(|x| x.identifier == obj.plan)
128+
.map(|x| x.limit)
129+
.collect::<Vec<u32>>();
130+
131+
// Checking if the appropriate plan exists
132+
if limit.len() != 1 {
133+
self.send_http_response(
134+
429,
135+
headers,
136+
Some(b"Invalid plan name or duplicate plan names defined.\n"),
137+
);
138+
return Action::Pause;
139+
}
140+
141+
//Update request count in RateLimiter object, and check if it exceeds limits
142+
if rl.update(mn as i32) > limit[0] {
143+
count = rl.count.to_string();
144+
headers
145+
.append(&mut vec![("x-rate-limit", &count), ("x-app-user", &rl.key)]);
146+
self.send_http_response(429, headers, Some(b"Limit exceeded.\n"));
147+
rl.set();
148+
return Action::Pause;
149+
}
150+
proxy_wasm::hostcalls::log(LogLevel::Debug, format!("Obj {:?}", &rl).as_str())
151+
.ok();
152+
// set the new count in headers, and proxy_wasm storage
93153
count = rl.count.to_string();
94-
headers.append(&mut vec![("x-rate-limit", &count), ("x-app-user", &rl.key)]);
95-
self.send_http_response(429, headers, Some(b"Limit exceeded.\n"));
96154
rl.set();
97-
return Action::Pause;
155+
headers.append(&mut vec![("x-rate-limit", &count), ("x-app-user", &rl.key)]);
156+
self.send_http_response(200, headers, Some(b"All Good!\n"));
157+
return Action::Continue;
98158
}
99-
proxy_wasm::hostcalls::log(LogLevel::Debug, format!("Obj {:?}", &rl).as_str()).ok();
100-
count = rl.count.to_string();
101-
rl.set();
102-
headers.append(&mut vec![("x-rate-limit", &count), ("x-app-user", &rl.key)]);
103-
self.send_http_response(200, headers, Some(b"All Good!\n"));
104-
return Action::Continue;
105159
}
106160
}
107161
self.send_http_response(401, CORS_HEADERS.to_vec(), Some(b"Unauthorized\n"));
@@ -115,10 +169,38 @@ impl HttpContext for UpstreamCall {
115169
}
116170
}
117171

118-
impl UpstreamCall {
119-
// fn retrieve_rl(&self) -> RateLimiter {
120-
// }
172+
struct UpstreamCallRoot {
173+
config_json: HashMap<String, Rule>,
121174
}
122175

123-
impl Context for UpstreamCall {}
124-
impl RootContext for UpstreamCall {}
176+
impl Context for UpstreamCallRoot {}
177+
impl<'a> RootContext for UpstreamCallRoot {
178+
//TODO: Revisit this once the read only feature is released in Istio 1.10
179+
// Get Base64 encoded JSON from envoy config file when WASM VM starts
180+
fn on_vm_start(&mut self, _: usize) -> bool {
181+
if let Some(config_bytes) = self.get_configuration() {
182+
// bytestring passed by VM -> String of base64 encoded JSON
183+
let config_str = String::from_utf8(config_bytes).unwrap();
184+
// String of base64 encoded JSON -> bytestring of decoded JSON
185+
let config_b64 = base64::decode(config_str).unwrap();
186+
// bytestring of decoded JSON -> String of decoded JSON
187+
let json_str = String::from_utf8(config_b64).unwrap();
188+
// Deserializing JSON String into vector of JsonPath objects
189+
let json_vec: Vec<JsonPath> = serde_json::from_str(&json_str).unwrap();
190+
// Creating HashMap of pattern ("path name", "rule type") and saving into UpstreamCallRoot object
191+
for i in json_vec {
192+
self.config_json.insert(i.name, i.rule);
193+
}
194+
}
195+
true
196+
}
197+
198+
fn create_http_context(&self, _: u32) -> Option<Box<dyn HttpContext>> {
199+
// creating UpstreamCall object for each new call
200+
Some(Box::new(UpstreamCall::new(&self.config_json)))
201+
}
202+
203+
fn get_type(&self) -> Option<ContextType> {
204+
Some(ContextType::HttpContext)
205+
}
206+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
mod rate_limiter;
2+
3+
pub use rate_limiter::*;

rate-limit-filter/src/rate_limiter.rs renamed to rate-limit-filter/src/rate_limiter/rate_limiter.rs

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,66 +5,48 @@ use serde::{Deserialize, Serialize};
55

66
#[derive(Debug, Serialize, Deserialize, Clone)]
77
pub struct RateLimiter {
8-
pub rpm: Option<u32>,
8+
// Tracks time
99
pub min: i32,
10+
// Tracks number of calls made
1011
pub count: u32,
12+
// stores a key(username according to example)
1113
pub key: String,
1214
}
1315

1416
impl RateLimiter {
15-
fn new(key: &String, plan: &String) -> Self {
16-
let limit = match plan.as_str() {
17-
"Enterprise" => Some(100),
18-
"Team" => Some(50),
19-
"Personal" => Some(10),
20-
_ => None,
21-
};
17+
fn new(key: &String, _plan: &String) -> Self {
2218
Self {
23-
rpm: limit,
2419
min: -1,
2520
count: 0,
2621
key: key.clone(),
2722
}
2823
}
29-
pub fn get(key: String, plan: String) -> Self {
24+
// Get key and plan from proxy_wasm shared data store (username+plan name)
25+
pub fn get(key: &String, plan: &String) -> Self {
3026
if let Ok(data) = proxy_wasm::hostcalls::get_shared_data(&key.clone()) {
3127
if let Some(data) = data.0 {
3228
let data: Option<Self> = bincode::deserialize(&data).unwrap_or(None);
33-
if let Some(mut obj) = data {
34-
let limit = match plan.as_str() {
35-
"Enterprise" => Some(100),
36-
"Team" => Some(50),
37-
"Personal" => Some(10),
38-
_ => None,
39-
};
40-
obj.rpm = limit;
29+
if let Some(obj) = data {
4130
return obj;
4231
}
4332
}
4433
}
4534
return Self::new(&key, &plan);
4635
}
36+
// Set key and plan in proxy_wasm shared data store (username+plan name)
4737
pub fn set(&self) {
4838
let target: Option<Self> = Some(self.clone());
4939
let encoded: Vec<u8> = bincode::serialize(&target).unwrap();
5040
proxy_wasm::hostcalls::set_shared_data(&self.key.clone(), Some(&encoded), None).ok();
5141
}
52-
pub fn update(&mut self, time: i32) -> bool {
42+
// Update time (minute by minute) and increment count
43+
pub fn update(&mut self, time: i32) -> u32 {
5344
if self.min != time {
5445
self.min = time;
5546
self.count = 0;
5647
}
5748
self.count += 1;
58-
proxy_wasm::hostcalls::log(
59-
LogLevel::Debug,
60-
format!("Obj {:?} {:?}", self.count, self.rpm).as_str(),
61-
)
62-
.ok();
63-
if let Some(sm) = self.rpm {
64-
if self.count > sm {
65-
return false;
66-
}
67-
}
68-
return true;
49+
proxy_wasm::hostcalls::log(LogLevel::Debug, format!("Obj {:?} ", self.count).as_str()).ok();
50+
self.count
6951
}
7052
}

0 commit comments

Comments
 (0)