Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions ninja/params/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,6 @@ def resolve(
return cls()

data = cls._map_data_paths(data)
# Convert defaultdict to dict for pydantic 2.12+ compatibility
# In pydantic 2.12+, accessing missing keys in defaultdict creates nested
# defaultdicts which then fail validation
if isinstance(data, defaultdict):
data = dict(data)
return cls.model_validate(data, context={"request": request})

@classmethod
Expand All @@ -84,8 +79,7 @@ def _map_data_paths(cls, data: DictStrAny) -> DictStrAny:
cls._map_data_path(mapped_data, data[k], flatten_map[k])
else:
cls._map_data_path(mapped_data, None, flatten_map[k])

return mapped_data
return cls._convert_nested_defaultdicts(mapped_data) # type: ignore[no-any-return]

@classmethod
def _map_data_path(cls, data: DictStrAny, value: Any, path: Tuple) -> None:
Expand All @@ -95,6 +89,17 @@ def _map_data_path(cls, data: DictStrAny, value: Any, path: Tuple) -> None:
else:
cls._map_data_path(data[path[0]], value, path[1:])

@classmethod
def _convert_nested_defaultdicts(cls, value: Any) -> Any:
if isinstance(value, (defaultdict, dict)):
return {
key: cls._convert_nested_defaultdicts(item)
for key, item in value.items()
}
if isinstance(value, list):
return [cls._convert_nested_defaultdicts(item) for item in value]
return value


class QueryModel(ParamModel):
@classmethod
Expand Down
132 changes: 78 additions & 54 deletions tests/test_query_schema.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from datetime import datetime
from enum import IntEnum

from pydantic import Field
from pydantic import BaseModel, Field

from ninja import NinjaAPI, Query, Schema
from ninja.testing.client import TestClient


class Range(IntEnum):
Expand All @@ -12,7 +13,7 @@ class Range(IntEnum):
TWO_HUNDRED = 200


class Filter(Schema):
class Filter(BaseModel):
to_datetime: datetime = Field(alias="to")
from_datetime: datetime = Field(alias="from")
range: Range = Range.TWENTY
Expand All @@ -28,7 +29,7 @@ class Data(Schema):

@api.get("/test")
def query_params_schema(request, filters: Filter = Query(...)):
return filters.dict()
return filters.model_dump()


@api.get("/test-mixed")
Expand All @@ -39,57 +40,80 @@ def query_params_mixed_schema(
filters: Filter = Query(...),
data: Data = Query(...),
):
return dict(query1=query1, query2=query2, filters=filters.dict(), data=data.dict())


# def test_request():
# client = TestClient(api)
# response = client.get("/test?from=1&to=2&range=20&foo=1&range2=50")
# print("!", response.json())
# assert response.json() == {
# "to_datetime": "1970-01-01T00:00:02Z",
# "from_datetime": "1970-01-01T00:00:01Z",
# "range": 20,
# }

# response = client.get("/test?from=1&to=2&range=21")
# assert response.status_code == 422


# def test_request_mixed():
# client = TestClient(api)
# response = client.get(
# "/test-mixed?from=1&to=2&range=20&foo=1&range2=50&query1=2&int=3&float=1.6"
# )
# print(response.json())
# assert response.json() == {
# "data": {"a_float": 1.6, "an_int": 3},
# "filters": {
# "from_datetime": "1970-01-01T00:00:01Z",
# "range": 20,
# "to_datetime": "1970-01-01T00:00:02Z",
# },
# "query1": 2,
# "query2": 5,
# }

# response = client.get(
# "/test-mixed?from=1&to=2&range=20&foo=1&range2=50&query1=2&query2=10"
# )
# print(response.json())
# assert response.json() == {
# "data": {"a_float": 1.5, "an_int": 0},
# "filters": {
# "from_datetime": "1970-01-01T00:00:01Z",
# "range": 20,
# "to_datetime": "1970-01-01T00:00:02Z",
# },
# "query1": 2,
# "query2": 10,
# }

# response = client.get("/test-mixed?from=1&to=2")
# assert response.status_code == 422
return dict(
query1=query1,
query2=query2,
filters=filters.model_dump(),
data=data.model_dump(),
)


def test_request():
client = TestClient(api)
response = client.get("/test?from=1&to=2&range=20&foo=1&range2=50")
print("!", response.json())
assert response.json() == {
"to_datetime": "1970-01-01T00:00:02Z",
"from_datetime": "1970-01-01T00:00:01Z",
"range": 20,
}

response = client.get("/test?from=1&to=2&range=21")
assert response.status_code == 422


def test_request_mixed():
client = TestClient(api)
response = client.get(
"/test-mixed?from=1&to=2&range=20&foo=1&range2=50&query1=2&int=3&float=1.6"
)
print(response.json())
assert response.json() == {
"data": {"a_float": 1.6, "an_int": 3},
"filters": {
"from_datetime": "1970-01-01T00:00:01Z",
"range": 20,
"to_datetime": "1970-01-01T00:00:02Z",
},
"query1": 2,
"query2": 5,
}

response = client.get(
"/test-mixed?from=1&to=2&range=20&foo=1&range2=50&query1=2&query2=10"
)
print(response.json())
assert response.json() == {
"data": {"a_float": 1.5, "an_int": 0},
"filters": {
"from_datetime": "1970-01-01T00:00:01Z",
"range": 20,
"to_datetime": "1970-01-01T00:00:02Z",
},
"query1": 2,
"query2": 10,
}

response = client.get("/test-mixed?from=1&to=2")
assert response.status_code == 422


def test_request_query_params_using_basemodel():
class Foo(BaseModel):
start: int
optional: int = 42

temp_api = NinjaAPI()

@temp_api.get("/foo")
def view(request, foo: Foo = Query(...)):
return foo.model_dump()

client = TestClient(temp_api)
resp = client.get("/foo?start=1")

assert resp.status_code == 200
assert resp.json() == {"start": 1, "optional": 42}


def test_schema():
Expand Down