Skip to content

Commit f6b8a8e

Browse files
authored
Add Annotated support (#257)
* Add Annotated support and therefore set minimum python version as 3.9 * Add support back for Python3.8 * Fix Python 3.8 and add tox config for cross version testing * Fix style issues. Rebase did not trigger pre-commit. * Remove enum from pre-commit * re-add but deprecate the newtype documentation * Move Annotated and Union handling to their own functions * Remove tox from requirements * Add coverage report and remove virtualenv-pyenv * Add warning when multiple Field annotations have bene detected * Remove tox and document Annotated for python 3.8 * fix: line-endings
1 parent 5504913 commit f6b8a8e

File tree

9 files changed

+187
-35
lines changed

9 files changed

+187
-35
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ venv.bak/
9595
# Rope project settings
9696
.ropeproject
9797

98+
# VSCode project settings
99+
.vscode
100+
98101
# mkdocs documentation
99102
/site
100103

.pre-commit-config.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ repos:
33
rev: v3.3.1
44
hooks:
55
- id: pyupgrade
6-
args: ["--py36-plus"]
6+
# I've kept it on py3.7 so that it doesn't replace `Dict` with `dict`
7+
args: ["--py37-plus"]
78
- repo: https://github.com/python/black
89
rev: 23.1.0
910
hooks:
@@ -19,7 +20,7 @@ repos:
1920
rev: v1.1.1
2021
hooks:
2122
- id: mypy
22-
additional_dependencies: [marshmallow-enum,typeguard,marshmallow]
23+
additional_dependencies: [typeguard,marshmallow]
2324
args: [--show-error-codes]
2425
- repo: https://github.com/asottile/blacken-docs
2526
rev: 1.13.0

CONTRIBUTING.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ Every commit is checked with pre-commit hooks for :
2121
- type safety with [mypy](http://mypy-lang.org/)
2222
- test conformance by running [tests](./tests) with [pytest](https://docs.pytest.org/en/latest/)
2323
- You can run `pytest` from the command line.
24+

README.md

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,47 @@ class Sample:
242242

243243
See [marshmallow's documentation about extending `Schema`](https://marshmallow.readthedocs.io/en/stable/extending.html).
244244

245-
### Custom NewType declarations
245+
### Custom type aliases
246+
247+
This library allows you to specify [customized marshmallow fields](https://marshmallow.readthedocs.io/en/stable/custom_fields.html#creating-a-field-class) using python's Annoted type [PEP-593](https://peps.python.org/pep-0593/).
248+
249+
```python
250+
from typing import Annotated
251+
import marshmallow.fields as mf
252+
import marshmallow.validate as mv
253+
254+
IPv4 = Annotated[str, mf.String(validate=mv.Regexp(r"^([0-9]{1,3}\\.){3}[0-9]{1,3}$"))]
255+
```
256+
257+
You can also pass a marshmallow field class.
258+
259+
```python
260+
from typing import Annotated
261+
import marshmallow
262+
from marshmallow_dataclass import NewType
263+
264+
Email = Annotated[str, marshmallow.fields.Email]
265+
```
266+
267+
For convenience, some custom types are provided:
268+
269+
```python
270+
from marshmallow_dataclass.typing import Email, Url
271+
```
272+
273+
When using Python 3.8, you must import `Annotated` from the typing_extensions package
274+
275+
```python
276+
# Version agnostic import code:
277+
if sys.version_info >= (3, 9):
278+
from typing import Annotated
279+
else:
280+
from typing_extensions import Annotated
281+
```
282+
283+
### Custom NewType declarations [__deprecated__]
284+
285+
> NewType is deprecated in favor or type aliases using Annotated, as described above.
246286
247287
This library exports a `NewType` function to create types that generate [customized marshmallow fields](https://marshmallow.readthedocs.io/en/stable/custom_fields.html#creating-a-field-class).
248288

@@ -266,12 +306,6 @@ from marshmallow_dataclass import NewType
266306
Email = NewType("Email", str, field=marshmallow.fields.Email)
267307
```
268308

269-
For convenience, some custom types are provided:
270-
271-
```python
272-
from marshmallow_dataclass.typing import Email, Url
273-
```
274-
275309
Note: if you are using `mypy`, you will notice that `mypy` throws an error if a variable defined with
276310
`NewType` is used in a type annotation. To resolve this, add the `marshmallow_dataclass.mypy` plugin
277311
to your `mypy` configuration, e.g.:

marshmallow_dataclass/__init__.py

Lines changed: 86 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class User:
3434
})
3535
Schema: ClassVar[Type[Schema]] = Schema # For the type checker
3636
"""
37+
3738
import collections.abc
3839
import dataclasses
3940
import inspect
@@ -47,11 +48,13 @@ class User:
4748
Any,
4849
Callable,
4950
Dict,
51+
FrozenSet,
5052
Generic,
5153
List,
5254
Mapping,
5355
NewType as typing_NewType,
5456
Optional,
57+
Sequence,
5558
Set,
5659
Tuple,
5760
Type,
@@ -60,24 +63,23 @@ class User:
6063
cast,
6164
get_type_hints,
6265
overload,
63-
Sequence,
64-
FrozenSet,
6566
)
6667

6768
import marshmallow
69+
import typing_extensions
6870
import typing_inspect
6971

7072
from marshmallow_dataclass.lazy_class_attribute import lazy_class_attribute
7173

74+
if sys.version_info >= (3, 9):
75+
from typing import Annotated
76+
else:
77+
from typing_extensions import Annotated
7278

7379
if sys.version_info >= (3, 11):
7480
from typing import dataclass_transform
75-
elif sys.version_info >= (3, 7):
76-
from typing_extensions import dataclass_transform
7781
else:
78-
# @dataclass_transform() only helps us with mypy>=1.1 which is only available for python>=3.7
79-
def dataclass_transform(**kwargs):
80-
return lambda cls: cls
82+
from typing_extensions import dataclass_transform
8183

8284

8385
__all__ = ["dataclass", "add_schema", "class_schema", "field_for_schema", "NewType"]
@@ -511,7 +513,15 @@ def _internal_class_schema(
511513
base_schema: Optional[Type[marshmallow.Schema]] = None,
512514
) -> Type[marshmallow.Schema]:
513515
schema_ctx = _schema_ctx_stack.top
514-
schema_ctx.seen_classes[clazz] = clazz.__name__
516+
517+
if typing_extensions.get_origin(clazz) is Annotated and sys.version_info < (3, 10):
518+
# https://github.com/python/cpython/blob/3.10/Lib/typing.py#L977
519+
class_name = clazz._name or clazz.__origin__.__name__ # type: ignore[attr-defined]
520+
else:
521+
class_name = clazz.__name__
522+
523+
schema_ctx.seen_classes[clazz] = class_name
524+
515525
try:
516526
# noinspection PyDataclass
517527
fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(clazz)
@@ -546,9 +556,18 @@ def _internal_class_schema(
546556
include_non_init = getattr(getattr(clazz, "Meta", None), "include_non_init", False)
547557

548558
# Update the schema members to contain marshmallow fields instead of dataclass fields
549-
type_hints = get_type_hints(
550-
clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns
551-
)
559+
560+
if sys.version_info >= (3, 9):
561+
type_hints = get_type_hints(
562+
clazz,
563+
globalns=schema_ctx.globalns,
564+
localns=schema_ctx.localns,
565+
include_extras=True,
566+
)
567+
else:
568+
type_hints = get_type_hints(
569+
clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns
570+
)
552571
attributes.update(
553572
(
554573
field.name,
@@ -639,8 +658,8 @@ def _field_for_generic_type(
639658
"""
640659
If the type is a generic interface, resolve the arguments and construct the appropriate Field.
641660
"""
642-
origin = typing_inspect.get_origin(typ)
643-
arguments = typing_inspect.get_args(typ, True)
661+
origin = typing_extensions.get_origin(typ)
662+
arguments = typing_extensions.get_args(typ)
644663
if origin:
645664
# Override base_schema.TYPE_MAPPING to change the class used for generic types below
646665
type_mapping = base_schema.TYPE_MAPPING if base_schema else {}
@@ -694,6 +713,46 @@ def _field_for_generic_type(
694713
**metadata,
695714
)
696715

716+
return None
717+
718+
719+
def _field_for_annotated_type(
720+
typ: type,
721+
**metadata: Any,
722+
) -> Optional[marshmallow.fields.Field]:
723+
"""
724+
If the type is an Annotated interface, resolve the arguments and construct the appropriate Field.
725+
"""
726+
origin = typing_extensions.get_origin(typ)
727+
arguments = typing_extensions.get_args(typ)
728+
if origin and origin is Annotated:
729+
marshmallow_annotations = [
730+
arg
731+
for arg in arguments[1:]
732+
if (inspect.isclass(arg) and issubclass(arg, marshmallow.fields.Field))
733+
or isinstance(arg, marshmallow.fields.Field)
734+
]
735+
if marshmallow_annotations:
736+
if len(marshmallow_annotations) > 1:
737+
warnings.warn(
738+
"Multiple marshmallow Field annotations found. Using the last one."
739+
)
740+
741+
field = marshmallow_annotations[-1]
742+
# Got a field instance, return as is. User must know what they're doing
743+
if isinstance(field, marshmallow.fields.Field):
744+
return field
745+
746+
return field(**metadata)
747+
return None
748+
749+
750+
def _field_for_union_type(
751+
typ: type,
752+
base_schema: Optional[Type[marshmallow.Schema]],
753+
**metadata: Any,
754+
) -> Optional[marshmallow.fields.Field]:
755+
arguments = typing_extensions.get_args(typ)
697756
if typing_inspect.is_union_type(typ):
698757
if typing_inspect.is_optional_type(typ):
699758
metadata["allow_none"] = metadata.get("allow_none", True)
@@ -806,6 +865,7 @@ def _field_for_schema(
806865
metadata.setdefault("allow_none", True)
807866
return marshmallow.fields.Raw(**metadata)
808867

868+
# i.e.: Literal['abc']
809869
if typing_inspect.is_literal_type(typ):
810870
arguments = typing_inspect.get_args(typ)
811871
return marshmallow.fields.Raw(
@@ -817,6 +877,7 @@ def _field_for_schema(
817877
**metadata,
818878
)
819879

880+
# i.e.: Final[str] = 'abc'
820881
if typing_inspect.is_final_type(typ):
821882
arguments = typing_inspect.get_args(typ)
822883
if arguments:
@@ -851,6 +912,14 @@ def _field_for_schema(
851912
subtyp = Any
852913
return _field_for_schema(subtyp, default, metadata, base_schema)
853914

915+
annotated_field = _field_for_annotated_type(typ, **metadata)
916+
if annotated_field:
917+
return annotated_field
918+
919+
union_field = _field_for_union_type(typ, base_schema, **metadata)
920+
if union_field:
921+
return union_field
922+
854923
# Generic types
855924
generic_field = _field_for_generic_type(typ, base_schema, **metadata)
856925
if generic_field:
@@ -869,14 +938,8 @@ def _field_for_schema(
869938
)
870939

871940
# enumerations
872-
if issubclass(typ, Enum):
873-
try:
874-
return marshmallow.fields.Enum(typ, **metadata)
875-
except AttributeError:
876-
# Remove this once support for python 3.6 is dropped.
877-
import marshmallow_enum
878-
879-
return marshmallow_enum.EnumField(typ, **metadata)
941+
if inspect.isclass(typ) and issubclass(typ, Enum):
942+
return marshmallow.fields.Enum(typ, **metadata)
880943

881944
# Nested marshmallow dataclass
882945
# it would be just a class name instead of actual schema util the schema is not ready yet
@@ -939,7 +1002,8 @@ def NewType(
9391002
field: Optional[Type[marshmallow.fields.Field]] = None,
9401003
**kwargs,
9411004
) -> Callable[[_U], _U]:
942-
"""NewType creates simple unique types
1005+
"""DEPRECATED: Use typing.Annotated instead.
1006+
NewType creates simple unique types
9431007
to which you can attach custom marshmallow attributes.
9441008
All the keyword arguments passed to this function will be transmitted
9451009
to the marshmallow field constructor.

marshmallow_dataclass/typing.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1+
import sys
2+
13
import marshmallow.fields
2-
from . import NewType
34

4-
Url = NewType("Url", str, field=marshmallow.fields.Url)
5-
Email = NewType("Email", str, field=marshmallow.fields.Email)
5+
if sys.version_info >= (3, 9):
6+
from typing import Annotated
7+
else:
8+
from typing_extensions import Annotated
9+
10+
Url = Annotated[str, marshmallow.fields.Url]
11+
Email = Annotated[str, marshmallow.fields.Email]
612

713
# Aliases
814
URL = Url

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,9 @@ target-version = ['py36', 'py37', 'py38', 'py39', 'py310', 'py310']
66
filterwarnings = [
77
"error:::marshmallow_dataclass|test",
88
]
9+
10+
[tool.coverage.report]
11+
exclude_also = [
12+
'^\s*\.\.\.\s*$',
13+
'^\s*pass\s*$',
14+
]

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from setuptools import setup, find_packages
1+
from setuptools import find_packages, setup
22

33
VERSION = "9.0.0"
44

tests/test_annotated.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import sys
2+
import unittest
3+
from typing import Optional
4+
5+
import marshmallow
6+
import marshmallow.fields
7+
8+
from marshmallow_dataclass import dataclass
9+
10+
if sys.version_info >= (3, 9):
11+
from typing import Annotated
12+
else:
13+
from typing_extensions import Annotated
14+
15+
16+
class TestAnnotatedField(unittest.TestCase):
17+
def test_annotated_field(self):
18+
@dataclass
19+
class AnnotatedValue:
20+
value: Annotated[str, marshmallow.fields.Email]
21+
default_string: Annotated[
22+
Optional[str], marshmallow.fields.String(load_default="Default String")
23+
] = None
24+
25+
schema = AnnotatedValue.Schema()
26+
27+
self.assertEqual(
28+
schema.load({"value": "test@test.com"}),
29+
AnnotatedValue(value="test@test.com", default_string="Default String"),
30+
)
31+
self.assertEqual(
32+
schema.load({"value": "test@test.com", "default_string": "override"}),
33+
AnnotatedValue(value="test@test.com", default_string="override"),
34+
)
35+
36+
with self.assertRaises(marshmallow.exceptions.ValidationError):
37+
schema.load({"value": "notavalidemail"})

0 commit comments

Comments
 (0)