Skip to content

Commit fae55c3

Browse files
committed
Move Annotated and Union handling to their own functions
1 parent 7eb9281 commit fae55c3

File tree

1 file changed

+43
-15
lines changed

1 file changed

+43
-15
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -664,21 +664,6 @@ def _field_for_generic_type(
664664
# Override base_schema.TYPE_MAPPING to change the class used for generic types below
665665
type_mapping = base_schema.TYPE_MAPPING if base_schema else {}
666666

667-
if origin is Annotated:
668-
marshmallow_annotations = [
669-
arg
670-
for arg in arguments[1:]
671-
if (inspect.isclass(arg) and issubclass(arg, marshmallow.fields.Field))
672-
or isinstance(arg, marshmallow.fields.Field)
673-
]
674-
if marshmallow_annotations:
675-
field = marshmallow_annotations[-1]
676-
# Got a field instance, return as is. User must know what they're doing
677-
if isinstance(field, marshmallow.fields.Field):
678-
return field
679-
680-
return field(**metadata)
681-
682667
if origin in (list, List):
683668
child_type = _field_for_schema(arguments[0], base_schema=base_schema)
684669
list_type = cast(
@@ -728,6 +713,41 @@ def _field_for_generic_type(
728713
**metadata,
729714
)
730715

716+
return None
717+
718+
719+
def _field_for_annotated_type(
720+
typ: type,
721+
**metadata: Any,
722+
) -> Optional[marshmallow.fields.Field]:
723+
"""
724+
If the type is an Annotated interface, resolve the arguments and construct the appropriate Field.
725+
"""
726+
origin = typing_extensions.get_origin(typ)
727+
arguments = typing_extensions.get_args(typ)
728+
if origin and origin is Annotated:
729+
marshmallow_annotations = [
730+
arg
731+
for arg in arguments[1:]
732+
if (inspect.isclass(arg) and issubclass(arg, marshmallow.fields.Field))
733+
or isinstance(arg, marshmallow.fields.Field)
734+
]
735+
if marshmallow_annotations:
736+
field = marshmallow_annotations[-1]
737+
# Got a field instance, return as is. User must know what they're doing
738+
if isinstance(field, marshmallow.fields.Field):
739+
return field
740+
741+
return field(**metadata)
742+
return None
743+
744+
745+
def _field_for_union_type(
746+
typ: type,
747+
base_schema: Optional[Type[marshmallow.Schema]],
748+
**metadata: Any,
749+
) -> Optional[marshmallow.fields.Field]:
750+
arguments = typing_extensions.get_args(typ)
731751
if typing_inspect.is_union_type(typ):
732752
if typing_inspect.is_optional_type(typ):
733753
metadata["allow_none"] = metadata.get("allow_none", True)
@@ -887,6 +907,14 @@ def _field_for_schema(
887907
subtyp = Any
888908
return _field_for_schema(subtyp, default, metadata, base_schema)
889909

910+
annotated_field = _field_for_annotated_type(typ, **metadata)
911+
if annotated_field:
912+
return annotated_field
913+
914+
union_field = _field_for_union_type(typ, base_schema, **metadata)
915+
if union_field:
916+
return union_field
917+
890918
# Generic types
891919
generic_field = _field_for_generic_type(typ, base_schema, **metadata)
892920
if generic_field:

0 commit comments

Comments
 (0)