Skip to content

Commit dede7fa

Browse files
committed
Add pre/post_load parameters to Field
1 parent ea26aeb commit dede7fa

File tree

3 files changed

+173
-11
lines changed

3 files changed

+173
-11
lines changed

src/marshmallow/fields.py

+45-11
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@
8080
"Url",
8181
]
8282

83+
_ProcessorT = typing.TypeVar(
84+
"_ProcessorT",
85+
bound=typing.Union[types.PostLoadCallable, types.PreLoadCallable, types.Validator],
86+
)
87+
8388

8489
class Field(FieldABC):
8590
"""Base field from which other fields inherit.
@@ -153,6 +158,12 @@ def __init__(
153158
data_key: str | None = None,
154159
attribute: str | None = None,
155160
validate: types.Validator | typing.Iterable[types.Validator] | None = None,
161+
pre_load: types.PreLoadCallable
162+
| typing.Iterable[types.PreLoadCallable]
163+
| None = None,
164+
post_load: types.PostLoadCallable
165+
| typing.Iterable[types.PostLoadCallable]
166+
| None = None,
156167
required: bool = False,
157168
allow_none: bool | None = None,
158169
load_only: bool = False,
@@ -193,17 +204,9 @@ def __init__(
193204
self.attribute = attribute
194205
self.data_key = data_key
195206
self.validate = validate
196-
if validate is None:
197-
self.validators = []
198-
elif callable(validate):
199-
self.validators = [validate]
200-
elif utils.is_iterable_but_not_string(validate):
201-
self.validators = list(validate)
202-
else:
203-
raise ValueError(
204-
"The 'validate' parameter must be a callable "
205-
"or a collection of callables."
206-
)
207+
self.validators = self._normalize_processors(validate, param="validate")
208+
self.pre_load = self._normalize_processors(pre_load, param="pre_load")
209+
self.post_load = self._normalize_processors(post_load, param="post_load")
207210

208211
# If allow_none is None and load_default is None
209212
# None should be considered valid by default
@@ -369,10 +372,23 @@ def deserialize(
369372
if value is missing_:
370373
_miss = self.load_default
371374
return _miss() if callable(_miss) else _miss
375+
376+
# Apply pre_load functions
377+
for func in self.pre_load:
378+
if func is not None:
379+
value = func(value)
380+
372381
if self.allow_none and value is None:
373382
return None
383+
374384
output = self._deserialize(value, attr, data, **kwargs)
385+
# Apply validators
375386
self._validate(output)
387+
388+
# Apply post_load functions
389+
for func in self.post_load:
390+
if func is not None:
391+
output = func(output)
376392
return output
377393

378394
# Methods for concrete classes to override.
@@ -484,6 +500,24 @@ def missing(self, value):
484500
)
485501
self.load_default = value
486502

503+
@staticmethod
504+
def _normalize_processors(
505+
processors: _ProcessorT | typing.Iterable[_ProcessorT] | None,
506+
*,
507+
param: str,
508+
) -> list[_ProcessorT]:
509+
"""Convert processor(s) to a tuple of callables."""
510+
if processors is None:
511+
return []
512+
if callable(processors):
513+
return [typing.cast(_ProcessorT, processors)]
514+
if not utils.is_iterable_but_not_string(processors):
515+
raise ValueError(
516+
f"The '{param}' parameter must be a callable "
517+
"or an iterable of callables."
518+
)
519+
return list(processors)
520+
487521

488522
class Raw(Field):
489523
"""Field that applies no formatting."""

src/marshmallow/types.py

+6
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,17 @@
99

1010
import typing
1111

12+
T = typing.TypeVar("T")
13+
1214
#: A type that can be either a sequence of strings or a set of strings
1315
StrSequenceOrSet = typing.Union[typing.Sequence[str], typing.AbstractSet[str]]
1416

1517
#: Type for validator functions
1618
Validator = typing.Callable[[typing.Any], typing.Any]
19+
#: Type for field-level pre-load functions
20+
PreLoadCallable = typing.Callable[[typing.Any], typing.Any]
21+
#: Type for field-level post-load functions
22+
PostLoadCallable = typing.Callable[[T], T]
1723

1824

1925
class SchemaValidator(typing.Protocol):

tests/test_fields.py

+122
Original file line numberDiff line numberDiff line change
@@ -663,3 +663,125 @@ class Family(Schema):
663663
"daughter": {"value": {"age": ["Missing data for required field."]}}
664664
}
665665
}
666+
667+
668+
class TestFieldPreAndPostLoad:
669+
def test_field_pre_load(self):
670+
class UserSchema(Schema):
671+
name = fields.Str(pre_load=str)
672+
673+
schema = UserSchema()
674+
result = schema.load({"name": 808})
675+
assert result["name"] == "808"
676+
677+
def test_field_pre_load_multiple(self):
678+
def decrement(value):
679+
return value - 1
680+
681+
def add_prefix(value):
682+
return "test_" + value
683+
684+
class UserSchema(Schema):
685+
name = fields.Str(pre_load=[decrement, str, add_prefix])
686+
687+
schema = UserSchema()
688+
result = schema.load({"name": 809})
689+
assert result["name"] == "test_808"
690+
691+
def test_field_post_load(self):
692+
class UserSchema(Schema):
693+
age = fields.Int(post_load=str)
694+
695+
schema = UserSchema()
696+
result = schema.load({"age": 42})
697+
assert result["age"] == "42"
698+
699+
def test_field_post_load_multiple(self):
700+
def to_string(value):
701+
return str(value)
702+
703+
def add_suffix(value):
704+
return value + " years"
705+
706+
class UserSchema(Schema):
707+
age = fields.Int(post_load=[to_string, add_suffix])
708+
709+
schema = UserSchema()
710+
result = schema.load({"age": 42})
711+
assert result["age"] == "42 years"
712+
713+
def test_field_pre_and_post_load(self):
714+
def multiply_by_2(value):
715+
return value * 2
716+
717+
class UserSchema(Schema):
718+
age = fields.Int(pre_load=[str.strip, int], post_load=[multiply_by_2])
719+
720+
schema = UserSchema()
721+
result = schema.load({"age": " 21 "})
722+
assert result["age"] == 42
723+
724+
def test_field_pre_load_validation_error(self):
725+
def always_fail(value):
726+
raise ValidationError("oops")
727+
728+
class UserSchema(Schema):
729+
age = fields.Int(pre_load=always_fail)
730+
731+
schema = UserSchema()
732+
with pytest.raises(ValidationError) as exc:
733+
schema.load({"age": 42})
734+
assert exc.value.messages == {"age": ["oops"]}
735+
736+
def test_field_post_load_validation_error(self):
737+
def always_fail(value):
738+
raise ValidationError("oops")
739+
740+
class UserSchema(Schema):
741+
age = fields.Int(post_load=always_fail)
742+
743+
schema = UserSchema()
744+
with pytest.raises(ValidationError) as exc:
745+
schema.load({"age": 42})
746+
assert exc.value.messages == {"age": ["oops"]}
747+
748+
def test_field_pre_load_none(self):
749+
def handle_none(value):
750+
if value is None:
751+
return 0
752+
return value
753+
754+
class UserSchema(Schema):
755+
age = fields.Int(pre_load=handle_none, allow_none=True)
756+
757+
schema = UserSchema()
758+
result = schema.load({"age": None})
759+
assert result["age"] == 0
760+
761+
def test_field_post_load_not_called_with_none_input_when_not_allowed(self):
762+
def handle_none(value):
763+
if value is None:
764+
return 0
765+
return value
766+
767+
class UserSchema(Schema):
768+
age = fields.Int(post_load=handle_none, allow_none=False)
769+
770+
schema = UserSchema()
771+
with pytest.raises(ValidationError) as exc:
772+
schema.load({"age": None})
773+
assert exc.value.messages == {"age": ["Field may not be null."]}
774+
775+
def test_invalid_type_passed_to_pre_load(self):
776+
with pytest.raises(
777+
ValueError,
778+
match="The 'pre_load' parameter must be a callable or an iterable of callables.",
779+
):
780+
fields.Int(pre_load="not_callable")
781+
782+
def test_invalid_type_passed_to_post_load(self):
783+
with pytest.raises(
784+
ValueError,
785+
match="The 'post_load' parameter must be a callable or an iterable of callables.",
786+
):
787+
fields.Int(post_load="not_callable")

0 commit comments

Comments
 (0)