diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index 0f0e735a8..0582a1b37 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -80,6 +80,11 @@ "Url", ] +_ProcessorT = typing.TypeVar( + "_ProcessorT", + bound=typing.Union[types.PostLoadCallable, types.PreLoadCallable, types.Validator], +) + class Field(FieldABC): """Base field from which other fields inherit. @@ -153,6 +158,12 @@ def __init__( data_key: str | None = None, attribute: str | None = None, validate: types.Validator | typing.Iterable[types.Validator] | None = None, + pre_load: ( + types.PreLoadCallable | typing.Iterable[types.PreLoadCallable] | None + ) = None, + post_load: ( + types.PostLoadCallable | typing.Iterable[types.PostLoadCallable] | None + ) = None, required: bool = False, allow_none: bool | None = None, load_only: bool = False, @@ -193,17 +204,9 @@ def __init__( self.attribute = attribute self.data_key = data_key self.validate = validate - if validate is None: - self.validators = [] - elif callable(validate): - self.validators = [validate] - elif utils.is_iterable_but_not_string(validate): - self.validators = list(validate) - else: - raise ValueError( - "The 'validate' parameter must be a callable " - "or a collection of callables." - ) + self.validators = self._normalize_processors(validate, param="validate") + self.pre_load = self._normalize_processors(pre_load, param="pre_load") + self.post_load = self._normalize_processors(post_load, param="post_load") # If allow_none is None and load_default is None # None should be considered valid by default @@ -369,10 +372,23 @@ def deserialize( if value is missing_: _miss = self.load_default return _miss() if callable(_miss) else _miss + + # Apply pre_load functions + for func in self.pre_load: + if func is not None: + value = func(value) + if self.allow_none and value is None: return None + output = self._deserialize(value, attr, data, **kwargs) + # Apply validators self._validate(output) + + # Apply post_load functions + for func in self.post_load: + if func is not None: + output = func(output) return output # Methods for concrete classes to override. @@ -484,6 +500,23 @@ def missing(self, value): ) self.load_default = value + @staticmethod + def _normalize_processors( + processors: _ProcessorT | typing.Iterable[_ProcessorT] | None, + *, + param: str, + ) -> list[_ProcessorT]: + if processors is None: + return [] + if callable(processors): + return [typing.cast(_ProcessorT, processors)] + if not utils.is_iterable_but_not_string(processors): + raise ValueError( + f"The '{param}' parameter must be a callable " + "or an iterable of callables." + ) + return list(processors) + class Raw(Field): """Field that applies no formatting.""" diff --git a/src/marshmallow/types.py b/src/marshmallow/types.py index 599f6b49e..054fa70cf 100644 --- a/src/marshmallow/types.py +++ b/src/marshmallow/types.py @@ -9,11 +9,17 @@ import typing +T = typing.TypeVar("T") + #: A type that can be either a sequence of strings or a set of strings StrSequenceOrSet = typing.Union[typing.Sequence[str], typing.AbstractSet[str]] #: Type for validator functions Validator = typing.Callable[[typing.Any], typing.Any] +#: Type for field-level pre-load functions +PreLoadCallable = typing.Callable[[typing.Any], typing.Any] +#: Type for field-level post-load functions +PostLoadCallable = typing.Callable[[T], T] class SchemaValidator(typing.Protocol): diff --git a/tests/test_fields.py b/tests/test_fields.py index a2763db35..749153a27 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -663,3 +663,125 @@ class Family(Schema): "daughter": {"value": {"age": ["Missing data for required field."]}} } } + + +class TestFieldPreAndPostLoad: + def test_field_pre_load(self): + class UserSchema(Schema): + name = fields.Str(pre_load=str) + + schema = UserSchema() + result = schema.load({"name": 808}) + assert result["name"] == "808" + + def test_field_pre_load_multiple(self): + def decrement(value): + return value - 1 + + def add_prefix(value): + return "test_" + value + + class UserSchema(Schema): + name = fields.Str(pre_load=[decrement, str, add_prefix]) + + schema = UserSchema() + result = schema.load({"name": 809}) + assert result["name"] == "test_808" + + def test_field_post_load(self): + class UserSchema(Schema): + age = fields.Int(post_load=str) + + schema = UserSchema() + result = schema.load({"age": 42}) + assert result["age"] == "42" + + def test_field_post_load_multiple(self): + def multiply_by_2(value): + return value * 2 + + def decrement(value): + return value - 1 + + class UserSchema(Schema): + age = fields.Float(post_load=[multiply_by_2, decrement]) + + schema = UserSchema() + result = schema.load({"age": 21.5}) + assert result["age"] == 42.0 + + def test_field_pre_and_post_load(self): + def multiply_by_2(value): + return value * 2 + + class UserSchema(Schema): + age = fields.Int(pre_load=[str.strip, int], post_load=[multiply_by_2]) + + schema = UserSchema() + result = schema.load({"age": " 21 "}) + assert result["age"] == 42 + + def test_field_pre_load_validation_error(self): + def always_fail(value): + raise ValidationError("oops") + + class UserSchema(Schema): + age = fields.Int(pre_load=always_fail) + + schema = UserSchema() + with pytest.raises(ValidationError) as exc: + schema.load({"age": 42}) + assert exc.value.messages == {"age": ["oops"]} + + def test_field_post_load_validation_error(self): + def always_fail(value): + raise ValidationError("oops") + + class UserSchema(Schema): + age = fields.Int(post_load=always_fail) + + schema = UserSchema() + with pytest.raises(ValidationError) as exc: + schema.load({"age": 42}) + assert exc.value.messages == {"age": ["oops"]} + + def test_field_pre_load_none(self): + def handle_none(value): + if value is None: + return 0 + return value + + class UserSchema(Schema): + age = fields.Int(pre_load=handle_none, allow_none=True) + + schema = UserSchema() + result = schema.load({"age": None}) + assert result["age"] == 0 + + def test_field_post_load_not_called_with_none_input_when_not_allowed(self): + def handle_none(value): + if value is None: + return 0 + return value + + class UserSchema(Schema): + age = fields.Int(post_load=handle_none, allow_none=False) + + schema = UserSchema() + with pytest.raises(ValidationError) as exc: + schema.load({"age": None}) + assert exc.value.messages == {"age": ["Field may not be null."]} + + def test_invalid_type_passed_to_pre_load(self): + with pytest.raises( + ValueError, + match="The 'pre_load' parameter must be a callable or an iterable of callables.", + ): + fields.Int(pre_load="not_callable") + + def test_invalid_type_passed_to_post_load(self): + with pytest.raises( + ValueError, + match="The 'post_load' parameter must be a callable or an iterable of callables.", + ): + fields.Int(post_load="not_callable")