1
1
import typing
2
2
import warnings
3
- from typing import Any , Dict , Optional , Type , cast
3
+ from typing import Any , Dict , Optional , Type , Union , cast
4
4
5
5
from django .conf import settings
6
6
from django .contrib .auth import authenticate , get_user_model
9
9
from django .utils .translation import gettext_lazy as _
10
10
from ninja import ModelSchema , Schema
11
11
from ninja .schema import DjangoGetter
12
- from pydantic import model_validator
12
+ from ninja_extra import service_resolver
13
+ from ninja_extra .context import RouteContext
14
+ from pydantic import ConfigDict , model_validator
13
15
14
16
import ninja_jwt .exceptions as exceptions
15
17
from ninja_jwt .utils import token_error
22
24
23
25
user_name_field = get_user_model ().USERNAME_FIELD # type: ignore
24
26
27
+ SCHEMA_INPUT = Union [DjangoGetter , Dict ]
28
+
29
+
30
+ class SchemaInputService :
31
+ def __init__ (self , values : SCHEMA_INPUT , model_config : ConfigDict ) -> None :
32
+ self .model_config = model_config
33
+ self .values = values
34
+
35
+ def get_request (self ) -> HttpRequest :
36
+ if self .model_config .get ("extra" ) == "forbid" :
37
+ return service_resolver (RouteContext ).request
38
+ return self .values ._context .get ("request" )
39
+
40
+ def get_values (self ) -> Dict :
41
+ if self .model_config .get ("extra" ) == "forbid" :
42
+ return self .values
43
+ if isinstance (self .values , DjangoGetter ):
44
+ return self .values ._obj
45
+ return self .values
46
+
25
47
26
48
class AuthUserSchema (ModelSchema ):
27
49
class Config :
@@ -100,15 +122,16 @@ class Config:
100
122
# extra = "allow"
101
123
model = get_user_model ()
102
124
model_fields = ["password" , user_name_field ]
125
+ extra = "forbid"
103
126
104
127
@model_validator (mode = "before" )
105
- def validate_inputs (cls , values : DjangoGetter ) -> DjangoGetter :
106
- input_values = values ._obj
107
- request = values ._context .get ("request" )
128
+ def validate_inputs (cls , values : SCHEMA_INPUT ) -> dict :
129
+ schema_input = SchemaInputService (values , cls .model_config )
130
+ input_values = schema_input .get_values ()
131
+ request = schema_input .get_request ()
132
+
108
133
if isinstance (input_values , dict ):
109
- values ._obj .update (
110
- cls .validate_values (request = request , values = input_values )
111
- )
134
+ values .update (cls .validate_values (request = request , values = input_values ))
112
135
return values
113
136
return values
114
137
@@ -190,8 +213,9 @@ class TokenRefreshInputSchema(Schema, InputSchemaMixin):
190
213
refresh : str
191
214
192
215
@model_validator (mode = "before" )
193
- def validate_schema (cls , values : DjangoGetter ) -> dict :
194
- values = values ._obj
216
+ def validate_schema (cls , values : SCHEMA_INPUT ) -> dict :
217
+ schema_input = SchemaInputService (values , cls .model_config )
218
+ values = schema_input .get_values ()
195
219
196
220
if isinstance (values , dict ):
197
221
if not values .get ("refresh" ):
@@ -209,8 +233,9 @@ class TokenRefreshOutputSchema(Schema):
209
233
210
234
@model_validator (mode = "before" )
211
235
@token_error
212
- def validate_schema (cls , values : DjangoGetter ) -> typing .Any :
213
- values = values ._obj
236
+ def validate_schema (cls , values : SCHEMA_INPUT ) -> typing .Any :
237
+ schema_input = SchemaInputService (values , cls .model_config )
238
+ values = schema_input .get_values ()
214
239
215
240
if isinstance (values , dict ):
216
241
if not values .get ("refresh" ):
@@ -245,8 +270,9 @@ class TokenRefreshSlidingInputSchema(Schema, InputSchemaMixin):
245
270
token : str
246
271
247
272
@model_validator (mode = "before" )
248
- def validate_schema (cls , values : DjangoGetter ) -> dict :
249
- values = values ._obj
273
+ def validate_schema (cls , values : SCHEMA_INPUT ) -> dict :
274
+ schema_input = SchemaInputService (values , cls .model_config )
275
+ values = schema_input .get_values ()
250
276
251
277
if isinstance (values , dict ):
252
278
if not values .get ("token" ):
@@ -263,8 +289,9 @@ class TokenRefreshSlidingOutputSchema(Schema):
263
289
264
290
@model_validator (mode = "before" )
265
291
@token_error
266
- def validate_schema (cls , values : DjangoGetter ) -> dict :
267
- values = values ._obj
292
+ def validate_schema (cls , values : SCHEMA_INPUT ) -> dict :
293
+ schema_input = SchemaInputService (values , cls .model_config )
294
+ values = schema_input .get_values ()
268
295
269
296
if isinstance (values , dict ):
270
297
if not values .get ("token" ):
@@ -288,8 +315,9 @@ class TokenVerifyInputSchema(Schema, InputSchemaMixin):
288
315
289
316
@model_validator (mode = "before" )
290
317
@token_error
291
- def validate_schema (cls , values : DjangoGetter ) -> Dict :
292
- values = values ._obj
318
+ def validate_schema (cls , values : SCHEMA_INPUT ) -> Dict :
319
+ schema_input = SchemaInputService (values , cls .model_config )
320
+ values = schema_input .get_values ()
293
321
294
322
if isinstance (values , dict ):
295
323
if not values .get ("token" ):
@@ -320,7 +348,8 @@ class TokenBlacklistInputSchema(Schema, InputSchemaMixin):
320
348
@model_validator (mode = "before" )
321
349
@token_error
322
350
def validate_schema (cls , values : DjangoGetter ) -> dict :
323
- values = values ._obj
351
+ schema_input = SchemaInputService (values , cls .model_config )
352
+ values = schema_input .get_values ()
324
353
325
354
if isinstance (values , dict ):
326
355
if not values .get ("refresh" ):
0 commit comments