Skip to content

Commit 8797b2b

Browse files
committed
Rename our is_generic_type and only utilize where absolutely necessary
1 parent 78fcd4a commit 8797b2b

File tree

3 files changed

+43
-16
lines changed

3 files changed

+43
-16
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ class User:
7272
UnboundTypeVarError,
7373
get_generic_dataclass_fields,
7474
is_generic_alias,
75-
is_generic_type,
7675
)
7776
from marshmallow_dataclass.lazy_class_attribute import lazy_class_attribute
7877

@@ -586,7 +585,7 @@ def _internal_class_schema(
586585

587586
# Update the schema members to contain marshmallow fields instead of dataclass fields
588587
type_hints = {}
589-
if not is_generic_type(clazz):
588+
if not typing_inspect.is_generic_type(clazz):
590589
type_hints = _get_type_hints(clazz, schema_ctx)
591590

592591
attributes.update(
@@ -595,7 +594,7 @@ def _internal_class_schema(
595594
_field_for_schema(
596595
(
597596
type_hints[field.name]
598-
if not is_generic_type(clazz)
597+
if not typing_inspect.is_generic_type(clazz)
599598
else _resolve_forward_type_refs(field.type, schema_ctx)
600599
),
601600
_get_field_default(field),
@@ -758,7 +757,10 @@ def _field_for_annotated_type(
758757
for arg in arguments[1:]
759758
if _is_marshmallow_field(arg)
760759
# Support `CustomGenericField[mf.String]`
761-
or (is_generic_type(arg) and _is_marshmallow_field(get_origin(arg)))
760+
or (
761+
typing_inspect.is_generic_type(arg)
762+
and _is_marshmallow_field(get_origin(arg))
763+
)
762764
]
763765
if marshmallow_annotations:
764766
if len(marshmallow_annotations) > 1:
@@ -1069,7 +1071,7 @@ class X:
10691071

10701072

10711073
def _dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]:
1072-
if not is_generic_type(clazz):
1074+
if not typing_inspect.is_generic_type(clazz):
10731075
return dataclasses.fields(clazz)
10741076

10751077
else:

marshmallow_dataclass/generic_resolver.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,18 @@ class A(Generic[T]):
7979
8080
``A[int]`` is a _generic alias_ (while ``A`` is a *generic type*, but not a *generic alias*).
8181
"""
82-
is_generic = is_generic_type(clazz)
82+
is_generic = typing_inspect.is_generic_type(clazz)
8383
type_arguments = get_args(clazz)
8484
return is_generic and len(type_arguments) > 0
8585

8686

87-
def is_generic_type(clazz: type) -> bool:
87+
def may_contain_typevars(clazz: type) -> bool:
8888
"""
89-
typing_inspect.is_generic_type explicitly ignores Union and Tuple
89+
Check if the class can contain typevars. This includes Special Forms.
90+
91+
Different from typing_inspect.is_generic_type as that explicitly ignores Union and Tuple.
92+
93+
We still need to resolve typevars for Union and Tuple
9094
"""
9195
origin = get_origin(clazz)
9296
return origin is not Annotated and (
@@ -141,14 +145,18 @@ def _resolve_typevars(clazz: type) -> Dict[type, Dict[TypeVar, _Future]]:
141145
def _replace_typevars(
142146
clazz: type, resolved_generics: Optional[Dict[TypeVar, _Future]] = None
143147
) -> type:
144-
if not resolved_generics or inspect.isclass(clazz) or not is_generic_type(clazz):
148+
if (
149+
not resolved_generics
150+
or inspect.isclass(clazz)
151+
or not may_contain_typevars(clazz)
152+
):
145153
return clazz
146154

147155
return clazz.copy_with( # type: ignore
148156
tuple(
149157
(
150158
_replace_typevars(arg, resolved_generics)
151-
if is_generic_type(arg)
159+
if may_contain_typevars(arg)
152160
else (
153161
resolved_generics[arg].result() if arg in resolved_generics else arg
154162
)
@@ -179,7 +187,7 @@ def get_generic_dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]:
179187
# Either the first time we see this field, or it got overridden
180188
# If it's a class we handle it later as a Nested. Nothing to resolve now.
181189
new_field = field
182-
if not inspect.isclass(field.type) and is_generic_type(field.type):
190+
if not inspect.isclass(field.type) and may_contain_typevars(field.type):
183191
new_field = copy.copy(field)
184192
new_field.type = _replace_typevars(
185193
field.type, resolved_typevars[subclass]

tests/test_generics.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sys
44
import typing
55
import unittest
6+
from typing_inspect import is_generic_type
67

78
import marshmallow.fields
89
from marshmallow import ValidationError
@@ -14,7 +15,6 @@
1415
dataclass,
1516
is_generic_alias_of_dataclass,
1617
)
17-
from marshmallow_dataclass.generic_resolver import is_generic_type
1818

1919
if sys.version_info >= (3, 9):
2020
from typing import Annotated
@@ -319,12 +319,29 @@ def test_generic_dataclass_with_forwardref(self):
319319
T = typing.TypeVar("T")
320320

321321
@dataclasses.dataclass
322-
class SimpleGeneric(typing.Generic[T]):
322+
class ForwardGeneric(typing.Generic[T]):
323323
data: T
324324

325-
schema_s = class_schema(SimpleGeneric["str"])()
326-
self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"}))
327-
self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"})
325+
schema_s = class_schema(ForwardGeneric["str"])()
326+
self.assertEqual(ForwardGeneric(data="a"), schema_s.load({"data": "a"}))
327+
self.assertEqual(schema_s.dump(ForwardGeneric(data="a")), {"data": "a"})
328+
with self.assertRaises(ValidationError):
329+
schema_s.load({"data": 2})
330+
331+
def test_generic_dataclass_with_optional(self):
332+
T = typing.TypeVar("T")
333+
334+
@dataclasses.dataclass
335+
class OptionalGeneric(typing.Generic[T]):
336+
data: typing.Optional[T]
337+
338+
schema_s = class_schema(OptionalGeneric["str"])()
339+
self.assertEqual(OptionalGeneric(data="a"), schema_s.load({"data": "a"}))
340+
self.assertEqual(schema_s.dump(OptionalGeneric(data="a")), {"data": "a"})
341+
342+
self.assertEqual(OptionalGeneric(data=None), schema_s.load({}))
343+
self.assertEqual(schema_s.dump(OptionalGeneric(data=None)), {"data": None})
344+
328345
with self.assertRaises(ValidationError):
329346
schema_s.load({"data": 2})
330347

0 commit comments

Comments
 (0)