Skip to content

Commit b1f3ea5

Browse files
authored
serialization/validation support django-enumfields(2) (#5)
1 parent f84e9b4 commit b1f3ea5

File tree

5 files changed

+84
-5
lines changed

5 files changed

+84
-5
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,14 @@ This can be mitigated by:
154154
to ensure relationships are fully loaded
155155

156156

157+
## Serialization / validation of 3rd party field types
158+
159+
Additionally, the following 3rd party fields / types are supported if the
160+
`DjangoModelPlugin` is installed:
161+
162+
- `django-enumfields`
163+
- `django-enumfields2`
164+
157165

158166
## Contributing
159167

litestar_django/plugin.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1+
import contextlib
12
from typing import Any
23

34
from django.db import models # type: ignore[import-untyped]
4-
from litestar.plugins.base import SerializationPlugin
5+
from litestar.config.app import AppConfig
6+
from litestar.plugins.base import SerializationPlugin, InitPlugin
57
from litestar.typing import FieldDefinition
68

79
from litestar_django.dto import DjangoModelDTO
810

911

10-
class DjangoModelPlugin(SerializationPlugin):
12+
class DjangoModelPlugin(InitPlugin, SerializationPlugin):
1113
def __init__(self) -> None:
1214
self._type_dto_map: dict[type[models.Model], type[DjangoModelDTO[Any]]] = {}
1315

@@ -35,3 +37,29 @@ def create_dto_for_type(
3537
self._type_dto_map[annotation] = dto_type = DjangoModelDTO[annotation] # type:ignore[valid-type]
3638

3739
return dto_type
40+
41+
def on_app_init(self, app_config: AppConfig) -> AppConfig:
42+
type_encoders = (
43+
dict(app_config.type_encoders) if app_config.type_encoders else {}
44+
)
45+
type_decoders = (
46+
list(app_config.type_decoders) if app_config.type_decoders else []
47+
)
48+
with contextlib.suppress(ImportError):
49+
import enumfields # type: ignore[import-untyped]
50+
51+
def _is_enumfields_enum(v: Any) -> bool:
52+
return issubclass(v, (enumfields.Enum, enumfields.IntEnum))
53+
54+
def _decode_enumfields_enum(type_: Any, value: Any) -> Any:
55+
return type_(value)
56+
57+
type_encoders[enumfields.Enum] = lambda v: v.value
58+
type_encoders[enumfields.IntEnum] = lambda v: v.value
59+
60+
type_decoders.append((_is_enumfields_enum, _decode_enumfields_enum))
61+
62+
app_config.type_encoders = type_encoders
63+
app_config.type_decoders = type_decoders
64+
65+
return app_config

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "litestar-django"
3-
version = "0.1.1"
3+
version = "0.2.0"
44
description = "Django model support for Litestar"
55
readme = "README.md"
66
license = { text = "MIT" }

tests/some_app/app/models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,11 @@ class ModelInvalidRegexValidator(models.Model):
131131
)
132132

133133

134+
class ModelWithCustomFields(models.Model):
135+
enum_field = enumfields.EnumField(StdEnum)
136+
enumfields_enum = enumfields.EnumField(LabelledEnum)
137+
138+
134139
class Author(models.Model):
135140
name = models.CharField(max_length=100)
136141

tests/test_dto_integration.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import secrets
2-
from typing import Annotated
2+
from typing import Annotated, Any
33

44
import pytest
55
from litestar import get, Response, Litestar, post
@@ -8,7 +8,13 @@
88

99
from litestar_django.dto import DjangoModelDTO
1010
from litestar_django.plugin import DjangoModelPlugin
11-
from tests.some_app.app.models import Author, Book, Genre, ModelWithFields
11+
from tests.some_app.app.models import (
12+
Author,
13+
Book,
14+
Genre,
15+
ModelWithFields,
16+
ModelWithCustomFields,
17+
)
1218

1319

1420
@pytest.mark.django_db(transaction=True)
@@ -156,6 +162,38 @@ def handler(data: Author) -> Author:
156162
assert res.json() == {"id": author.id, "name": author_name, "books": []}
157163

158164

165+
@pytest.mark.django_db(transaction=True)
166+
def test_enumfields() -> None:
167+
@post(
168+
"/",
169+
dto=DjangoModelDTO[Annotated[ModelWithCustomFields, DTOConfig(exclude={"id"})]],
170+
)
171+
async def post_handler(data: ModelWithCustomFields) -> dict[str, Any]:
172+
await data.asave()
173+
return {
174+
"id": data.id,
175+
"enum_field": data.enum_field.value,
176+
"enumfields_enum": data.enumfields_enum.value,
177+
}
178+
179+
@get("/{obj_id:int}", dto=DjangoModelDTO[ModelWithCustomFields])
180+
async def get_handler(obj_id: int) -> ModelWithCustomFields:
181+
return await ModelWithCustomFields.objects.aget(id=obj_id)
182+
183+
with create_test_client(
184+
[get_handler, post_handler], plugins=[DjangoModelPlugin()]
185+
) as client:
186+
res = client.post("/", json={"enum_field": "ONE", "enumfields_enum": "TWO"})
187+
assert res.status_code == 201
188+
data = res.json()
189+
model_id = data.pop("id")
190+
assert data == {"enum_field": "ONE", "enumfields_enum": "TWO"}
191+
192+
res = client.get(f"/{model_id}")
193+
assert res.status_code == 200
194+
assert res.json() == {"id": model_id, **data}
195+
196+
159197
def test_schema() -> None:
160198
@get("/", dto=DjangoModelDTO[ModelWithFields], sync_to_thread=True)
161199
def handler() -> ModelWithFields:

0 commit comments

Comments
 (0)