Skip to content

Add pre/post_load parameters to Field #2799

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: 3.x-line
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 44 additions & 11 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@
"Url",
]

_ProcessorT = typing.TypeVar(
"_ProcessorT",
bound=typing.Union[types.PostLoadCallable, types.PreLoadCallable, types.Validator],
)


class Field(FieldABC):
"""Base field from which other fields inherit.
Expand Down Expand Up @@ -153,6 +158,12 @@ def __init__(
data_key: str | None = None,
attribute: str | None = None,
validate: types.Validator | typing.Iterable[types.Validator] | None = None,
pre_load: (
types.PreLoadCallable | typing.Iterable[types.PreLoadCallable] | None
) = None,
post_load: (
types.PostLoadCallable | typing.Iterable[types.PostLoadCallable] | None
) = None,
required: bool = False,
allow_none: bool | None = None,
load_only: bool = False,
Expand Down Expand Up @@ -193,17 +204,9 @@ def __init__(
self.attribute = attribute
self.data_key = data_key
self.validate = validate
if validate is None:
self.validators = []
elif callable(validate):
self.validators = [validate]
elif utils.is_iterable_but_not_string(validate):
self.validators = list(validate)
else:
raise ValueError(
"The 'validate' parameter must be a callable "
"or a collection of callables."
)
self.validators = self._normalize_processors(validate, param="validate")
self.pre_load = self._normalize_processors(pre_load, param="pre_load")
self.post_load = self._normalize_processors(post_load, param="post_load")

# If allow_none is None and load_default is None
# None should be considered valid by default
Expand Down Expand Up @@ -369,10 +372,23 @@ def deserialize(
if value is missing_:
_miss = self.load_default
return _miss() if callable(_miss) else _miss

# Apply pre_load functions
for func in self.pre_load:
if func is not None:
value = func(value)

if self.allow_none and value is None:
return None

output = self._deserialize(value, attr, data, **kwargs)
# Apply validators
self._validate(output)

# Apply post_load functions
for func in self.post_load:
if func is not None:
output = func(output)
return output

# Methods for concrete classes to override.
Expand Down Expand Up @@ -484,6 +500,23 @@ def missing(self, value):
)
self.load_default = value

@staticmethod
def _normalize_processors(
processors: _ProcessorT | typing.Iterable[_ProcessorT] | None,
*,
param: str,
) -> list[_ProcessorT]:
if processors is None:
return []
if callable(processors):
return [typing.cast(_ProcessorT, processors)]
if not utils.is_iterable_but_not_string(processors):
raise ValueError(
f"The '{param}' parameter must be a callable "
"or an iterable of callables."
)
return list(processors)


class Raw(Field):
"""Field that applies no formatting."""
Expand Down
6 changes: 6 additions & 0 deletions src/marshmallow/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,17 @@

import typing

T = typing.TypeVar("T")

#: A type that can be either a sequence of strings or a set of strings
StrSequenceOrSet = typing.Union[typing.Sequence[str], typing.AbstractSet[str]]

#: Type for validator functions
Validator = typing.Callable[[typing.Any], typing.Any]
#: Type for field-level pre-load functions
PreLoadCallable = typing.Callable[[typing.Any], typing.Any]
#: Type for field-level post-load functions
PostLoadCallable = typing.Callable[[T], T]


class SchemaValidator(typing.Protocol):
Expand Down
122 changes: 122 additions & 0 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,3 +663,125 @@ class Family(Schema):
"daughter": {"value": {"age": ["Missing data for required field."]}}
}
}


class TestFieldPreAndPostLoad:
def test_field_pre_load(self):
class UserSchema(Schema):
name = fields.Str(pre_load=str)

schema = UserSchema()
result = schema.load({"name": 808})
assert result["name"] == "808"

def test_field_pre_load_multiple(self):
def decrement(value):
return value - 1

def add_prefix(value):
return "test_" + value

class UserSchema(Schema):
name = fields.Str(pre_load=[decrement, str, add_prefix])

schema = UserSchema()
result = schema.load({"name": 809})
assert result["name"] == "test_808"

def test_field_post_load(self):
class UserSchema(Schema):
age = fields.Int(post_load=str)

schema = UserSchema()
result = schema.load({"age": 42})
assert result["age"] == "42"

def test_field_post_load_multiple(self):
def multiply_by_2(value):
return value * 2

def decrement(value):
return value - 1

class UserSchema(Schema):
age = fields.Float(post_load=[multiply_by_2, decrement])

schema = UserSchema()
result = schema.load({"age": 21.5})
assert result["age"] == 42.0

def test_field_pre_and_post_load(self):
def multiply_by_2(value):
return value * 2

class UserSchema(Schema):
age = fields.Int(pre_load=[str.strip, int], post_load=[multiply_by_2])

schema = UserSchema()
result = schema.load({"age": " 21 "})
assert result["age"] == 42

def test_field_pre_load_validation_error(self):
def always_fail(value):
raise ValidationError("oops")

class UserSchema(Schema):
age = fields.Int(pre_load=always_fail)

schema = UserSchema()
with pytest.raises(ValidationError) as exc:
schema.load({"age": 42})
assert exc.value.messages == {"age": ["oops"]}

def test_field_post_load_validation_error(self):
def always_fail(value):
raise ValidationError("oops")

class UserSchema(Schema):
age = fields.Int(post_load=always_fail)

schema = UserSchema()
with pytest.raises(ValidationError) as exc:
schema.load({"age": 42})
assert exc.value.messages == {"age": ["oops"]}

def test_field_pre_load_none(self):
def handle_none(value):
if value is None:
return 0
return value

class UserSchema(Schema):
age = fields.Int(pre_load=handle_none, allow_none=True)

schema = UserSchema()
result = schema.load({"age": None})
assert result["age"] == 0

def test_field_post_load_not_called_with_none_input_when_not_allowed(self):
def handle_none(value):
if value is None:
return 0
return value

class UserSchema(Schema):
age = fields.Int(post_load=handle_none, allow_none=False)

schema = UserSchema()
with pytest.raises(ValidationError) as exc:
schema.load({"age": None})
assert exc.value.messages == {"age": ["Field may not be null."]}

def test_invalid_type_passed_to_pre_load(self):
with pytest.raises(
ValueError,
match="The 'pre_load' parameter must be a callable or an iterable of callables.",
):
fields.Int(pre_load="not_callable")

def test_invalid_type_passed_to_post_load(self):
with pytest.raises(
ValueError,
match="The 'post_load' parameter must be a callable or an iterable of callables.",
):
fields.Int(post_load="not_callable")
Loading