Skip to content

Commit 5b15c72

Browse files
authored
resolved pydantic extra forbid issues #117 (#128)
1 parent aa3a9ff commit 5b15c72

File tree

2 files changed

+49
-20
lines changed

2 files changed

+49
-20
lines changed

ninja_jwt/routers/obtain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ninja.router import Router
1+
from ninja_extra.router import Router
22

33
from ninja_jwt.schema_control import SchemaControl
44
from ninja_jwt.settings import api_settings

ninja_jwt/schema.py

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import typing
22
import warnings
3-
from typing import Any, Dict, Optional, Type, cast
3+
from typing import Any, Dict, Optional, Type, Union, cast
44

55
from django.conf import settings
66
from django.contrib.auth import authenticate, get_user_model
@@ -9,7 +9,9 @@
99
from django.utils.translation import gettext_lazy as _
1010
from ninja import ModelSchema, Schema
1111
from ninja.schema import DjangoGetter
12-
from pydantic import model_validator
12+
from ninja_extra import service_resolver
13+
from ninja_extra.context import RouteContext
14+
from pydantic import ConfigDict, model_validator
1315

1416
import ninja_jwt.exceptions as exceptions
1517
from ninja_jwt.utils import token_error
@@ -22,6 +24,26 @@
2224

2325
user_name_field = get_user_model().USERNAME_FIELD # type: ignore
2426

27+
SCHEMA_INPUT = Union[DjangoGetter, Dict]
28+
29+
30+
class SchemaInputService:
31+
def __init__(self, values: SCHEMA_INPUT, model_config: ConfigDict) -> None:
32+
self.model_config = model_config
33+
self.values = values
34+
35+
def get_request(self) -> HttpRequest:
36+
if self.model_config.get("extra") == "forbid":
37+
return service_resolver(RouteContext).request
38+
return self.values._context.get("request")
39+
40+
def get_values(self) -> Dict:
41+
if self.model_config.get("extra") == "forbid":
42+
return self.values
43+
if isinstance(self.values, DjangoGetter):
44+
return self.values._obj
45+
return self.values
46+
2547

2648
class AuthUserSchema(ModelSchema):
2749
class Config:
@@ -100,15 +122,16 @@ class Config:
100122
# extra = "allow"
101123
model = get_user_model()
102124
model_fields = ["password", user_name_field]
125+
extra = "forbid"
103126

104127
@model_validator(mode="before")
105-
def validate_inputs(cls, values: DjangoGetter) -> DjangoGetter:
106-
input_values = values._obj
107-
request = values._context.get("request")
128+
def validate_inputs(cls, values: SCHEMA_INPUT) -> dict:
129+
schema_input = SchemaInputService(values, cls.model_config)
130+
input_values = schema_input.get_values()
131+
request = schema_input.get_request()
132+
108133
if isinstance(input_values, dict):
109-
values._obj.update(
110-
cls.validate_values(request=request, values=input_values)
111-
)
134+
values.update(cls.validate_values(request=request, values=input_values))
112135
return values
113136
return values
114137

@@ -190,8 +213,9 @@ class TokenRefreshInputSchema(Schema, InputSchemaMixin):
190213
refresh: str
191214

192215
@model_validator(mode="before")
193-
def validate_schema(cls, values: DjangoGetter) -> dict:
194-
values = values._obj
216+
def validate_schema(cls, values: SCHEMA_INPUT) -> dict:
217+
schema_input = SchemaInputService(values, cls.model_config)
218+
values = schema_input.get_values()
195219

196220
if isinstance(values, dict):
197221
if not values.get("refresh"):
@@ -209,8 +233,9 @@ class TokenRefreshOutputSchema(Schema):
209233

210234
@model_validator(mode="before")
211235
@token_error
212-
def validate_schema(cls, values: DjangoGetter) -> typing.Any:
213-
values = values._obj
236+
def validate_schema(cls, values: SCHEMA_INPUT) -> typing.Any:
237+
schema_input = SchemaInputService(values, cls.model_config)
238+
values = schema_input.get_values()
214239

215240
if isinstance(values, dict):
216241
if not values.get("refresh"):
@@ -245,8 +270,9 @@ class TokenRefreshSlidingInputSchema(Schema, InputSchemaMixin):
245270
token: str
246271

247272
@model_validator(mode="before")
248-
def validate_schema(cls, values: DjangoGetter) -> dict:
249-
values = values._obj
273+
def validate_schema(cls, values: SCHEMA_INPUT) -> dict:
274+
schema_input = SchemaInputService(values, cls.model_config)
275+
values = schema_input.get_values()
250276

251277
if isinstance(values, dict):
252278
if not values.get("token"):
@@ -263,8 +289,9 @@ class TokenRefreshSlidingOutputSchema(Schema):
263289

264290
@model_validator(mode="before")
265291
@token_error
266-
def validate_schema(cls, values: DjangoGetter) -> dict:
267-
values = values._obj
292+
def validate_schema(cls, values: SCHEMA_INPUT) -> dict:
293+
schema_input = SchemaInputService(values, cls.model_config)
294+
values = schema_input.get_values()
268295

269296
if isinstance(values, dict):
270297
if not values.get("token"):
@@ -288,8 +315,9 @@ class TokenVerifyInputSchema(Schema, InputSchemaMixin):
288315

289316
@model_validator(mode="before")
290317
@token_error
291-
def validate_schema(cls, values: DjangoGetter) -> Dict:
292-
values = values._obj
318+
def validate_schema(cls, values: SCHEMA_INPUT) -> Dict:
319+
schema_input = SchemaInputService(values, cls.model_config)
320+
values = schema_input.get_values()
293321

294322
if isinstance(values, dict):
295323
if not values.get("token"):
@@ -320,7 +348,8 @@ class TokenBlacklistInputSchema(Schema, InputSchemaMixin):
320348
@model_validator(mode="before")
321349
@token_error
322350
def validate_schema(cls, values: DjangoGetter) -> dict:
323-
values = values._obj
351+
schema_input = SchemaInputService(values, cls.model_config)
352+
values = schema_input.get_values()
324353

325354
if isinstance(values, dict):
326355
if not values.get("refresh"):

0 commit comments

Comments
 (0)