diff --git a/robyn/__init__.py b/robyn/__init__.py index fbedb24b..38effeaa 100644 --- a/robyn/__init__.py +++ b/robyn/__init__.py @@ -131,9 +131,6 @@ def add_route( list_openapi_tags: List[str] = openapi_tags if openapi_tags else [] - if auth_required: - self.middleware_router.add_auth_middleware(endpoint)(handler) - if isinstance(route_type, str): http_methods = { "GET": HttpMethod.GET, @@ -146,6 +143,9 @@ def add_route( } route_type = http_methods[route_type] + if auth_required: + self.middleware_router.add_auth_middleware(endpoint, route_type)(handler) + add_route_response = self.router.add_route( route_type=route_type, endpoint=endpoint, diff --git a/robyn/processpool.py b/robyn/processpool.py index 7669ec85..ac9aa888 100644 --- a/robyn/processpool.py +++ b/robyn/processpool.py @@ -188,8 +188,8 @@ def spawn_process( for middleware_type, middleware_function in global_middlewares: server.add_global_middleware(middleware_type, middleware_function) - for http_route_type, endpoint, function in route_middlewares: - server.add_middleware_route(http_route_type, endpoint, function) + for middleware_type, endpoint, function, route_type in route_middlewares: + server.add_middleware_route(middleware_type, endpoint, function, route_type) if Events.STARTUP in event_handlers: server.add_startup_handler(event_handlers[Events.STARTUP]) diff --git a/robyn/robyn.pyi b/robyn/robyn.pyi index ca0866ce..f72c9c9b 100644 --- a/robyn/robyn.pyi +++ b/robyn/robyn.pyi @@ -345,6 +345,7 @@ class Server: middleware_type: MiddlewareType, route: str, function: FunctionInfo, + route_type: HttpMethod, ) -> None: pass def add_startup_handler(self, function: FunctionInfo) -> None: diff --git a/robyn/router.py b/robyn/router.py index cb6f57f9..6fd161f9 100644 --- a/robyn/router.py +++ b/robyn/router.py @@ -37,6 +37,7 @@ class RouteMiddleware(NamedTuple): middleware_type: MiddlewareType route: str function: FunctionInfo + route_type: HttpMethod class GlobalMiddleware(NamedTuple): @@ -276,6 +277,7 @@ def add_route( # type: ignore self, middleware_type: MiddlewareType, endpoint: str, + route_type: HttpMethod, handler: Callable, injected_dependencies: dict, ) -> Callable: @@ -295,10 +297,10 @@ def add_route( # type: ignore params, new_injected_dependencies, ) - self.route_middlewares.append(RouteMiddleware(middleware_type, endpoint, function)) + self.route_middlewares.append(RouteMiddleware(middleware_type, endpoint, function, route_type)) return handler - def add_auth_middleware(self, endpoint: str): + def add_auth_middleware(self, endpoint: str, route_type: HttpMethod): """ This method adds an authentication middleware to the specified endpoint. """ @@ -320,6 +322,7 @@ def inner_handler(request: Request, *args): self.add_route( MiddlewareType.BEFORE_REQUEST, endpoint, + route_type, inner_handler, injected_dependencies, ) @@ -348,11 +351,12 @@ def inner_handler(*args, **kwargs): self.add_route( middleware_type, endpoint, + HttpMethod.GET, async_inner_handler, injected_dependencies, ) else: - self.add_route(middleware_type, endpoint, inner_handler, injected_dependencies) + self.add_route(middleware_type, endpoint, HttpMethod.GET, inner_handler, injected_dependencies) else: params = dict(inspect.signature(handler).parameters) diff --git a/src/server.rs b/src/server.rs index 5e4af0bd..309ddda1 100644 --- a/src/server.rs +++ b/src/server.rs @@ -394,14 +394,29 @@ impl Server { middleware_type: &MiddlewareType, route: &str, function: FunctionInfo, + http_method: HttpMethod, ) { + let mut endpoint_prefixed_with_method = http_method.to_string(); + + if !route.starts_with('/') { + endpoint_prefixed_with_method.push('/'); + } + + endpoint_prefixed_with_method.push_str(route); + debug!( "MiddleWare Route added for {:?} {} ", - middleware_type, route + middleware_type, &endpoint_prefixed_with_method ); Python::with_gil(|py| { self.middleware_router - .add_route(py, middleware_type, route, function, None) + .add_route( + py, + middleware_type, + &endpoint_prefixed_with_method, + function, + None, + ) .unwrap() }); } @@ -453,13 +468,15 @@ async fn index( ) -> impl Responder { let mut request = Request::from_actix_request(&req, payload, &global_request_headers).await; + let route = format!("{}{}", req.method(), req.uri().path()); + // Before middleware // Global let mut before_middlewares = middleware_router.get_global_middlewares(&MiddlewareType::BeforeRequest); // Route specific if let Some((function, route_params)) = - middleware_router.get_route(&MiddlewareType::BeforeRequest, req.uri().path()) + middleware_router.get_route(&MiddlewareType::BeforeRequest, &route) { before_middlewares.push(function); request.path_params = route_params; @@ -529,8 +546,7 @@ async fn index( let mut after_middlewares = middleware_router.get_global_middlewares(&MiddlewareType::AfterRequest); // Route specific - if let Some((function, _)) = - middleware_router.get_route(&MiddlewareType::AfterRequest, req.uri().path()) + if let Some((function, _)) = middleware_router.get_route(&MiddlewareType::AfterRequest, &route) { after_middlewares.push(function); } diff --git a/src/types/mod.rs b/src/types/mod.rs index 73a4bc2d..c42be5d2 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -50,6 +50,13 @@ impl HttpMethod { } } +// for: https://stackoverflow.com/a/32712140/9652621 +impl std::fmt::Display for HttpMethod { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + #[pyclass] #[derive(Default, Debug, Clone)] pub struct Url {