1
1
import typing as t
2
+ import uuid
2
3
3
- from starlette .routing import BaseRoute , Mount as StarletteMount , Route , Router
4
+ from starlette .routing import BaseRoute , Match , Mount as StarletteMount , Route , Router
4
5
from starlette .types import ASGIApp
5
6
6
7
from ellar .compatible import AttributeDict
16
17
from ellar .core .routing .route import RouteOperation
17
18
from ellar .core .routing .websocket .route import WebsocketRouteOperation
18
19
from ellar .reflect import reflect
20
+ from ellar .types import TReceive , TScope , TSend
19
21
20
22
from ..operation_definitions import OperationDefinitions
21
- from .route_collections import ModuleRouteCollection
23
+ from .route_collections import RouteCollection
22
24
23
25
if t .TYPE_CHECKING : # pragma: no cover
24
26
from ellar .core .guard import GuardCanActivate
25
27
26
-
27
28
__all__ = ["ModuleMount" , "ModuleRouter" , "controller_router_factory" ]
28
29
29
30
@@ -33,7 +34,8 @@ def controller_router_factory(
33
34
openapi = reflect .get_metadata (CONTROLLER_METADATA .OPENAPI , controller ) or dict ()
34
35
routes = reflect .get_metadata (CONTROLLER_OPERATION_HANDLER_KEY , controller ) or []
35
36
app = Router ()
36
- app .routes = ModuleRouteCollection (routes ) # type:ignore
37
+ app .routes = RouteCollection (routes ) # type:ignore
38
+
37
39
include_in_schema = reflect .get_metadata_or_raise_exception (
38
40
CONTROLLER_METADATA .INCLUDE_IN_SCHEMA , controller
39
41
)
@@ -52,7 +54,7 @@ def controller_router_factory(
52
54
CONTROLLER_METADATA .GUARDS , controller
53
55
),
54
56
include_in_schema = include_in_schema if include_in_schema is not None else True ,
55
- ** openapi
57
+ ** openapi ,
56
58
)
57
59
return router
58
60
@@ -85,6 +87,7 @@ def __init__(
85
87
] = (guards or [])
86
88
self ._version = set (version or [])
87
89
self ._build : bool = False
90
+ self ._current_found_route_key = f"{ uuid .uuid4 ().hex :4} _ModuleMountRoute"
88
91
89
92
def get_meta (self ) -> t .Mapping :
90
93
return self ._meta
@@ -115,33 +118,32 @@ def _build_route_operation(self, route: RouteOperation) -> None:
115
118
def _build_ws_route_operation (self , route : WebsocketRouteOperation ) -> None :
116
119
route .build_route_operation (path_prefix = self .path , name = self .name )
117
120
118
- def get_flatten_routes (self ) -> t .List [BaseRoute ]:
119
- if not self ._build :
120
- for route in self .routes :
121
- _route : RouteOperation = t .cast (RouteOperation , route )
121
+ def build_child_routes (self , flatten : bool = False ) -> None :
122
+ for route in self .routes :
123
+ _route : RouteOperation = t .cast (RouteOperation , route )
124
+
125
+ route_versioning = reflect .get_metadata (VERSIONING_KEY , _route .endpoint )
126
+ route_guards = reflect .get_metadata (GUARDS_KEY , _route .endpoint )
122
127
123
- route_versioning = reflect .get_metadata (VERSIONING_KEY , _route .endpoint )
124
- route_guards = reflect .get_metadata (GUARDS_KEY , _route .endpoint )
128
+ if not route_versioning :
129
+ reflect .define_metadata (
130
+ VERSIONING_KEY ,
131
+ self ._version ,
132
+ _route .endpoint ,
133
+ default_value = set (),
134
+ )
135
+ if not route_guards :
136
+ reflect .define_metadata (
137
+ GUARDS_KEY ,
138
+ self ._route_guards ,
139
+ _route .endpoint ,
140
+ default_value = [],
141
+ )
142
+ if flatten :
125
143
openapi = (
126
144
reflect .get_metadata (OPENAPI_KEY , _route .endpoint )
127
145
or AttributeDict ()
128
146
)
129
-
130
- if not route_versioning :
131
- reflect .define_metadata (
132
- VERSIONING_KEY ,
133
- self ._version ,
134
- _route .endpoint ,
135
- default_value = set (),
136
- )
137
- if not route_guards :
138
- reflect .define_metadata (
139
- GUARDS_KEY ,
140
- self ._route_guards ,
141
- _route .endpoint ,
142
- default_value = [],
143
- )
144
-
145
147
if isinstance (_route , Route ):
146
148
if not openapi .tags and self ._meta .get ("tag" ):
147
149
tags = {self ._meta .get ("tag" )}
@@ -151,13 +153,57 @@ def get_flatten_routes(self) -> t.List[BaseRoute]:
151
153
self ._build_route_operation (_route )
152
154
elif isinstance (_route , WebsocketRouteOperation ):
153
155
self ._build_ws_route_operation (_route )
156
+
157
+ def get_flatten_routes (self ) -> t .List [BaseRoute ]:
158
+ if not self ._build :
159
+ self .build_child_routes (flatten = True )
154
160
self ._build = True
155
161
return list (self .routes )
156
162
163
+ def matches (self , scope : TScope ) -> t .Tuple [Match , TScope ]:
164
+ match , _child_scope = super ().matches (scope )
165
+ if match == Match .FULL :
166
+ scope_copy = dict (scope )
167
+ scope_copy .update (_child_scope )
168
+ partial : t .Optional [RouteOperation ] = None
169
+ partial_scope = dict ()
170
+
171
+ for route in self .routes :
172
+ # Determine if any route matches the incoming scope,
173
+ # and hand over to the matching route if found.
174
+ match , child_scope = route .matches (scope_copy )
175
+ if match == Match .FULL :
176
+ _child_scope .update (child_scope )
177
+ _child_scope [self ._current_found_route_key ] = route
178
+ return Match .FULL , _child_scope
179
+ elif (
180
+ match == Match .PARTIAL
181
+ and partial is None
182
+ and isinstance (route , RouteOperation )
183
+ ):
184
+ partial = route
185
+ partial_scope = dict (_child_scope )
186
+ partial_scope .update (child_scope )
187
+
188
+ if partial :
189
+ partial_scope [self ._current_found_route_key ] = partial
190
+ return Match .PARTIAL , partial_scope
191
+
192
+ return Match .NONE , {}
193
+
194
+ async def handle (self , scope : TScope , receive : TReceive , send : TSend ) -> None :
195
+ route = t .cast (t .Optional [Route ], scope .get (self ._current_found_route_key ))
196
+ if route :
197
+ del scope [self ._current_found_route_key ]
198
+ await route .handle (scope , receive , send )
199
+ return
200
+ mount_router = t .cast (Router , self .app )
201
+ await mount_router .default (scope , receive , send )
202
+
157
203
158
204
class ModuleRouter (ModuleMount ):
159
205
operation_definition_class : t .Type [OperationDefinitions ] = OperationDefinitions
160
- routes : ModuleRouteCollection # type:ignore
206
+ routes : RouteCollection # type:ignore
161
207
162
208
def __init__ (
163
209
self ,
@@ -172,7 +218,7 @@ def __init__(
172
218
include_in_schema : bool = True ,
173
219
) -> None :
174
220
app = Router ()
175
- app .routes = ModuleRouteCollection () # type:ignore
221
+ app .routes = RouteCollection () # type:ignore
176
222
177
223
super (ModuleRouter , self ).__init__ (
178
224
path = path ,
0 commit comments