From 24601228695648a73bf95217d5a2dd8b402fe61b Mon Sep 17 00:00:00 2001 From: alex Date: Fri, 7 Mar 2025 18:26:03 +0000 Subject: [PATCH] improve url matching algorithm in case of ambiguity --- starlette/routing.py | 32 +++++++++++++++++++++++++++++--- tests/test_routing.py | 26 +++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index add7df0c2..4e4910ed6 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -724,19 +724,45 @@ async def app(self, scope: Scope, receive: Receive, send: Send) -> None: return partial = None + full_match_route = None + full_match_best_number_of_path_params = None for route in self.routes: # Determine if any route matches the incoming scope, # and hand over to the matching route if found. match, child_scope = route.matches(scope) if match == Match.FULL: - scope.update(child_scope) - await route.handle(scope, receive, send) - return + number_of_path_params = len(child_scope["path_params"]) + if number_of_path_params == 0: + scope.update(child_scope) + await route.handle(scope, receive, send) + return + # in case we have 2 routes that fully match the requested URL + # we need to check which one is a stricter match. + # Example: + # we have 2 endpoints: + # /user/{user_id} + # /user/myself + # if the requested URL is /user/myself, then both endpoints are + # a full match, however the last one is a more appropriate choice + # in case of ambiguity + if full_match_route is None or ( + full_match_best_number_of_path_params is not None + and number_of_path_params < full_match_best_number_of_path_params + ): + full_match_route = route + full_match_scope = child_scope + full_match_best_number_of_path_params = number_of_path_params + elif match == Match.PARTIAL and partial is None: partial = route partial_scope = child_scope + if full_match_route is not None: + scope.update(full_match_scope) + await full_match_route.handle(scope, receive, send) + return + if partial is not None: #  Handle partial matches. These are cases where an endpoint is # able to handle the request, but is not a preferred option. diff --git a/tests/test_routing.py b/tests/test_routing.py index 933fe7c31..b263d6e41 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -33,6 +33,16 @@ def user(request: Request) -> Response: return Response(content, media_type="text/plain") +def user_greet(request: Request) -> Response: + content = "User " + request.path_params["username"] + " greeted" + return Response(content, media_type="text/plain") + + +def user_action(request: Request) -> Response: + content = "User " + request.path_params["username"] + " action " + request.path_params["action"] + return Response(content, media_type="text/plain") + + def user_me(request: Request) -> Response: content = "User fixed me" return Response(content, media_type="text/plain") @@ -124,6 +134,8 @@ async def websocket_params(session: WebSocket) -> None: Route("/", endpoint=users), Route("/me", endpoint=user_me), Route("/{username}", endpoint=user), + Route("/{username}/greet", endpoint=user_greet), + Route("/{username}/{action}", endpoint=user_action), Route("/{username}:disable", endpoint=disable_user, methods=["PUT"]), Route("/nomatch", endpoint=user_no_match), ], @@ -212,9 +224,21 @@ def test_router(client: TestClient) -> None: assert response.url == "http://testserver/users/tomchristie:disable" assert response.text == "User tomchristie disabled" + response = client.get("/users/notmatch") + assert response.status_code == 200 + assert response.text == "User notmatch" + response = client.get("/users/nomatch") assert response.status_code == 200 - assert response.text == "User nomatch" + assert response.text == "User fixed no match" + + response = client.get("/users/nomatch/greet") + assert response.status_code == 200 + assert response.text == "User nomatch greeted" + + response = client.get("/users/nomatch/complimented") + assert response.status_code == 200 + assert response.text == "User nomatch action complimented" response = client.get("/static/123") assert response.status_code == 200