Skip to content

Commit 0cdf953

Browse files
committed
rollback to multiple router instead of resolving all to single router
1 parent 7ba4f99 commit 0cdf953

File tree

7 files changed

+124
-98
lines changed

7 files changed

+124
-98
lines changed

ellar/core/routing/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def get_methods(cls, methods: t.Optional[t.List[str]] = None) -> t.Set[str]:
8181

8282
def matches(self, scope: TScope) -> t.Tuple[Match, TScope]:
8383
match = super().matches(scope) # type: ignore
84-
if match[0] is not Match.NONE:
84+
if match[0] is Match.FULL:
8585
version_scheme_resolver: "BaseAPIVersioningResolver" = t.cast(
8686
"BaseAPIVersioningResolver", scope[SCOPE_API_VERSIONING_RESOLVER]
8787
)

ellar/core/routing/operation_definitions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,9 @@ def trace(
256256

257257
def http_route(
258258
self,
259-
path: str,
260-
methods: t.List[str],
259+
path: str = "/",
261260
*,
261+
methods: t.List[str],
262262
name: str = None,
263263
include_in_schema: bool = True,
264264
response: t.Union[

ellar/core/routing/route.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def build_route_operation( # type:ignore
9797
self.endpoint_parameter_model.build_model()
9898
self.include_in_schema = include_in_schema
9999
if name:
100+
# necessary for route flattening
100101
self.name = f"{name}:{self.name}"
101102

102103
def _load_model(self) -> None:

ellar/core/routing/router/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from .app import ApplicationRouter
22
from .module import ModuleMount, ModuleRouter
3-
from .route_collections import ModuleRouteCollection, RouteCollection
3+
from .route_collections import RouteCollection
44

55
__all__ = [
66
"ApplicationRouter",
77
"RouteCollection",
88
"ModuleRouter",
99
"ModuleMount",
10-
"ModuleRouteCollection",
1110
]

ellar/core/routing/router/module.py

Lines changed: 75 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import typing as t
2+
import uuid
23

3-
from starlette.routing import BaseRoute, Mount as StarletteMount, Route, Router
4+
from starlette.routing import BaseRoute, Match, Mount as StarletteMount, Route, Router
45
from starlette.types import ASGIApp
56

67
from ellar.compatible import AttributeDict
@@ -16,14 +17,14 @@
1617
from ellar.core.routing.route import RouteOperation
1718
from ellar.core.routing.websocket.route import WebsocketRouteOperation
1819
from ellar.reflect import reflect
20+
from ellar.types import TReceive, TScope, TSend
1921

2022
from ..operation_definitions import OperationDefinitions
21-
from .route_collections import ModuleRouteCollection
23+
from .route_collections import RouteCollection
2224

2325
if t.TYPE_CHECKING: # pragma: no cover
2426
from ellar.core.guard import GuardCanActivate
2527

26-
2728
__all__ = ["ModuleMount", "ModuleRouter", "controller_router_factory"]
2829

2930

@@ -33,7 +34,8 @@ def controller_router_factory(
3334
openapi = reflect.get_metadata(CONTROLLER_METADATA.OPENAPI, controller) or dict()
3435
routes = reflect.get_metadata(CONTROLLER_OPERATION_HANDLER_KEY, controller) or []
3536
app = Router()
36-
app.routes = ModuleRouteCollection(routes) # type:ignore
37+
app.routes = RouteCollection(routes) # type:ignore
38+
3739
include_in_schema = reflect.get_metadata_or_raise_exception(
3840
CONTROLLER_METADATA.INCLUDE_IN_SCHEMA, controller
3941
)
@@ -52,7 +54,7 @@ def controller_router_factory(
5254
CONTROLLER_METADATA.GUARDS, controller
5355
),
5456
include_in_schema=include_in_schema if include_in_schema is not None else True,
55-
**openapi
57+
**openapi,
5658
)
5759
return router
5860

@@ -85,6 +87,7 @@ def __init__(
8587
] = (guards or [])
8688
self._version = set(version or [])
8789
self._build: bool = False
90+
self._current_found_route_key = f"{uuid.uuid4().hex:4}_ModuleMountRoute"
8891

8992
def get_meta(self) -> t.Mapping:
9093
return self._meta
@@ -115,33 +118,32 @@ def _build_route_operation(self, route: RouteOperation) -> None:
115118
def _build_ws_route_operation(self, route: WebsocketRouteOperation) -> None:
116119
route.build_route_operation(path_prefix=self.path, name=self.name)
117120

118-
def get_flatten_routes(self) -> t.List[BaseRoute]:
119-
if not self._build:
120-
for route in self.routes:
121-
_route: RouteOperation = t.cast(RouteOperation, route)
121+
def build_child_routes(self, flatten: bool = False) -> None:
122+
for route in self.routes:
123+
_route: RouteOperation = t.cast(RouteOperation, route)
124+
125+
route_versioning = reflect.get_metadata(VERSIONING_KEY, _route.endpoint)
126+
route_guards = reflect.get_metadata(GUARDS_KEY, _route.endpoint)
122127

123-
route_versioning = reflect.get_metadata(VERSIONING_KEY, _route.endpoint)
124-
route_guards = reflect.get_metadata(GUARDS_KEY, _route.endpoint)
128+
if not route_versioning:
129+
reflect.define_metadata(
130+
VERSIONING_KEY,
131+
self._version,
132+
_route.endpoint,
133+
default_value=set(),
134+
)
135+
if not route_guards:
136+
reflect.define_metadata(
137+
GUARDS_KEY,
138+
self._route_guards,
139+
_route.endpoint,
140+
default_value=[],
141+
)
142+
if flatten:
125143
openapi = (
126144
reflect.get_metadata(OPENAPI_KEY, _route.endpoint)
127145
or AttributeDict()
128146
)
129-
130-
if not route_versioning:
131-
reflect.define_metadata(
132-
VERSIONING_KEY,
133-
self._version,
134-
_route.endpoint,
135-
default_value=set(),
136-
)
137-
if not route_guards:
138-
reflect.define_metadata(
139-
GUARDS_KEY,
140-
self._route_guards,
141-
_route.endpoint,
142-
default_value=[],
143-
)
144-
145147
if isinstance(_route, Route):
146148
if not openapi.tags and self._meta.get("tag"):
147149
tags = {self._meta.get("tag")}
@@ -151,13 +153,57 @@ def get_flatten_routes(self) -> t.List[BaseRoute]:
151153
self._build_route_operation(_route)
152154
elif isinstance(_route, WebsocketRouteOperation):
153155
self._build_ws_route_operation(_route)
156+
157+
def get_flatten_routes(self) -> t.List[BaseRoute]:
158+
if not self._build:
159+
self.build_child_routes(flatten=True)
154160
self._build = True
155161
return list(self.routes)
156162

163+
def matches(self, scope: TScope) -> t.Tuple[Match, TScope]:
164+
match, _child_scope = super().matches(scope)
165+
if match == Match.FULL:
166+
scope_copy = dict(scope)
167+
scope_copy.update(_child_scope)
168+
partial: t.Optional[RouteOperation] = None
169+
partial_scope = dict()
170+
171+
for route in self.routes:
172+
# Determine if any route matches the incoming scope,
173+
# and hand over to the matching route if found.
174+
match, child_scope = route.matches(scope_copy)
175+
if match == Match.FULL:
176+
_child_scope.update(child_scope)
177+
_child_scope[self._current_found_route_key] = route
178+
return Match.FULL, _child_scope
179+
elif (
180+
match == Match.PARTIAL
181+
and partial is None
182+
and isinstance(route, RouteOperation)
183+
):
184+
partial = route
185+
partial_scope = dict(_child_scope)
186+
partial_scope.update(child_scope)
187+
188+
if partial:
189+
partial_scope[self._current_found_route_key] = partial
190+
return Match.PARTIAL, partial_scope
191+
192+
return Match.NONE, {}
193+
194+
async def handle(self, scope: TScope, receive: TReceive, send: TSend) -> None:
195+
route = t.cast(t.Optional[Route], scope.get(self._current_found_route_key))
196+
if route:
197+
del scope[self._current_found_route_key]
198+
await route.handle(scope, receive, send)
199+
return
200+
mount_router = t.cast(Router, self.app)
201+
await mount_router.default(scope, receive, send)
202+
157203

158204
class ModuleRouter(ModuleMount):
159205
operation_definition_class: t.Type[OperationDefinitions] = OperationDefinitions
160-
routes: ModuleRouteCollection # type:ignore
206+
routes: RouteCollection # type:ignore
161207

162208
def __init__(
163209
self,
@@ -172,7 +218,7 @@ def __init__(
172218
include_in_schema: bool = True,
173219
) -> None:
174220
app = Router()
175-
app.routes = ModuleRouteCollection() # type:ignore
221+
app.routes = RouteCollection() # type:ignore
176222

177223
super(ModuleRouter, self).__init__(
178224
path=path,

ellar/core/routing/router/route_collections.py

Lines changed: 35 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,48 @@
88
from ellar.helper import generate_controller_operation_unique_id
99

1010

11-
class ModuleRouteCollection(t.Sequence[BaseRoute]):
12-
__slots__ = ("_routes",)
11+
class RouteCollection(t.Sequence[BaseRoute]):
12+
__slots__ = ("_routes", "_served_routes")
1313

1414
def __init__(self, routes: t.Optional[t.Sequence[BaseRoute]] = None) -> None:
1515
self._routes: t.Dict[int, BaseRoute] = OrderedDict()
1616
self.extend([] if routes is None else list(routes))
17+
self._served_routes: t.List[BaseRoute] = []
18+
self.sort_routes()
1719

1820
@t.no_type_check
1921
def __getitem__(self, i: int) -> BaseRoute:
20-
return list(self._routes.values()).__getitem__(i)
22+
return self._served_routes.__getitem__(i)
23+
24+
def __setitem__(self, i: int, o: BaseRoute) -> None:
25+
self._add_operation(o)
26+
self.sort_routes()
27+
28+
def __len__(self) -> int:
29+
return len(self._routes)
30+
31+
def __iter__(self) -> t.Iterator[BaseRoute]:
32+
return iter(self._served_routes)
33+
34+
def append(self, __item: t.Any) -> None:
35+
self._add_operation(__item)
36+
self.sort_routes()
37+
38+
def get_routes(self) -> t.List[BaseRoute]:
39+
return self._served_routes.copy()
40+
41+
def extend(self, routes: t.Sequence[BaseRoute]) -> "RouteCollection":
42+
for item in routes:
43+
self._add_operation(item)
44+
self.sort_routes()
45+
return self
46+
47+
def sort_routes(self) -> None:
48+
# TODO: flatten the routes for faster look up
49+
self._served_routes = list(self._routes.values())
50+
self._served_routes.sort(
51+
key=lambda e: getattr(e, "path", getattr(e, "host", ""))
52+
)
2153

2254
def _add_operation(
2355
self, operation: t.Union[RouteOperation, WebsocketRouteOperation, BaseRoute]
@@ -50,63 +82,3 @@ def _add_operation(
5082
# TODO
5183
"""TODO: log warning to user when operations with the same route is found"""
5284
self._routes[_hash] = operation
53-
54-
def __setitem__(self, i: int, o: BaseRoute) -> None:
55-
self._add_operation(o)
56-
57-
def __len__(self) -> int:
58-
return len(self._routes)
59-
60-
def __iter__(self) -> t.Iterator[BaseRoute]:
61-
return iter(self._routes.values())
62-
63-
def append(self, __item: t.Any) -> None:
64-
self._add_operation(__item)
65-
66-
def extend(self, routes: t.Sequence[BaseRoute]) -> "ModuleRouteCollection":
67-
for item in routes:
68-
self.append(item)
69-
return self
70-
71-
72-
class RouteCollection(ModuleRouteCollection):
73-
__slots__ = ("_routes", "_served_routes")
74-
75-
def __init__(self, routes: t.Optional[t.Sequence[BaseRoute]] = None) -> None:
76-
super().__init__(routes)
77-
self._served_routes: t.List[BaseRoute] = []
78-
self.sort_routes()
79-
80-
@t.no_type_check
81-
def __getitem__(self, i: int) -> BaseRoute:
82-
return self._served_routes.__getitem__(i)
83-
84-
def __setitem__(self, i: int, o: BaseRoute) -> None:
85-
super().__setitem__(i, o)
86-
self.sort_routes()
87-
88-
def __len__(self) -> int:
89-
return len(self._routes)
90-
91-
def __iter__(self) -> t.Iterator[BaseRoute]:
92-
return iter(self._served_routes)
93-
94-
def append(self, __item: t.Any) -> None:
95-
super().append(__item)
96-
self.sort_routes()
97-
98-
def get_routes(self) -> t.List[BaseRoute]:
99-
return self._served_routes.copy()
100-
101-
def extend(self, routes: t.Sequence[BaseRoute]) -> "RouteCollection":
102-
for item in routes:
103-
self._add_operation(item)
104-
self.sort_routes()
105-
return self
106-
107-
def sort_routes(self) -> None:
108-
# TODO: flatten the routes for faster look up
109-
self._served_routes = list(self._routes.values())
110-
self._served_routes.sort(
111-
key=lambda e: getattr(e, "path", getattr(e, "host", ""))
112-
)

ellar/core/routing/websocket/route.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ellar.core.context import ExecutionContext
1313
from ellar.core.params import WebsocketEndpointArgsModel
1414
from ellar.exceptions import ImproperConfiguration, WebSocketRequestValidationError
15+
from ellar.helper import get_name
1516
from ellar.reflect import reflect
1617

1718
from ..base import WebsocketRouteOperationBase
@@ -65,6 +66,7 @@ def __init__(
6566
extra_handler_type: t.Optional[t.Type[WebSocketExtraHandler]] = None,
6667
**handlers_kwargs: t.Any,
6768
) -> None:
69+
assert path.startswith("/"), "Routed paths must start with '/'"
6870
self._handlers_kwargs: t.Dict[str, t.Any] = dict(
6971
encoding=encoding,
7072
on_receive=None,
@@ -77,7 +79,13 @@ def __init__(
7779
t.Type[WebSocketExtraHandler]
7880
] = extra_handler_type
7981

80-
super().__init__(path=path, endpoint=endpoint, name=name)
82+
self.path = path
83+
self.path_regex, self.path_format, self.param_convertors = compile_path(
84+
self.path
85+
)
86+
self.endpoint = endpoint # type: ignore
87+
self.name = get_name(endpoint) if name is None else name
88+
8189
self.endpoint_parameter_model: WebsocketEndpointArgsModel = NOT_SET
8290

8391
reflect.define_metadata(CONTROLLER_OPERATION_HANDLER_KEY, self, self.endpoint)

0 commit comments

Comments
 (0)