Skip to content

Commit 4d1af1f

Browse files
committed
Add method filter
1 parent c4f2bce commit 4d1af1f

File tree

8 files changed

+150
-2
lines changed

8 files changed

+150
-2
lines changed

src/config.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use http::uri::Uri;
88
use serde::Deserialize;
99
use structopt::StructOpt;
1010

11+
use crate::method::MethodSet;
1112
use crate::route;
1213

1314
#[derive(Debug, StructOpt)]
@@ -39,6 +40,8 @@ pub struct Config {
3940
pub struct Route {
4041
pub route: route::Route,
4142
pub rewrite_path: Option<String>,
43+
#[serde(alias = "method", default)]
44+
pub methods: Option<MethodSet>,
4245
#[serde(with = "http_serde::header_map", default)]
4346
pub response_headers: http::HeaderMap,
4447
#[serde(flatten)]

src/handler/fs.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use tokio::fs;
66
use tokio::io::ErrorKind;
77
use urlencoding::decode;
88

9+
use crate::method::MethodFilter;
910
use crate::{config, response};
1011

1112
#[derive(Debug)]
@@ -18,6 +19,10 @@ pub struct DirHandler {
1819
config: config::DirRoute,
1920
}
2021

22+
pub fn default_method_filter() -> Box<dyn MethodFilter> {
23+
Box::new(|method: &http::Method| method == http::Method::GET)
24+
}
25+
2126
impl FileHandler {
2227
pub fn new(config: config::FileRoute) -> Self {
2328
FileHandler { config }

src/handler/json.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use tokio::io::{self, AsyncReadExt, AsyncWriteExt};
1313
use tokio::sync::{Notify, RwLock};
1414
use urlencoding::decode;
1515

16+
use crate::method::MethodFilter;
1617
use crate::{config, response};
1718

1819
#[derive(Debug)]
@@ -34,6 +35,10 @@ struct Sync {
3435
buf: Vec<u8>,
3536
}
3637

38+
pub fn default_method_filter() -> Box<dyn MethodFilter> {
39+
Box::new(|method: &http::Method| method == http::Method::GET || method == http::Method::PATCH)
40+
}
41+
3742
impl JsonHandler {
3843
pub async fn new(config: config::JsonRoute) -> Result<Self> {
3944
let mut file = fs::OpenOptions::new()

src/handler/mock.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
use hyper::Body;
22

3+
use crate::method::{self, MethodFilter};
34
use crate::{config, response};
45

56
#[derive(Debug)]
67
pub struct MockHandler {
78
config: config::MockRoute,
89
}
910

11+
pub fn default_method_filter() -> Box<dyn MethodFilter> {
12+
method::any()
13+
}
14+
1015
impl MockHandler {
1116
pub fn new(config: config::MockRoute) -> Self {
1217
MockHandler { config }

src/handler/mod.rs

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,24 @@ mod json;
33
mod mock;
44
mod proxy;
55

6+
use std::fmt;
7+
68
use anyhow::Result;
79
use hyper::Body;
810

911
use self::fs::{DirHandler, FileHandler};
1012
use self::json::JsonHandler;
1113
use self::mock::MockHandler;
1214
use self::proxy::ProxyHandler;
13-
use crate::config;
15+
use crate::method::MethodFilter;
1416
use crate::path::PathRewriter;
17+
use crate::{config, response};
1518

16-
#[derive(Debug)]
1719
pub struct Handler {
1820
kind: HandlerKind,
1921
path_rewriter: Option<PathRewriter>,
2022
response_headers: http::HeaderMap,
23+
method_filter: Box<dyn MethodFilter>,
2124
}
2225

2326
#[derive(Debug)]
@@ -36,6 +39,7 @@ impl Handler {
3639
route,
3740
kind,
3841
response_headers,
42+
methods,
3943
} = route;
4044
let path_rewriter = rewrite_path.map(|replace| {
4145
let regex = route.to_regex();
@@ -50,17 +54,30 @@ impl Handler {
5054
config::RouteKind::Mock(mock) => HandlerKind::Mock(MockHandler::new(mock)),
5155
};
5256

57+
let method_filter = match methods {
58+
Some(methods) => Box::new(methods),
59+
None => kind.default_method_filter(),
60+
};
61+
5362
Ok(Handler {
5463
path_rewriter,
5564
kind,
5665
response_headers,
66+
method_filter,
5767
})
5868
}
5969

6070
pub async fn handle(
6171
&self,
6272
request: http::Request<Body>,
6373
) -> Result<http::Response<Body>, (http::Request<Body>, http::Response<Body>)> {
74+
if !self.method_filter.is_match(request.method()) {
75+
return Err((
76+
request,
77+
response::from_status(http::StatusCode::METHOD_NOT_ALLOWED),
78+
));
79+
}
80+
6481
let path = match &self.path_rewriter {
6582
Some(path_rewriter) => path_rewriter.rewrite(request.uri().path()),
6683
None => request.uri().path().to_owned(),
@@ -81,3 +98,24 @@ impl Handler {
8198
result
8299
}
83100
}
101+
102+
impl HandlerKind {
103+
fn default_method_filter(&self) -> Box<dyn MethodFilter> {
104+
match self {
105+
HandlerKind::File(_) | HandlerKind::Dir(_) => fs::default_method_filter(),
106+
HandlerKind::Proxy(_) => proxy::default_method_filter(),
107+
HandlerKind::Json(_) => json::default_method_filter(),
108+
HandlerKind::Mock(_) => mock::default_method_filter(),
109+
}
110+
}
111+
}
112+
113+
impl fmt::Debug for Handler {
114+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
115+
f.debug_struct("Handler")
116+
.field("kind", &self.kind)
117+
.field("path_rewriter", &self.path_rewriter)
118+
.field("response_headers", &self.response_headers)
119+
.finish()
120+
}
121+
}

src/handler/proxy.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use hyper::Body;
66
use hyper_rustls::HttpsConnector;
77
use once_cell::sync::Lazy;
88

9+
use crate::method::{self, MethodFilter};
910
use crate::{config, response};
1011

1112
#[derive(Debug)]
@@ -14,6 +15,10 @@ pub struct ProxyHandler {
1415
client: Arc<Client<HttpsConnector<HttpConnector>>>,
1516
}
1617

18+
pub fn default_method_filter() -> Box<dyn MethodFilter> {
19+
method::any()
20+
}
21+
1722
impl ProxyHandler {
1823
pub fn new(config: config::ProxyRoute) -> Self {
1924
static CLIENT: Lazy<Arc<Client<HttpsConnector<HttpConnector>>>> =

src/main.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use structopt::StructOpt;
22

33
mod config;
44
mod handler;
5+
mod method;
56
mod path;
67
mod response;
78
mod route;

src/method.rs

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
use std::collections::HashSet;
2+
use std::fmt;
3+
use std::iter::{once, FromIterator};
4+
use std::str::FromStr;
5+
6+
use serde::de::{self, Deserialize, Deserializer, Error, SeqAccess};
7+
8+
pub fn any() -> Box<dyn MethodFilter> {
9+
Box::new(|_: &http::Method| true)
10+
}
11+
12+
#[derive(Debug)]
13+
pub struct MethodSet {
14+
set: HashSet<http::Method>,
15+
}
16+
17+
pub trait MethodFilter: Send + Sync {
18+
fn is_match(&self, method: &http::Method) -> bool;
19+
}
20+
21+
impl<F> MethodFilter for F
22+
where
23+
F: Fn(&http::Method) -> bool + Send + Sync,
24+
{
25+
fn is_match(&self, method: &http::Method) -> bool {
26+
self(method)
27+
}
28+
}
29+
30+
impl MethodFilter for MethodSet {
31+
fn is_match(&self, method: &http::Method) -> bool {
32+
self.set.contains(&method)
33+
}
34+
}
35+
36+
impl<'de> Deserialize<'de> for MethodSet {
37+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
38+
where
39+
D: Deserializer<'de>,
40+
{
41+
struct MethodSetVisitor;
42+
43+
impl<'de> de::Visitor<'de> for MethodSetVisitor {
44+
type Value = MethodSet;
45+
46+
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
47+
formatter.write_str("a set of HTTP methods")
48+
}
49+
50+
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
51+
where
52+
E: de::Error,
53+
{
54+
self.visit_string(v.to_owned())
55+
}
56+
57+
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
58+
where
59+
E: de::Error,
60+
{
61+
let method = parse_http_method(v).map_err(E::custom)?;
62+
Ok(MethodSet {
63+
set: HashSet::from_iter(once(method)),
64+
})
65+
}
66+
67+
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
68+
where
69+
A: SeqAccess<'de>,
70+
{
71+
let mut set = HashSet::with_capacity(seq.size_hint().unwrap_or(4));
72+
while let Some(v) = seq.next_element::<String>()? {
73+
set.insert(parse_http_method(v).map_err(|err| A::Error::custom(err))?);
74+
}
75+
Ok(MethodSet { set })
76+
}
77+
}
78+
79+
deserializer.deserialize_any(MethodSetVisitor)
80+
}
81+
}
82+
83+
fn parse_http_method(mut string: String) -> Result<http::Method, http::method::InvalidMethod> {
84+
string.make_ascii_uppercase();
85+
http::Method::from_str(&string)
86+
}

0 commit comments

Comments
 (0)