From e14c94d9ee3f8a9d6e970d27f3b021c7456a8bd1 Mon Sep 17 00:00:00 2001 From: Martin Date: Sat, 6 Jul 2024 16:14:39 +0200 Subject: [PATCH 1/4] Add support for deeply nested models This goes beyond the first level of nesting like List[Model], Tuple[Model]. It makes stuff like List[Tuple[str, Tuple[int, Model, bool]]] possible --- README.md | 32 ++ pydantic_redis/_shared/lua_scripts.py | 101 +++--- pydantic_redis/_shared/model/base.py | 316 ++++++++++++------- pydantic_redis/_shared/model/insert_utils.py | 201 +++++------- pydantic_redis/_shared/model/prop_utils.py | 8 - pydantic_redis/_shared/model/select_utils.py | 23 +- test/test_async_pydantic_redis.py | 2 +- test/test_pydantic_redis.py | 2 +- 8 files changed, 388 insertions(+), 297 deletions(-) diff --git a/README.md b/README.md index f64ee66e..15094d58 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,38 @@ benchmark_bulk_insert[redis_store] 721.2247 (6.19) 6 --------------------------------------------------------------------------------------------------------------------- ``` +# >=v0.7 (with fully-fledged nested models) + +``` +--------------------------------------------------- benchmark: 22 tests ---------------------------------------------------- +Name (time in us) Mean Min Max +---------------------------------------------------------------------------------------------------------------------------- +test_benchmark_delete[redis_store-Wuthering Heights] 124.5440 (1.01) 109.3710 (1.0) 579.7810 (1.39) +test_benchmark_bulk_delete[redis_store] 122.9285 (1.0) 113.7120 (1.04) 492.2730 (1.18) +test_benchmark_select_columns_for_one_id[redis_store-book1] 182.3891 (1.48) 154.9150 (1.42) 441.2820 (1.06) +test_benchmark_select_columns_for_one_id[redis_store-book2] 183.2679 (1.49) 156.6830 (1.43) 462.6000 (1.11) +test_benchmark_select_columns_for_one_id[redis_store-book0] 181.6972 (1.48) 157.2330 (1.44) 459.2930 (1.10) +test_benchmark_select_columns_for_one_id[redis_store-book3] 183.0834 (1.49) 160.1250 (1.46) 416.8570 (1.0) +test_benchmark_select_all_for_one_id[redis_store-book1] 203.9491 (1.66) 183.3080 (1.68) 469.4700 (1.13) +test_benchmark_select_all_for_one_id[redis_store-book2] 206.7124 (1.68) 184.1920 (1.68) 490.6700 (1.18) +test_benchmark_select_all_for_one_id[redis_store-book0] 207.3341 (1.69) 184.2210 (1.68) 443.9260 (1.06) +test_benchmark_select_all_for_one_id[redis_store-book3] 210.6874 (1.71) 185.0600 (1.69) 696.9330 (1.67) +test_benchmark_select_columns_for_some_items[redis_store] 236.5783 (1.92) 215.7490 (1.97) 496.0540 (1.19) +test_benchmark_select_columns_paginated[redis_store] 248.5335 (2.02) 218.3450 (2.00) 522.1270 (1.25) +test_benchmark_update[redis_store-Wuthering Heights-data0] 282.1803 (2.30) 239.5410 (2.19) 541.5220 (1.30) +test_benchmark_select_some_items[redis_store] 298.2036 (2.43) 264.0860 (2.41) 599.3010 (1.44) +test_benchmark_single_insert[redis_store-book0] 316.0245 (2.57) 269.8110 (2.47) 596.0940 (1.43) +test_benchmark_single_insert[redis_store-book2] 314.1899 (2.56) 270.9780 (2.48) 560.5280 (1.34) +test_benchmark_select_default_paginated[redis_store] 305.2798 (2.48) 277.8170 (2.54) 550.5110 (1.32) +test_benchmark_single_insert[redis_store-book1] 312.5839 (2.54) 279.5660 (2.56) 578.7070 (1.39) +test_benchmark_single_insert[redis_store-book3] 316.9207 (2.58) 284.8630 (2.60) 567.0120 (1.36) +test_benchmark_select_columns[redis_store] 369.1538 (3.00) 331.5770 (3.03) 666.0470 (1.60) +test_benchmark_select_default[redis_store] 553.9420 (4.51) 485.3700 (4.44) 1,235.8540 (2.96) +test_benchmark_bulk_insert[redis_store] 777.4058 (6.32) 730.4280 (6.68) 1,012.7780 (2.43) +---------------------------------------------------------------------------------------------------------------------------- + +``` + ## Contributions Contributions are welcome. The docs have to maintained, the code has to be made cleaner, more idiomatic and faster, diff --git a/pydantic_redis/_shared/lua_scripts.py b/pydantic_redis/_shared/lua_scripts.py index 433c8646..5b052fee 100644 --- a/pydantic_redis/_shared/lua_scripts.py +++ b/pydantic_redis/_shared/lua_scripts.py @@ -39,13 +39,18 @@ for i, k in ipairs(value) do if not (i % 2 == 0) then if startswith(k, '___') or startswith(k, '____') then - local nested = {} - - for v in s_gmatch(value[i + 1], '([^%[^,^%]^\"]+)') do - table_insert(nested, get_obj(v)) + if value[i + 1] == 'null' then + value[i + 1] = 'null' + else + local nested = {} + + for v in s_gmatch(value[i + 1], '([^%[^,^%]^\"]+)') do + table_insert(nested, get_obj(v)) + end + + value[i + 1] = nested end - - value[i + 1] = nested + value[i] = trim_dunder(k) elseif startswith(k, '__') then value[i + 1] = get_obj(value[i + 1]) @@ -95,13 +100,18 @@ for i, k in ipairs(value) do if not (i % 2 == 0) then if startswith(k, '___') or startswith(k, '____') then - local nested = {} - - for v in s_gmatch(value[i + 1], '([^%[^,^%]^\"]+)') do - table_insert(nested, get_obj(v)) + if value[i + 1] == 'null' then + value[i + 1] = 'null' + else + local nested = {} + + for v in s_gmatch(value[i + 1], '([^%[^,^%]^\"]+)') do + table_insert(nested, get_obj(v)) + end + + value[i + 1] = nested end - - value[i + 1] = nested + value[i] = trim_dunder(k) elseif startswith(k, '__') then value[i + 1] = get_obj(value[i + 1]) @@ -158,13 +168,18 @@ for i, k in ipairs(value) do if not (i % 2 == 0) then if startswith(k, '___') or startswith(k, '____') then - local nested = {} - - for v in s_gmatch(value[i + 1], '([^%[^,^%]^\"]+)') do - table_insert(nested, get_obj(v)) + if value[i + 1] == 'null' then + value[i + 1] = 'null' + else + local nested = {} + + for v in s_gmatch(value[i + 1], '([^%[^,^%]^\"]+)') do + table_insert(nested, get_obj(v)) + end + + value[i + 1] = nested end - - value[i + 1] = nested + value[i] = trim_dunder(k) elseif startswith(k, '__') then value[i + 1] = get_obj(value[i + 1]) @@ -217,13 +232,17 @@ for i, k in ipairs(value) do if not (i % 2 == 0) then if startswith(k, '___') or startswith(k, '____') then - local nested = {} - - for v in s_gmatch(value[i + 1], '([^%[^,^%]^\"]+)') do - table_insert(nested, get_obj(v)) + if value[i + 1] == 'null' then + value[i + 1] = 'null' + else + local nested = {} + + for v in s_gmatch(value[i + 1], '([^%[^,^%]^\"]+)') do + table_insert(nested, get_obj(v)) + end + + value[i + 1] = nested end - - value[i + 1] = nested value[i] = trim_dunder(k) elseif startswith(k, '__') then value[i + 1] = get_obj(value[i + 1]) @@ -288,13 +307,18 @@ for i, k in ipairs(value) do if not (i % 2 == 0) then if startswith(k, '___') or startswith(k, '____') then - local nested = {} - - for v in s_gmatch(value[i + 1], '([^%[^,^%]^\"]+)') do - table_insert(nested, get_obj(v)) + if value[i + 1] == 'null' then + value[i + 1] = 'null' + else + local nested = {} + + for v in s_gmatch(value[i + 1], '([^%[^,^%]^\"]+)') do + table_insert(nested, get_obj(v)) + end + + value[i + 1] = nested end - - value[i + 1] = nested + value[i] = trim_dunder(k) elseif startswith(k, '__') then value[i + 1] = get_obj(value[i + 1]) @@ -367,13 +391,18 @@ for i, k in ipairs(value) do if not (i % 2 == 0) then if startswith(k, '___') or startswith(k, '____') then - local nested = {} - - for v in s_gmatch(value[i + 1], '([^%[^,^%]^\"]+)') do - table_insert(nested, get_obj(v)) + if value[i + 1] == 'null' then + value[i + 1] = 'null' + else + local nested = {} + + for v in s_gmatch(value[i + 1], '([^%[^,^%]^\"]+)') do + table_insert(nested, get_obj(v)) + end + + value[i + 1] = nested end - - value[i + 1] = nested + value[i] = trim_dunder(k) elseif startswith(k, '__') then value[i + 1] = get_obj(value[i + 1]) diff --git a/pydantic_redis/_shared/model/base.py b/pydantic_redis/_shared/model/base.py index 2d52d501..b7b7c580 100644 --- a/pydantic_redis/_shared/model/base.py +++ b/pydantic_redis/_shared/model/base.py @@ -1,12 +1,17 @@ """Exposes the Base `Model` common to both async and sync APIs +Attributes: + NESTED_MODEL_PREFIX (str): the prefix for fields with single nested models + NESTED_MODEL_LIST_FIELD_PREFIX (str): the prefix for fields with lists of nested models + NESTED_MODEL_TUPLE_FIELD_PREFIX (str): the prefix for fields with tuples of nested models + NESTED_MODEL_DICT_FIELD_PREFIX (str): the prefix for fields with dicts of nested models """ +import enum import typing from typing import Dict, Tuple, Any, Type, Union, List, Optional from pydantic import ConfigDict, BaseModel -from pydantic.fields import ModelPrivateAttr from pydantic_redis._shared.utils import ( typing_get_origin, @@ -20,6 +25,34 @@ from ..store import AbstractStore +NESTED_MODEL_PREFIX = "__" +NESTED_MODEL_LIST_FIELD_PREFIX = "___" +NESTED_MODEL_TUPLE_FIELD_PREFIX = "____" +NESTED_MODEL_DICT_FIELD_PREFIX = "_____" + + +class NestingType(int, enum.Enum): + """The type of nesting that can happen especially for nested models""" + + ON_ROOT = 0 + IN_LIST = 1 + IN_TUPLE = 2 + IN_DICT = 3 + IN_UNION = 4 + + +# the type describing a tree for traversing types that form an aggregate type with a possibility +# of nested types and models +# Note: AbstractModel types are treated special. The first item in the tuple declares if +# the type on that tree node has a nested model, and of which type of nesting +# +# Note: (None, (Any)) corresponds to a type that is not a nested model +# (IN_LIST, (AbstractModel)) corresponds to List[AbstractModel] +# (None, (str)) corresponds to str +# (IN_TUPLE, (str, AbstractModel)) corresponds to Tuple[str, AbstractModel] +# (IN_LIST, (IN_TUPLE, (str, AbstractModel))) corresponds to List[Tuple[str, AbstractModel]] +AggTypeTree = Tuple[Optional[NestingType], Tuple[Union[Type["AbstractModel"], Any], ...]] # type: ignore + class AbstractModel(BaseModel): """A base class for all Models, sync and async alike. @@ -32,21 +65,18 @@ class AbstractModel(BaseModel): _field_types (Dict[str, Any]): a mapping of the fields and their types for the current model _store (AbstractStore): the Store in which the current model is registered. - _nested_model_tuple_fields (Dict[str, Tuple[Any, ...]]): a mapping of - fields and their types for fields that have tuples of nested models - _nested_model_list_fields (Dict[str, Type["AbstractModel"]]): a mapping of - fields and their associated nested models for fields that have - lists of nested models - _nested_model_fields (Dict[str, Type["AbstractModel"]]): a mapping of - fields and their associated nested models for fields that have nested models + _field_type_trees (Dict[str, Optional[AggTypeTree]]): a mapping of + fields and their associated trees of types forming their aggregate types + _strict (bool): Whether the model should be very strict on its types. By default, a + moderate level of strictness is imposed """ _primary_key_field: str _field_types: Dict[str, Any] = {} _store: AbstractStore - _nested_model_tuple_fields: Dict[str, Tuple[Any, ...]] = {} - _nested_model_list_fields: Dict[str, Type["AbstractModel"]] = {} - _nested_model_fields: Dict[str, Type["AbstractModel"]] = {} + _field_type_trees: Dict[str, Optional[AggTypeTree]] = {} + _field_typed_keys: Dict[str, str] = {} + _strict: bool = False model_config = ConfigDict(arbitrary_types_allowed=True) @classmethod @@ -59,32 +89,13 @@ def get_store(cls) -> AbstractStore: return cls._store @classmethod - def get_nested_model_tuple_fields(cls): - """Gets the mapping for fields that have tuples of nested models. + def get_field_type_trees(cls): + """Gets the mapping for fields and the trees of the types that form their aggregate types. Returns: - The mapping of field name and field type of a form similar to - `Tuple[str, Book, date]` + The mapping of field name and type trees of their aggregate types """ - return cls._nested_model_tuple_fields - - @classmethod - def get_nested_model_list_fields(cls): - """Gets the mapping for fields that have lists of nested models. - - Returns: - The mapping of field name and model class nested in that field. - """ - return cls._nested_model_list_fields - - @classmethod - def get_nested_model_fields(cls): - """Gets the mapping for fields that have nested models. - - Returns: - The mapping of field name and model class nested in that field. - """ - return cls._nested_model_fields + return cls._field_type_trees @classmethod def get_primary_key_field(cls): @@ -107,6 +118,15 @@ def get_field_types(cls) -> Dict[str, Any]: """ return cls._field_types + @classmethod + def get_field_typed_keys(cls) -> Dict[str, Any]: + """Gets the mapping of field and their type-aware key names for current Model. + + Returns: + the mapping of field and the type-aware key names for current Model + """ + return cls._field_typed_keys + @classmethod def initialize(cls): """Initializes class-wide variables for performance's reasons. @@ -116,48 +136,16 @@ def initialize(cls): """ cls._field_types = typing.get_type_hints(cls) - cls._nested_model_list_fields = {} - cls._nested_model_tuple_fields = {} - cls._nested_model_fields = {} + cls._field_type_trees = { + field: _generate_field_type_tree(field_type, strict=cls._strict) + for field, field_type in cls._field_types.items() + if not field.startswith("_") + } - for field, field_type in cls._field_types.items(): - try: - # In case the annotation is Optional, an alias of Union[X, None], extract the X - is_generic = hasattr(field_type, "__origin__") - if ( - is_generic - and typing_get_origin(field_type) == Union - and typing_get_args(field_type)[-1] == None.__class__ - ): - field_type = typing_get_args(field_type)[0] - is_generic = hasattr(field_type, "__origin__") - - if ( - is_generic - and typing_get_origin(field_type) in (List, list) - and issubclass(typing_get_args(field_type)[0], AbstractModel) - ): - cls._nested_model_list_fields[field] = typing_get_args(field_type)[ - 0 - ] - - elif ( - is_generic - and typing_get_origin(field_type) in (Tuple, tuple) - and any( - [ - issubclass(v, AbstractModel) - for v in typing_get_args(field_type) - ] - ) - ): - cls._nested_model_tuple_fields[field] = typing_get_args(field_type) - - elif issubclass(field_type, AbstractModel): - cls._nested_model_fields[field] = field_type - - except (TypeError, AttributeError): - pass + cls._field_typed_keys = { + field: _get_typed_field_key(field, type_tree=type_tree) + for field, type_tree in cls._field_type_trees.items() + } @classmethod def serialize_partially(cls, data: Optional[Dict[str, Any]]) -> Dict[str, Any]: @@ -192,70 +180,174 @@ def deserialize_partially( parsed_dict = {} - nested_model_list_fields = cls.get_nested_model_list_fields() - nested_model_tuple_fields = cls.get_nested_model_tuple_fields() - nested_model_fields = cls.get_nested_model_fields() + field_type_trees = cls.get_field_type_trees() for i in range(0, len(data), 2): key = from_bytes_to_str(data[i]) field_type = cls._field_types.get(key) value = from_str_or_bytes_to_any(value=data[i + 1], field_type=field_type) + type_tree = field_type_trees.get(key) - if key in nested_model_list_fields and value is not None: - value = _cast_lists(value, nested_model_list_fields[key]) - - elif key in nested_model_tuple_fields and value is not None: - value = _cast_tuples(value, nested_model_tuple_fields[key]) - - elif key in nested_model_fields and value is not None: - value = _cast_to_model(value=value, model=nested_model_fields[key]) - - parsed_dict[key] = value + parsed_dict[key] = _cast_by_type_tree(value=value, type_tree=type_tree) return parsed_dict -def _cast_lists(value: List[Any], _type: Type[AbstractModel]) -> List[AbstractModel]: - """Casts a list of flattened key-value lists into a list of _type. +def _generate_field_type_tree(field_type: Any, strict: bool = False) -> AggTypeTree: + """Gets the tree of types for the given aggregate type of the field Args: - _type: the type to cast the records to. - value: the value to convert + field_type: the type of the field + strict: whether to raise an error if a given generic type is not supported; default=False Returns: - a list of records of the given _type + the type of nested model or None if not a nested model instance, + and a tuple of the types of its constituent types """ - return [_type(**_type.deserialize_partially(item)) for item in value] + try: + nesting_type = None + generic_cls = typing_get_origin(field_type) + + if generic_cls is None: + if issubclass(field_type, AbstractModel): + return NestingType.ON_ROOT, (field_type,) + return None, (field_type,) + + type_args = typing_get_args(field_type) + + if generic_cls is Union: + nesting_type = NestingType.IN_UNION + + elif generic_cls in (List, list): + nesting_type = NestingType.IN_LIST + + elif generic_cls in (Tuple, tuple): + nesting_type = NestingType.IN_TUPLE + + elif generic_cls in (Dict, dict): + nesting_type = NestingType.IN_DICT + + elif strict: + raise NotImplementedError( + f"Generic class type: {generic_cls} not supported for nested models" + ) + + return nesting_type, tuple( + [_generate_field_type_tree(v, strict) for v in type_args] + ) + + except AttributeError: + return None, (field_type,) -def _cast_tuples(value: List[Any], _type: Tuple[Any, ...]) -> Tuple[Any, ...]: - """Casts a list of flattened key-value lists into a list of tuple of _type,. +def _cast_by_type_tree(value: Any, type_tree: Optional[AggTypeTree]) -> Any: + """Casts a given value into a value basing on the tree of its aggregate type Args: - _type: the tuple signature type to cast the records to - e.g. Tuple[str, Book, int] - value: the value to convert + value: the value to be cast basing on the type tree + type_tree: the tree representing the nested hierarchy of types for the aggregate + type that the value is to be cast into Returns: - a list of records of tuple signature specified by `_type` + the parsed value """ - items = [] - for field_type, value in zip(_type, value): - if issubclass(field_type, AbstractModel) and value is not None: - value = field_type(**field_type.deserialize_partially(value)) - items.append(value) + if value is None or type_tree is None: + # return the value as is because it cannot be cast + return value - return tuple(items) + nesting_type, type_args = type_tree + if nesting_type is NestingType.ON_ROOT: + _type = type_args[0] + return _type(**_type.deserialize_partially(value)) -def _cast_to_model(value: List[Any], model: Type[AbstractModel]) -> AbstractModel: - """Converts a list of flattened key-value lists into a list of models,. + if nesting_type is NestingType.IN_LIST: + _type = type_args[0] + return [_cast_by_type_tree(item, _type) for item in value] + + if nesting_type is NestingType.IN_TUPLE: + return tuple( + [_cast_by_type_tree(item, _type) for _type, item in zip(type_args, value)] + ) + + if nesting_type is NestingType.IN_DICT: + _, value_type = type_args + return {k: _cast_by_type_tree(v, value_type) for k, v in value.items()} + + if nesting_type is NestingType.IN_UNION: + # the value can be any of the types in type_args + for _type in type_args: + try: + parsed_value = _cast_by_type_tree(value, _type) + # return the first successfully parsed value + # that is not equal to the original value + if parsed_value != value: + return parsed_value + except Exception: + pass + + # return the value without any parsing + return value + + +def _get_typed_field_key( + field: str, type_tree: AggTypeTree, initial_prefix: str = "" +) -> str: + """Returns the key for the given field with extra information of its type Args: - model: the model class to cast to - value: the value to cast + field: the original key + type_tree: the tree of types that form the aggregate tree for the given key + initial_prefix: the initial_prefix to add to the field Returns: - a list of model instances of type `model` + the key with extra type information in its string """ - return model(**model.deserialize_partially(value)) + if type_tree is None: + return field + + nesting_type, type_args = type_tree + + if nesting_type is NestingType.ON_ROOT: + # we have found a NestedModel, so we stop recursion + if initial_prefix: + return f"{initial_prefix}{field}" + return f"{NESTED_MODEL_PREFIX}{field}" + + if nesting_type is NestingType.IN_LIST: + _type = type_args[0] + return _get_typed_field_key( + field, _type, initial_prefix=NESTED_MODEL_LIST_FIELD_PREFIX + ) + + if nesting_type is NestingType.IN_TUPLE: + for _type in type_args: + key = _get_typed_field_key( + field, _type, initial_prefix=NESTED_MODEL_TUPLE_FIELD_PREFIX + ) + if key != field: + return key + + if nesting_type is NestingType.IN_DICT: + _, value_type = type_args + return _get_typed_field_key( + field, value_type, initial_prefix=NESTED_MODEL_DICT_FIELD_PREFIX + ) + + if nesting_type is NestingType.IN_UNION: + # the value can be any of the types in type_args + for _type in type_args: + try: + key = _get_typed_field_key(field, type_tree=_type) + # return the first successful value + # that is not equal to the original value + # FIXME: Note that this is not comprehensive enough because + # it is possible to have Union[AbstractModel, List[AbstractModel]] + # but that would be a complicated type to parse/serialize + # Moral of story: Don't use it :-) + if key != field: + return key + except Exception: + pass + + return field diff --git a/pydantic_redis/_shared/model/insert_utils.py b/pydantic_redis/_shared/model/insert_utils.py index aae524bc..6d87122e 100644 --- a/pydantic_redis/_shared/model/insert_utils.py +++ b/pydantic_redis/_shared/model/insert_utils.py @@ -3,7 +3,7 @@ """ from datetime import datetime -from typing import Union, Optional, Any, Dict, Tuple, List, Type +from typing import Union, Optional, Any, Dict, Type from redis.asyncio.client import Pipeline as AioPipeline from redis.client import Pipeline @@ -11,12 +11,14 @@ from .prop_utils import ( get_redis_key, get_model_index_key, - NESTED_MODEL_PREFIX, - NESTED_MODEL_LIST_FIELD_PREFIX, - NESTED_MODEL_TUPLE_FIELD_PREFIX, ) -from .base import AbstractModel +from .base import ( + AbstractModel, + NestingType, + AggTypeTree, + NESTED_MODEL_LIST_FIELD_PREFIX, +) def insert_on_pipeline( @@ -75,16 +77,16 @@ def _serialize_nested_models( In order to make the record serializable, all nested models including those in lists and tuples of nested models are converted to their primary keys, - after being their insert operations have been added to the pipeline. + after their insert operations have been added to the pipeline. A few cleanups it does include: - Upserting any nested records in `record` - Replacing the keys of nested records with their NESTED_MODEL_PREFIX prefixed versions e.g. `__author` instead of author - Replacing the keys of lists of nested records with their NESTED_MODEL_LIST_FIELD_PREFIX prefixed versions - e.g. `__%&l_author` instead of author + e.g. `___author` instead of author - Replacing the keys of tuples of nested records with their NESTED_MODEL_TUPLE_FIELD_PREFIX prefixed versions - e.g. `__%&l_author` instead of author + e.g. `____author` instead of author - Replacing the values of nested records with their foreign keys Args: @@ -97,136 +99,97 @@ def _serialize_nested_models( the partially serialized dict that has no nested models """ data = record.items() if isinstance(record, dict) else record - new_data = {} - - nested_model_list_fields = model.get_nested_model_list_fields() - nested_model_tuple_fields = model.get_nested_model_tuple_fields() - nested_model_fields = model.get_nested_model_fields() + field_type_trees = model.get_field_type_trees() + field_typed_keys = model.get_field_typed_keys() + new_data = {} for k, v in data: - key, value = k, v - - if key in nested_model_list_fields: - key, value = _serialize_list( - key=key, value=value, pipeline=pipeline, life_span=life_span - ) - elif key in nested_model_tuple_fields: - key, value = _serialize_tuple( - key=key, - value=value, - pipeline=pipeline, - life_span=life_span, - tuple_fields=nested_model_tuple_fields, - ) - elif key in nested_model_fields: - key, value = _serialize_model( - key=key, value=value, pipeline=pipeline, life_span=life_span - ) + type_tree = field_type_trees.get(k) + key = field_typed_keys.get(k, k) + new_data[key] = _serialize_by_type_tree( + value=v, type_tree=type_tree, pipeline=pipeline, life_span=life_span + ) - new_data[key] = value return new_data -def _serialize_tuple( - key: str, - value: Tuple[AbstractModel], +def _serialize_by_type_tree( + value: Any, + type_tree: Optional[AggTypeTree], pipeline: Union[Pipeline, AioPipeline], life_span: Optional[Union[float, int]], - tuple_fields: Dict[str, Tuple[Any, ...]], -) -> Tuple[str, List[Any]]: - """Replaces models in a tuple with strings. - - It adds insert operations for the records in the tuple onto the pipeline - and returns the tuple with the models replaced by their primary keys as value. - - Returns: - key: the original `key` prefixed with NESTED_MODEL_TUPLE_FIELD_PREFIX - value: tthe tuple with the models replaced by their primary keys - """ - try: - field_types = tuple_fields.get(key, ()) - value = [ - ( - insert_on_pipeline( - model=field_type, - _id=None, - pipeline=pipeline, - record=item, - life_span=life_span, - ) - if issubclass(field_type, AbstractModel) - else item - ) - for field_type, item in zip(field_types, value) - ] - key = f"{NESTED_MODEL_TUPLE_FIELD_PREFIX}{key}" - except TypeError: - # In case the value is None, just ignore - pass - - return key, value - +) -> Any: + """Transforms a given value into a value basing on the tree of its aggregate type -def _serialize_list( - key: str, - value: List[AbstractModel], - pipeline: Union[Pipeline, AioPipeline], - life_span: Optional[Union[float, int]], -) -> Tuple[str, List[Any]]: - """Casts a list of models into a list of strings + Nested models are inserted into the redis database and their positions in the data + replaced by their primary keys - It adds insert operations for the records in the list onto the pipeline - and returns a list of their primary keys as value. + Args: + value: the value to be serialized basing on the type tree + type_tree: the tree representing the nested hierarchy of types for the aggregate + type that the value is to be cast into + pipeline: the redis pipeline on which the redis operations are to be done. + life_span: the time-to-live in seconds for the given record. Returns: - key: the original `key` prefixed with NESTED_MODEL_LIST_FIELD_PREFIX - value: the list of primary keys of the records to be inserted + the serialized value """ - try: - value = [ - insert_on_pipeline( - model=item.__class__, - _id=None, - pipeline=pipeline, - record=item, - life_span=life_span, - ) - for item in value - ] - key = f"{NESTED_MODEL_LIST_FIELD_PREFIX}{key}" - except TypeError: - # In case the value is None, just ignore - pass - - return key, value - - -def _serialize_model( - key: str, - value: AbstractModel, - pipeline: Union[Pipeline, AioPipeline], - life_span: Optional[Union[float, int]], -) -> Tuple[str, str]: - """Casts a model into a string + if type_tree is None: + # return the value as is because it cannot be serialized + return value - It adds an insert operation for the given model onto the pipeline - and returns its primary key as value. + nesting_type, type_args = type_tree - Returns: - key: the original `key` prefixed with NESTED_MODEL_PREFIX - value: the primary key of the model - """ - try: - value = insert_on_pipeline( + if nesting_type is NestingType.ON_ROOT: + return insert_on_pipeline( model=value.__class__, _id=None, pipeline=pipeline, record=value, life_span=life_span, ) - key = f"{NESTED_MODEL_PREFIX}{key}" - except TypeError: - # In case the value is None, just ignore - pass - return key, value + if nesting_type is NestingType.IN_LIST: + _type = type_args[0] + return [ + _serialize_by_type_tree( + value=item, type_tree=_type, pipeline=pipeline, life_span=life_span + ) + for item in value + ] + + if nesting_type is NestingType.IN_TUPLE: + return tuple( + [ + _serialize_by_type_tree( + value=item, type_tree=_type, pipeline=pipeline, life_span=life_span + ) + for _type, item in zip(type_args, value) + ] + ) + + if nesting_type is NestingType.IN_DICT: + _, value_type = type_args + return { + k: _serialize_by_type_tree( + value=v, type_tree=value_type, pipeline=pipeline, life_span=life_span + ) + for k, v in value.items() + } + + if nesting_type is NestingType.IN_UNION: + # the value can be any of the types in type_args + for _type in type_args: + try: + serialized_value = _serialize_by_type_tree( + value=value, type_tree=_type, pipeline=pipeline, life_span=life_span + ) + # return the first successfully serialized value + # that is not equal to the original value + if serialized_value != value: + return serialized_value + except Exception: + pass + + # return the value without any serializing + return value diff --git a/pydantic_redis/_shared/model/prop_utils.py b/pydantic_redis/_shared/model/prop_utils.py index 9256a247..b1adcf00 100644 --- a/pydantic_redis/_shared/model/prop_utils.py +++ b/pydantic_redis/_shared/model/prop_utils.py @@ -1,19 +1,11 @@ """Exposes utils for getting properties of the Model -Attributes: - NESTED_MODEL_PREFIX (str): the prefix for fields with single nested models - NESTED_MODEL_LIST_FIELD_PREFIX (str): the prefix for fields with lists of nested models - NESTED_MODEL_TUPLE_FIELD_PREFIX (str): the prefix for fields with tuples of nested models """ from typing import Type, Any from .base import AbstractModel -NESTED_MODEL_PREFIX = "__" -NESTED_MODEL_LIST_FIELD_PREFIX = "___" -NESTED_MODEL_TUPLE_FIELD_PREFIX = "____" - def get_redis_key(model: Type[AbstractModel], primary_key_value: Any): """Gets the key used internally in redis for the `primary_key_value` of `model`. diff --git a/pydantic_redis/_shared/model/select_utils.py b/pydantic_redis/_shared/model/select_utils.py index 33e0f842..f509e74a 100644 --- a/pydantic_redis/_shared/model/select_utils.py +++ b/pydantic_redis/_shared/model/select_utils.py @@ -5,9 +5,6 @@ from typing import List, Any, Type, Union, Awaitable, Optional from pydantic_redis._shared.model.prop_utils import ( - NESTED_MODEL_PREFIX, - NESTED_MODEL_LIST_FIELD_PREFIX, - NESTED_MODEL_TUPLE_FIELD_PREFIX, get_redis_keys_regex, get_redis_key_prefix, get_model_index_key, @@ -17,7 +14,7 @@ from .base import AbstractModel -def get_select_fields(model: Type[AbstractModel], columns: List[str]) -> List[str]: +def get_select_fields(model: Type[AbstractModel], columns: List[str] = []) -> List[str]: """Gets the fields to be used for selecting HMAP fields in Redis. It replaces any fields in `columns` that correspond to nested records with their @@ -30,22 +27,8 @@ def get_select_fields(model: Type[AbstractModel], columns: List[str]) -> List[st Returns: the fields for selecting, with nested fields being given appropriate prefixes. """ - fields = [] - nested_model_list_fields = model.get_nested_model_list_fields() - nested_model_tuple_fields = model.get_nested_model_tuple_fields() - nested_model_fields = model.get_nested_model_fields() - - for col in columns: - - if col in nested_model_fields: - col = f"{NESTED_MODEL_PREFIX}{col}" - elif col in nested_model_list_fields: - col = f"{NESTED_MODEL_LIST_FIELD_PREFIX}{col}" - elif col in nested_model_tuple_fields: - col = f"{NESTED_MODEL_TUPLE_FIELD_PREFIX}{col}" - - fields.append(col) - return fields + typed_keys = model.get_field_typed_keys() + return [typed_keys.get(col, col) for col in columns] def select_all_fields_all_ids( diff --git a/test/test_async_pydantic_redis.py b/test/test_async_pydantic_redis.py index a90e19fa..a7cee471 100644 --- a/test/test_async_pydantic_redis.py +++ b/test/test_async_pydantic_redis.py @@ -5,7 +5,7 @@ import pytest -from pydantic_redis._shared.model.prop_utils import NESTED_MODEL_PREFIX # noqa +from pydantic_redis._shared.model.base import NESTED_MODEL_PREFIX # noqa from pydantic_redis._shared.utils import strip_leading # noqa from pydantic_redis.asyncio import Model, RedisConfig, Store from test.conftest import ( diff --git a/test/test_pydantic_redis.py b/test/test_pydantic_redis.py index a3ddca41..a174936e 100644 --- a/test/test_pydantic_redis.py +++ b/test/test_pydantic_redis.py @@ -7,7 +7,7 @@ from pydantic_redis import Store from pydantic_redis.config import RedisConfig # noqa -from pydantic_redis._shared.model.prop_utils import NESTED_MODEL_PREFIX # noqa +from pydantic_redis._shared.model.base import NESTED_MODEL_PREFIX # noqa from pydantic_redis._shared.utils import strip_leading # noqa from pydantic_redis.syncio.model import Model from test.conftest import ( From a3876ba6442b463031dbaa6d71229d2fe123c8e5 Mon Sep 17 00:00:00 2001 From: Martin Date: Sat, 20 Jul 2024 21:11:05 +0200 Subject: [PATCH 2/4] Fix errors in deeply nested models implementation --- .gitignore | 4 +- CHANGELOG.md | 11 + README.md | 55 ++-- pydantic_redis/_shared/lua_scripts.py | 288 ++++++++----------- pydantic_redis/_shared/model/base.py | 42 ++- pydantic_redis/_shared/model/prop_utils.py | 7 +- pydantic_redis/_shared/model/select_utils.py | 72 +++-- pydantic_redis/_shared/utils.py | 27 +- test/conftest.py | 4 +- test/test_pydantic_redis.py | 35 ++- 10 files changed, 314 insertions(+), 231 deletions(-) diff --git a/.gitignore b/.gitignore index fb3d7f30..18f6d6c5 100644 --- a/.gitignore +++ b/.gitignore @@ -145,4 +145,6 @@ cython_debug/ env3.7/ /lua_scripts/ -.DS_Store \ No newline at end of file +.DS_Store + +*.lua \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 36c0476c..3440ecee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,17 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## [Unreleased] +### Added + +### Changed + +- Added support for deeply nested models beyond level-1 deep including: + - dictionaries of lists of ... of nested models + - lists of tuples of lists .... of nested models + +### Fixed + + ## [0.6.0] - 2024-07-01 ### Added diff --git a/README.md b/README.md index 15094d58..6a459927 100644 --- a/README.md +++ b/README.md @@ -88,36 +88,35 @@ benchmark_bulk_insert[redis_store] 721.2247 (6.19) 6 --------------------------------------------------------------------------------------------------------------------- ``` -# >=v0.7 (with fully-fledged nested models) +# >=v0.7 (with deeply nested models) ``` ---------------------------------------------------- benchmark: 22 tests ---------------------------------------------------- -Name (time in us) Mean Min Max ----------------------------------------------------------------------------------------------------------------------------- -test_benchmark_delete[redis_store-Wuthering Heights] 124.5440 (1.01) 109.3710 (1.0) 579.7810 (1.39) -test_benchmark_bulk_delete[redis_store] 122.9285 (1.0) 113.7120 (1.04) 492.2730 (1.18) -test_benchmark_select_columns_for_one_id[redis_store-book1] 182.3891 (1.48) 154.9150 (1.42) 441.2820 (1.06) -test_benchmark_select_columns_for_one_id[redis_store-book2] 183.2679 (1.49) 156.6830 (1.43) 462.6000 (1.11) -test_benchmark_select_columns_for_one_id[redis_store-book0] 181.6972 (1.48) 157.2330 (1.44) 459.2930 (1.10) -test_benchmark_select_columns_for_one_id[redis_store-book3] 183.0834 (1.49) 160.1250 (1.46) 416.8570 (1.0) -test_benchmark_select_all_for_one_id[redis_store-book1] 203.9491 (1.66) 183.3080 (1.68) 469.4700 (1.13) -test_benchmark_select_all_for_one_id[redis_store-book2] 206.7124 (1.68) 184.1920 (1.68) 490.6700 (1.18) -test_benchmark_select_all_for_one_id[redis_store-book0] 207.3341 (1.69) 184.2210 (1.68) 443.9260 (1.06) -test_benchmark_select_all_for_one_id[redis_store-book3] 210.6874 (1.71) 185.0600 (1.69) 696.9330 (1.67) -test_benchmark_select_columns_for_some_items[redis_store] 236.5783 (1.92) 215.7490 (1.97) 496.0540 (1.19) -test_benchmark_select_columns_paginated[redis_store] 248.5335 (2.02) 218.3450 (2.00) 522.1270 (1.25) -test_benchmark_update[redis_store-Wuthering Heights-data0] 282.1803 (2.30) 239.5410 (2.19) 541.5220 (1.30) -test_benchmark_select_some_items[redis_store] 298.2036 (2.43) 264.0860 (2.41) 599.3010 (1.44) -test_benchmark_single_insert[redis_store-book0] 316.0245 (2.57) 269.8110 (2.47) 596.0940 (1.43) -test_benchmark_single_insert[redis_store-book2] 314.1899 (2.56) 270.9780 (2.48) 560.5280 (1.34) -test_benchmark_select_default_paginated[redis_store] 305.2798 (2.48) 277.8170 (2.54) 550.5110 (1.32) -test_benchmark_single_insert[redis_store-book1] 312.5839 (2.54) 279.5660 (2.56) 578.7070 (1.39) -test_benchmark_single_insert[redis_store-book3] 316.9207 (2.58) 284.8630 (2.60) 567.0120 (1.36) -test_benchmark_select_columns[redis_store] 369.1538 (3.00) 331.5770 (3.03) 666.0470 (1.60) -test_benchmark_select_default[redis_store] 553.9420 (4.51) 485.3700 (4.44) 1,235.8540 (2.96) -test_benchmark_bulk_insert[redis_store] 777.4058 (6.32) 730.4280 (6.68) 1,012.7780 (2.43) ----------------------------------------------------------------------------------------------------------------------------- - +------------------------------------------------- benchmark: 22 tests ------------------------------------------------- +Name (time in us) Mean Min Max +----------------------------------------------------------------------------------------------------------------------- +benchmark_delete[redis_store-Wuthering Heights] 123.2946 (1.02) 107.9690 (1.0) 502.6140 (1.33) +benchmark_bulk_delete[redis_store] 120.5815 (1.0) 111.9320 (1.04) 378.8660 (1.0) +benchmark_select_columns_for_one_id[redis_store-book2] 208.2612 (1.73) 180.4660 (1.67) 470.9860 (1.24) +benchmark_select_columns_for_one_id[redis_store-book1] 207.9143 (1.72) 180.6440 (1.67) 489.6890 (1.29) +benchmark_select_columns_for_one_id[redis_store-book0] 204.2471 (1.69) 183.4360 (1.70) 485.2500 (1.28) +benchmark_select_columns_for_one_id[redis_store-book3] 209.5764 (1.74) 189.5780 (1.76) 462.5650 (1.22) +benchmark_select_all_for_one_id[redis_store-book0] 226.4569 (1.88) 207.4920 (1.92) 499.9470 (1.32) +benchmark_select_all_for_one_id[redis_store-book3] 241.5488 (2.00) 210.5230 (1.95) 504.5150 (1.33) +benchmark_select_all_for_one_id[redis_store-book1] 234.4014 (1.94) 210.6420 (1.95) 501.2470 (1.32) +benchmark_select_all_for_one_id[redis_store-book2] 228.9277 (1.90) 212.0090 (1.96) 509.5740 (1.34) +benchmark_update[redis_store-Wuthering Heights-data0] 276.3908 (2.29) 238.3390 (2.21) 704.9450 (1.86) +benchmark_single_insert[redis_store-book3] 311.0476 (2.58) 262.2940 (2.43) 589.3940 (1.56) +benchmark_select_columns_for_some_items[redis_store] 291.2779 (2.42) 266.0960 (2.46) 564.3510 (1.49) +benchmark_select_columns_paginated[redis_store] 300.4108 (2.49) 269.4740 (2.50) 552.8510 (1.46) +benchmark_single_insert[redis_store-book1] 304.5771 (2.53) 274.1740 (2.54) 547.5210 (1.45) +benchmark_single_insert[redis_store-book2] 317.2681 (2.63) 275.6170 (2.55) 641.5440 (1.69) +benchmark_single_insert[redis_store-book0] 313.0004 (2.60) 277.3190 (2.57) 558.2160 (1.47) +benchmark_select_some_items[redis_store] 343.2569 (2.85) 311.9140 (2.89) 624.6600 (1.65) +benchmark_select_default_paginated[redis_store] 359.8463 (2.98) 325.8310 (3.02) 623.2360 (1.65) +benchmark_select_columns[redis_store] 486.6047 (4.04) 429.3250 (3.98) 867.8780 (2.29) +benchmark_select_default[redis_store] 631.3835 (5.24) 584.7630 (5.42) 1,033.5990 (2.73) +benchmark_bulk_insert[redis_store] 761.0832 (6.31) 724.1240 (6.71) 1,034.2950 (2.73) +----------------------------------------------------------------------------------------------------------------------- ``` ## Contributions diff --git a/pydantic_redis/_shared/lua_scripts.py b/pydantic_redis/_shared/lua_scripts.py index 5b052fee..8eee732b 100644 --- a/pydantic_redis/_shared/lua_scripts.py +++ b/pydantic_redis/_shared/lua_scripts.py @@ -1,5 +1,8 @@ """Exposes the redis lua scripts to be used in select queries. +These scripts always return a list of tuples of [record, index] where the index is a flat list of nested models +for that record + Attributes: SELECT_ALL_FIELDS_FOR_ALL_IDS_SCRIPT: the script for selecting all records from redis PAGINATED_SELECT_ALL_FIELDS_FOR_ALL_IDS_SCRIPT: the script for selecting a slice of all records from redis, @@ -14,6 +17,8 @@ but returning only a subset of the fields in each record. """ +# What if instead of constructing tables, we return obj as a JSON string + SELECT_ALL_FIELDS_FOR_ALL_IDS_SCRIPT = """ local s_find = string.find local s_gmatch = string.gmatch @@ -29,48 +34,40 @@ return s_find(s, prefix, 1, true) == 1 end -local function trim_dunder(s) - return s:match '^_*(.-)$' -end - -local function get_obj(id) +local function get_obj_and_index(id) local value = redis_call('HGETALL', id) + local idx = {} for i, k in ipairs(value) do if not (i % 2 == 0) then - if startswith(k, '___') or startswith(k, '____') then - if value[i + 1] == 'null' then - value[i + 1] = 'null' - else - local nested = {} - - for v in s_gmatch(value[i + 1], '([^%[^,^%]^\"]+)') do - table_insert(nested, get_obj(v)) - end - - value[i + 1] = nested - end - - value[i] = trim_dunder(k) + if value[i + 1] == 'null' then + elseif startswith(k, '___') then + for v in s_gmatch(value[i + 1], '\"([%w_]+_%%&_[^\"%[%]]+)\"') do + table_insert(idx, v) + table_insert(idx, {get_obj_and_index(v)}) + end elseif startswith(k, '__') then - value[i + 1] = get_obj(value[i + 1]) - value[i] = trim_dunder(k) + table_insert(idx, value[i + 1]) + table_insert(idx, {get_obj_and_index(value[i + 1])}) end - end + + end end if next(value) == nil then - return id + return id, nil end - - return value + return value, idx end repeat local result = redis_call('SCAN', cursor, 'MATCH', ARGV[1]) for _, key in ipairs(result[2]) do if redis_call('TYPE', key).ok == 'hash' then - table_insert(filtered, get_obj(key)) + local value, idx = get_obj_and_index(key) + if type(value) == 'table' then + table_insert(filtered, {value, idx}) + end end end cursor = result[1] @@ -90,41 +87,29 @@ return s_find(s, prefix, 1, true) == 1 end -local function trim_dunder(s) - return s:match '^_*(.-)$' -end - -local function get_obj(id) +local function get_obj_and_index(id) local value = redis_call('HGETALL', id) + local idx = {} for i, k in ipairs(value) do if not (i % 2 == 0) then - if startswith(k, '___') or startswith(k, '____') then - if value[i + 1] == 'null' then - value[i + 1] = 'null' - else - local nested = {} - - for v in s_gmatch(value[i + 1], '([^%[^,^%]^\"]+)') do - table_insert(nested, get_obj(v)) - end - - value[i + 1] = nested - end - - value[i] = trim_dunder(k) + if value[i + 1] == 'null' then + elseif startswith(k, '___') then + for v in s_gmatch(value[i + 1], '\"([%w_]+_%%&_[^\"%[%]]+)\"') do + table_insert(idx, v) + table_insert(idx, {get_obj_and_index(v)}) + end elseif startswith(k, '__') then - value[i + 1] = get_obj(value[i + 1]) - value[i] = trim_dunder(k) + table_insert(idx, value[i + 1]) + table_insert(idx, {get_obj_and_index(value[i + 1])}) end - end + end end if next(value) == nil then - return id + return id, nil end - - return value + return value, idx end local table_index_key = ARGV[1] @@ -135,9 +120,9 @@ local ids = redis_call('ZRANGE', table_index_key, start, stop) for _, key in ipairs(ids) do - local value = get_obj(key) + local value, idx = get_obj_and_index(key) if type(value) == 'table' then - table_insert(result, value) + table_insert(result, {value, idx}) end end @@ -158,47 +143,35 @@ return s_find(s, prefix, 1, true) == 1 end -local function trim_dunder(s) - return s:match '^_*(.-)$' -end - -local function get_obj(id) +local function get_obj_and_index(id) local value = redis_call('HGETALL', id) + local idx = {} for i, k in ipairs(value) do if not (i % 2 == 0) then - if startswith(k, '___') or startswith(k, '____') then - if value[i + 1] == 'null' then - value[i + 1] = 'null' - else - local nested = {} - - for v in s_gmatch(value[i + 1], '([^%[^,^%]^\"]+)') do - table_insert(nested, get_obj(v)) - end - - value[i + 1] = nested - end - - value[i] = trim_dunder(k) + if value[i + 1] == 'null' then + elseif startswith(k, '___') then + for v in s_gmatch(value[i + 1], '\"([%w_]+_%%&_[^\"%[%]]+)\"') do + table_insert(idx, v) + table_insert(idx, {get_obj_and_index(v)}) + end elseif startswith(k, '__') then - value[i + 1] = get_obj(value[i + 1]) - value[i] = trim_dunder(k) + table_insert(idx, value[i + 1]) + table_insert(idx, {get_obj_and_index(value[i + 1])}) end - end + end end if next(value) == nil then - return id + return id, nil end - - return value + return value, idx end for _, key in ipairs(KEYS) do - local value = get_obj(key) + local value, idx = get_obj_and_index(key) if type(value) == 'table' then - table_insert(result, value) + table_insert(result, {value, idx}) end end @@ -222,40 +195,29 @@ return s_find(s, prefix, 1, true) == 1 end -local function trim_dunder(s) - return s:match '^_*(.-)$' -end - -local function get_obj(id) +local function get_obj_and_index(id) local value = redis_call('HGETALL', id) + local idx = {} for i, k in ipairs(value) do if not (i % 2 == 0) then - if startswith(k, '___') or startswith(k, '____') then - if value[i + 1] == 'null' then - value[i + 1] = 'null' - else - local nested = {} - - for v in s_gmatch(value[i + 1], '([^%[^,^%]^\"]+)') do - table_insert(nested, get_obj(v)) - end - - value[i + 1] = nested - end - value[i] = trim_dunder(k) + if value[i + 1] == 'null' then + elseif startswith(k, '___') then + for v in s_gmatch(value[i + 1], '\"([%w_]+_%%&_[^\"%[%]]+)\"') do + table_insert(idx, v) + table_insert(idx, {get_obj_and_index(v)}) + end elseif startswith(k, '__') then - value[i + 1] = get_obj(value[i + 1]) - value[i] = trim_dunder(k) + table_insert(idx, value[i + 1]) + table_insert(idx, {get_obj_and_index(value[i + 1])}) end - end + end end if next(value) == nil then - return id + return id, nil end - - return value + return value, idx end for i, k in ipairs(ARGV) do @@ -270,13 +232,20 @@ if redis_call('TYPE', key).ok == 'hash' then local data = redis_call('HMGET', key, table_unpack(columns)) local parsed_data = {} + local index = {} for i, v in ipairs(data) do - table_insert(parsed_data, trim_dunder(columns[i])) - table_insert(parsed_data, get_obj(v)) + table_insert(parsed_data, columns[i]) + table_insert(parsed_data, v) + + local value, idx = get_obj_and_index(v) + if type(idx) == 'table' then + table_insert(index, v) + table_insert(index, {value, idx}) + end end - - table_insert(filtered, parsed_data) + + table_insert(filtered, {parsed_data, index}) end end cursor = result[1] @@ -297,41 +266,29 @@ return s_find(s, prefix, 1, true) == 1 end -local function trim_dunder(s) - return s:match '^_*(.-)$' -end - -local function get_obj(id) +local function get_obj_and_index(id) local value = redis_call('HGETALL', id) + local idx = {} for i, k in ipairs(value) do if not (i % 2 == 0) then - if startswith(k, '___') or startswith(k, '____') then - if value[i + 1] == 'null' then - value[i + 1] = 'null' - else - local nested = {} - - for v in s_gmatch(value[i + 1], '([^%[^,^%]^\"]+)') do - table_insert(nested, get_obj(v)) - end - - value[i + 1] = nested - end - - value[i] = trim_dunder(k) + if value[i + 1] == 'null' then + elseif startswith(k, '___') then + for v in s_gmatch(value[i + 1], '\"([%w_]+_%%&_[^\"%[%]]+)\"') do + table_insert(idx, v) + table_insert(idx, {get_obj_and_index(v)}) + end elseif startswith(k, '__') then - value[i + 1] = get_obj(value[i + 1]) - value[i] = trim_dunder(k) + table_insert(idx, value[i + 1]) + table_insert(idx, {get_obj_and_index(value[i + 1])}) end - end + end end if next(value) == nil then - return id + return id, nil end - - return value + return value, idx end local result = {} @@ -351,15 +308,23 @@ for _, key in ipairs(ids) do local data = redis_call('HMGET', key, table_unpack(columns)) local parsed_data = {} + local index = {} for i, v in ipairs(data) do if v then - table_insert(parsed_data, trim_dunder(columns[i])) - table_insert(parsed_data, get_obj(v)) + table_insert(parsed_data, columns[i]) + table_insert(parsed_data, v) + + local value, idx = get_obj_and_index(v) + if type(idx) == 'table' then + table_insert(index, v) + table_insert(index, {value, idx}) + end end end - table_insert(result, parsed_data) + table_insert(result, {parsed_data, index}) + end return result @@ -381,41 +346,29 @@ return s_find(s, prefix, 1, true) == 1 end -local function trim_dunder(s) - return s:match '^_*(.-)$' -end - -local function get_obj(id) +local function get_obj_and_index(id) local value = redis_call('HGETALL', id) + local idx = {} for i, k in ipairs(value) do if not (i % 2 == 0) then - if startswith(k, '___') or startswith(k, '____') then - if value[i + 1] == 'null' then - value[i + 1] = 'null' - else - local nested = {} - - for v in s_gmatch(value[i + 1], '([^%[^,^%]^\"]+)') do - table_insert(nested, get_obj(v)) - end - - value[i + 1] = nested - end - - value[i] = trim_dunder(k) + if value[i + 1] == 'null' then + elseif startswith(k, '___') then + for v in s_gmatch(value[i + 1], '\"([%w_]+_%%&_[^\"%[%]]+)\"') do + table_insert(idx, v) + table_insert(idx, {get_obj_and_index(v)}) + end elseif startswith(k, '__') then - value[i + 1] = get_obj(value[i + 1]) - value[i] = trim_dunder(k) + table_insert(idx, value[i + 1]) + table_insert(idx, {get_obj_and_index(value[i + 1])}) end - end + end end if next(value) == nil then - return id + return id, nil end - - return value + return value, idx end for _, k in ipairs(ARGV) do @@ -425,15 +378,22 @@ for _, key in ipairs(KEYS) do local data = redis_call('HMGET', key, table_unpack(columns)) local parsed_data = {} + local index = {} for i, v in ipairs(data) do if v then - table_insert(parsed_data, trim_dunder(columns[i])) - table_insert(parsed_data, get_obj(v)) + table_insert(parsed_data, columns[i]) + table_insert(parsed_data, v) + + local value, idx = get_obj_and_index(v) + if type(idx) == 'table' then + table_insert(index, v) + table_insert(index, {value, idx}) + end end end - table_insert(result, parsed_data) + table_insert(result, {parsed_data, index}) end return result """ diff --git a/pydantic_redis/_shared/model/base.py b/pydantic_redis/_shared/model/base.py index b7b7c580..00c1f545 100644 --- a/pydantic_redis/_shared/model/base.py +++ b/pydantic_redis/_shared/model/base.py @@ -20,6 +20,7 @@ from_dict_to_key_value_list, from_bytes_to_str, from_str_or_bytes_to_any, + groups_of_n, ) @@ -161,19 +162,23 @@ def serialize_partially(cls, data: Optional[Dict[str, Any]]) -> Dict[str, Any]: @classmethod def deserialize_partially( - cls, data: Union[List[Any], Dict[Any, Any]] = () + cls, data: Union[List[Any], Dict[Any, Any]] = (), index: Dict[Any, Any] = None ) -> Dict[str, Any]: - """Casts str or bytes in a dict or flattened key-value list to expected data types. + """Casts str or bytes in a dict to expected data types. Converts str or bytes to their expected data types Args: data: flattened list of key-values or dictionary of data to cast. Keeping it as potentially a dictionary ensures backward compatibility. + index: dictionary of the index of nested models potentially present Returns: the dictionary of properly parsed key-values. """ + if index is None: + index = {} + if isinstance(data, dict): # for backward compatibility data = from_dict_to_key_value_list(data) @@ -182,13 +187,16 @@ def deserialize_partially( field_type_trees = cls.get_field_type_trees() - for i in range(0, len(data), 2): - key = from_bytes_to_str(data[i]) + for k, v in groups_of_n(data, 2): + # remove the dunders for nested model fields + key = from_bytes_to_str(k).lstrip("_") field_type = cls._field_types.get(key) - value = from_str_or_bytes_to_any(value=data[i + 1], field_type=field_type) + value = from_str_or_bytes_to_any(value=v, field_type=field_type) type_tree = field_type_trees.get(key) - parsed_dict[key] = _cast_by_type_tree(value=value, type_tree=type_tree) + parsed_dict[key] = _cast_by_type_tree( + value=value, type_tree=type_tree, index=index + ) return parsed_dict @@ -240,13 +248,16 @@ def _generate_field_type_tree(field_type: Any, strict: bool = False) -> AggTypeT return None, (field_type,) -def _cast_by_type_tree(value: Any, type_tree: Optional[AggTypeTree]) -> Any: +def _cast_by_type_tree( + value: Any, type_tree: Optional[AggTypeTree], index: Dict[Any, Any] = None +) -> Any: """Casts a given value into a value basing on the tree of its aggregate type Args: value: the value to be cast basing on the type tree type_tree: the tree representing the nested hierarchy of types for the aggregate type that the value is to be cast into + index: dictionary of the index of nested models potentially present Returns: the parsed value @@ -259,26 +270,33 @@ def _cast_by_type_tree(value: Any, type_tree: Optional[AggTypeTree]) -> Any: if nesting_type is NestingType.ON_ROOT: _type = type_args[0] - return _type(**_type.deserialize_partially(value)) + nested_model_data = value + if isinstance(value, str): + # load the nested model if it is not yet loaded + nested_model_data = index.get(value, value) + return _type(**_type.deserialize_partially(nested_model_data)) if nesting_type is NestingType.IN_LIST: _type = type_args[0] - return [_cast_by_type_tree(item, _type) for item in value] + return [_cast_by_type_tree(item, _type, index) for item in value] if nesting_type is NestingType.IN_TUPLE: return tuple( - [_cast_by_type_tree(item, _type) for _type, item in zip(type_args, value)] + [ + _cast_by_type_tree(item, _type, index) + for _type, item in zip(type_args, value) + ] ) if nesting_type is NestingType.IN_DICT: _, value_type = type_args - return {k: _cast_by_type_tree(v, value_type) for k, v in value.items()} + return {k: _cast_by_type_tree(v, value_type, index) for k, v in value.items()} if nesting_type is NestingType.IN_UNION: # the value can be any of the types in type_args for _type in type_args: try: - parsed_value = _cast_by_type_tree(value, _type) + parsed_value = _cast_by_type_tree(value, _type, index) # return the first successfully parsed value # that is not equal to the original value if parsed_value != value: diff --git a/pydantic_redis/_shared/model/prop_utils.py b/pydantic_redis/_shared/model/prop_utils.py index b1adcf00..8eb01544 100644 --- a/pydantic_redis/_shared/model/prop_utils.py +++ b/pydantic_redis/_shared/model/prop_utils.py @@ -2,11 +2,16 @@ """ +import re from typing import Type, Any from .base import AbstractModel +NESTED_MODEL_SEPARATOR = "_%&_" +NESTED_MODEL_VALUE_REGEX = re.compile(f"^([\\w_]+{NESTED_MODEL_SEPARATOR}[\\w_]+)$") + + def get_redis_key(model: Type[AbstractModel], primary_key_value: Any): """Gets the key used internally in redis for the `primary_key_value` of `model`. @@ -30,7 +35,7 @@ def get_redis_key_prefix(model: Type[AbstractModel]): the prefix of the all the redis keys that are associated with this model """ model_name = model.__name__.lower() - return f"{model_name}_%&_" + return f"{model_name}{NESTED_MODEL_SEPARATOR}" def get_redis_keys_regex(model: Type[AbstractModel]): diff --git a/pydantic_redis/_shared/model/select_utils.py b/pydantic_redis/_shared/model/select_utils.py index f509e74a..114cf94c 100644 --- a/pydantic_redis/_shared/model/select_utils.py +++ b/pydantic_redis/_shared/model/select_utils.py @@ -2,7 +2,7 @@ """ -from typing import List, Any, Type, Union, Awaitable, Optional +from typing import List, Any, Type, Union, Awaitable, Optional, Dict, Tuple from pydantic_redis._shared.model.prop_utils import ( get_redis_keys_regex, @@ -10,11 +10,14 @@ get_model_index_key, ) - from .base import AbstractModel +from ..utils import groups_of_n + + +RawRedisSelectData = List[Tuple[List[Any], List[Any]]] -def get_select_fields(model: Type[AbstractModel], columns: List[str] = []) -> List[str]: +def get_select_fields(model: Type[AbstractModel], columns: List[str] = ()) -> List[str]: """Gets the fields to be used for selecting HMAP fields in Redis. It replaces any fields in `columns` that correspond to nested records with their @@ -35,7 +38,7 @@ def select_all_fields_all_ids( model: Type[AbstractModel], skip: int = 0, limit: Optional[int] = None, -) -> Union[List[List[Any]], Awaitable[List[List[Any]]]]: +) -> Union[RawRedisSelectData, Awaitable[RawRedisSelectData]]: """Retrieves all records of the given model in the redis database. Args: @@ -44,7 +47,7 @@ def select_all_fields_all_ids( limit: the maximum number of records to return. If None, limit is infinity. Returns: - the list of records from redis, each record being a flattened list of key-values. + list of tuple of [record, index-of-nested-models] with each record being a flattened list of key-values. In case we are using async, an Awaitable of that list is returned instead. """ if isinstance(limit, int): @@ -58,7 +61,7 @@ def select_all_fields_all_ids( def select_all_fields_some_ids( model: Type[AbstractModel], ids: List[str] -) -> Union[List[List[Any]], Awaitable[List[List[Any]]]]: +) -> Union[RawRedisSelectData, Awaitable[RawRedisSelectData]]: """Retrieves some records from redis. Args: @@ -66,7 +69,7 @@ def select_all_fields_some_ids( ids: the list of primary keys of the records to be retrieved. Returns: - the list of records where each record is a flattened key-value list. + list of tuple of [record, index-of-nested-models] with each record is a flattened key-value list. In case we are using async, an Awaitable of that list is returned instead. """ table_prefix = get_redis_key_prefix(model=model) @@ -80,7 +83,7 @@ def select_some_fields_all_ids( fields: List[str], skip: int = 0, limit: Optional[int] = None, -) -> Union[List[List[Any]], Awaitable[List[List[Any]]]]: +) -> Union[RawRedisSelectData, Awaitable[RawRedisSelectData]]: """Retrieves records of model from redis, each as with a subset of the fields. Args: @@ -90,7 +93,7 @@ def select_some_fields_all_ids( limit: the maximum number of records to return. If None, limit is infinity. Returns: - the list of records from redis, each record being a flattened list of key-values. + list of tuple of [record, index-of-nested-models] with each record being a flattened list of key-values. In case we are using async, an Awaitable of that list is returned instead. """ columns = get_select_fields(model=model, columns=fields) @@ -108,7 +111,7 @@ def select_some_fields_all_ids( def select_some_fields_some_ids( model: Type[AbstractModel], fields: List[str], ids: List[str] -) -> Union[List[List[Any]], Awaitable[List[List[Any]]]]: +) -> Union[RawRedisSelectData, Awaitable[RawRedisSelectData]]: """Retrieves some records of current model from redis, each as with a subset of the fields. Args: @@ -117,7 +120,7 @@ def select_some_fields_some_ids( ids: the list of primary keys of the records to be retrieved. Returns: - the list of records from redis, each record being a flattened list of key-values. + list of tuple of [record, index-of-nested-models] with each record being a flattened list of key-values. In case we are using async, an Awaitable of that list is returned instead. """ table_prefix = get_redis_key_prefix(model=model) @@ -128,7 +131,7 @@ def select_some_fields_some_ids( def parse_select_response( - model: Type[AbstractModel], response: List[List], as_models: bool + model: Type[AbstractModel], response: RawRedisSelectData, as_models: bool ): """Casts a list of flattened key-value lists into a list of models or dicts. @@ -150,17 +153,23 @@ def parse_select_response( if as_models: return [ - model(**model.deserialize_partially(record)) - for record in response - if record != [] + model( + **model.deserialize_partially(record, index=_construct_index(raw_index)) + ) + for record, raw_index in response + if len(response) != 0 ] - return [model.deserialize_partially(record) for record in response if record != []] + return [ + model.deserialize_partially(record, index=_construct_index(raw_index)) + for record, raw_index in response + if len(response) != 0 + ] def _select_all_ids_all_fields_paginated( model: Type[AbstractModel], limit: int, skip: Optional[int] -): +) -> Union[RawRedisSelectData, Awaitable[RawRedisSelectData]]: """Retrieves a slice of all records of the given model in the redis database. Args: @@ -169,7 +178,7 @@ def _select_all_ids_all_fields_paginated( limit: the maximum number of records to return. If None, limit is infinity. Returns: - the list of records from redis, each record being a flattened list of key-values. + list of tuple of [record, index-of-nested-models] with each record being a flattened list of key-values. In case we are using async, an Awaitable of that list is returned instead. """ if skip is None: @@ -182,7 +191,7 @@ def _select_all_ids_all_fields_paginated( def _select_some_fields_all_ids_paginated( model: Type[AbstractModel], columns: List[str], limit: int, skip: int -): +) -> Union[RawRedisSelectData, Awaitable[RawRedisSelectData]]: """Retrieves a slice of all records of model from redis, each as with a subset of the fields. Args: @@ -192,7 +201,7 @@ def _select_some_fields_all_ids_paginated( limit: the maximum number of records to return. If None, limit is infinity. Returns: - the list of records from redis, each record being a flattened list of key-values. + list of tuple of [record, index-of-nested-models] with each record being a flattened list of key-values. In case we are using async, an Awaitable of that list is returned instead. """ if skip is None: @@ -201,3 +210,26 @@ def _select_some_fields_all_ids_paginated( args = [table_index_key, skip, limit, *columns] store = model.get_store() return store.paginated_select_some_fields_for_all_ids_script(args=args) + + +def _construct_index(index_list: List[Any]) -> Dict[str, Any]: + """Constructs the index dict from the index list of nested models returned from redis + + Args: + index_list: the flat list of the index of nested models, with key followed by [model, index] tuple + [key1, [model1_flat_list, index1_flat_list], key2, [model2_flat_list, index2_flat_list]...] + + Returns: + the index as a dict + """ + index = {} + for k, model_and_index in groups_of_n(index_list, 2): + model_as_list, index_as_list = model_and_index + model_index = _construct_index(index_as_list) + index[k] = { + # remove the dunders for nested model fields + key.lstrip("_"): model_index.get(value, value) + for key, value in groups_of_n(model_as_list, 2) + } + + return index diff --git a/pydantic_redis/_shared/utils.py b/pydantic_redis/_shared/utils.py index 8fa95db7..7c1e87dd 100644 --- a/pydantic_redis/_shared/utils.py +++ b/pydantic_redis/_shared/utils.py @@ -3,10 +3,12 @@ """ import typing -from typing import Any, Tuple, Optional, Union, Dict, Type, List +from typing import Any, Tuple, Optional, Union, Dict, Type, List, Iterable, TypeVar import orjson +T = TypeVar("T") + def strip_leading(word: str, substring: str) -> str: """Strips the leading substring if it exists. @@ -96,8 +98,12 @@ def from_str_or_bytes_to_any(value: Any, field_type: Type) -> Any: elif not isinstance(value, str): return value - # JSON parse all other values that are str - return orjson.loads(value) + try: + # JSON parse all other values that are str + return orjson.loads(value) + except orjson.JSONDecodeError: + # try to be as fault-tolerant as sanely possible + return value def from_any_to_valid_redis_type(value: Any) -> Union[str, bytes, List[Any]]: @@ -155,3 +161,18 @@ def from_dict_to_key_value_list(data: Dict[str, Any]) -> List[Any]: parsed_list.append(v) return parsed_list + + +def groups_of_n(items: Iterable[T], n: int) -> Iterable[Tuple[T, ...]]: + """Returns an iterable of tuples of size n from the given list of items + + Note that it might ignore the last items if n does not fit nicely into the items list + + Args: + items: the list of items from which to extract the tuples + n: the size of the tuples + + Returns: + the iterable of tuples of n size from the list of items + """ + return zip(*[iter(items)] * n) diff --git a/test/conftest.py b/test/conftest.py index 98356cdc..3263a7f2 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,6 +1,6 @@ import socket from datetime import date -from typing import Tuple, List, Optional +from typing import Tuple, List, Optional, Dict import pytest import pytest_asyncio @@ -56,6 +56,8 @@ class Library(syn.Model): lost: Optional[List[Book]] = None popular: Optional[Tuple[Book, Book]] = None new: Optional[Tuple[Book, Author, Book, int]] = None + list_of_tuples: Optional[List[Tuple[str, Book]]] = None + dict_of_models: Optional[Dict[str, Book]] = None class AsyncLibrary(asy.Model): diff --git a/test/test_pydantic_redis.py b/test/test_pydantic_redis.py index a174936e..c1b05cb9 100644 --- a/test/test_pydantic_redis.py +++ b/test/test_pydantic_redis.py @@ -6,6 +6,7 @@ import pytest from pydantic_redis import Store +from pydantic_redis._shared.model.prop_utils import NESTED_MODEL_SEPARATOR from pydantic_redis.config import RedisConfig # noqa from pydantic_redis._shared.model.base import NESTED_MODEL_PREFIX # noqa from pydantic_redis._shared.utils import strip_leading # noqa @@ -199,6 +200,38 @@ def test_update_optional_nested_tuple_of_models(store: Store): assert got == expected +@pytest.mark.parametrize("store", redis_store_fixture) +def test_update_list_of_tuples_of_nested_models(store: Store): + list_of_tuples = [("some book", books[0]), ("book2", books[2])] + data = [Library(name="Babel Library", address="In a book", list_of_tuples=list_of_tuples)] + Library.insert(data) + # the tuple of nested models is automatically inserted + got = sorted(Book.select(), key=lambda x: x.title) + expected_books = [book for _, book in list_of_tuples] + expected = sorted(expected_books, key=lambda x: x.title) + assert expected == got + + got = sorted(Library.select(), key=lambda x: x.name) + expected = sorted(data, key=lambda x: x.name) + assert got == expected + + +@pytest.mark.parametrize("store", redis_store_fixture) +def test_update_dict_of_models(store: Store): + dict_of_models = {"some book": books[0], "book2": books[2]} + data = [Library(name="Babel Library", address="In a book", dict_of_models=dict_of_models)] + Library.insert(data) + # the tuple of nested models is automatically inserted + got = sorted(Book.select(), key=lambda x: x.title) + expected_books = [book for _, book in dict_of_models.items()] + expected = sorted(expected_books, key=lambda x: x.title) + assert expected == got + + got = sorted(Library.select(), key=lambda x: x.name) + expected = sorted(data, key=lambda x: x.name) + assert got == expected + + @pytest.mark.parametrize("store", redis_store_fixture) def test_select_default(store: Store): """Selecting without arguments returns all the book models""" @@ -432,7 +465,7 @@ def test_delete_multiple(store: Store): def __deserialize_book_data(raw_book_data: Dict[str, Any]) -> Book: """Deserializes the raw book data returning a book instance""" author_id = raw_book_data.pop(f"{NESTED_MODEL_PREFIX}author") - author_id = strip_leading(author_id, "author_%&_") + author_id = strip_leading(author_id, f"author{NESTED_MODEL_SEPARATOR}") data = Book.deserialize_partially(raw_book_data) From 955b7e1609b010ab382b614e88d552e3997f0b58 Mon Sep 17 00:00:00 2001 From: Martin Date: Sat, 20 Jul 2024 21:19:37 +0200 Subject: [PATCH 3/4] Update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3440ecee..57814eb5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ### Changed +- Added support for dictionaries of nested models - Added support for deeply nested models beyond level-1 deep including: - dictionaries of lists of ... of nested models - lists of tuples of lists .... of nested models From 0c24b2adef7ce054b1398d1fb3d260d373c3eec4 Mon Sep 17 00:00:00 2001 From: Martin Date: Sat, 20 Jul 2024 21:36:13 +0200 Subject: [PATCH 4/4] Add tests for AttributeError NoneType 'get_primary_key_field' --- CHANGELOG.md | 2 ++ test/conftest.py | 1 + test/test_pydantic_redis.py | 42 ++++++++++++++++++++++++++++++++++--- 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 57814eb5..3fdb5bec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ### Fixed +- Fixed `AttributeError: 'get_primary_key_field'` when None is passed to a field with an optional nested model + ## [0.6.0] - 2024-07-01 diff --git a/test/conftest.py b/test/conftest.py index 3263a7f2..c7810186 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -58,6 +58,7 @@ class Library(syn.Model): new: Optional[Tuple[Book, Author, Book, int]] = None list_of_tuples: Optional[List[Tuple[str, Book]]] = None dict_of_models: Optional[Dict[str, Book]] = None + optional_nested: Optional[Book] = None class AsyncLibrary(asy.Model): diff --git a/test/test_pydantic_redis.py b/test/test_pydantic_redis.py index c1b05cb9..5756c79a 100644 --- a/test/test_pydantic_redis.py +++ b/test/test_pydantic_redis.py @@ -203,7 +203,11 @@ def test_update_optional_nested_tuple_of_models(store: Store): @pytest.mark.parametrize("store", redis_store_fixture) def test_update_list_of_tuples_of_nested_models(store: Store): list_of_tuples = [("some book", books[0]), ("book2", books[2])] - data = [Library(name="Babel Library", address="In a book", list_of_tuples=list_of_tuples)] + data = [ + Library( + name="Babel Library", address="In a book", list_of_tuples=list_of_tuples + ) + ] Library.insert(data) # the tuple of nested models is automatically inserted got = sorted(Book.select(), key=lambda x: x.title) @@ -219,9 +223,13 @@ def test_update_list_of_tuples_of_nested_models(store: Store): @pytest.mark.parametrize("store", redis_store_fixture) def test_update_dict_of_models(store: Store): dict_of_models = {"some book": books[0], "book2": books[2]} - data = [Library(name="Babel Library", address="In a book", dict_of_models=dict_of_models)] + data = [ + Library( + name="Babel Library", address="In a book", dict_of_models=dict_of_models + ) + ] Library.insert(data) - # the tuple of nested models is automatically inserted + # the dict of nested models is automatically inserted got = sorted(Book.select(), key=lambda x: x.title) expected_books = [book for _, book in dict_of_models.items()] expected = sorted(expected_books, key=lambda x: x.title) @@ -232,6 +240,34 @@ def test_update_dict_of_models(store: Store): assert got == expected +@pytest.mark.parametrize("store", redis_store_fixture) +def test_update_filled_optional_nested_model(store: Store): + data = [ + Library(name="Babel Library", address="In a book", optional_nested=books[0]) + ] + Library.insert(data) + + got = sorted(Book.select(), key=lambda x: x.title) + assert [books[0]] == got + + got = sorted(Library.select(), key=lambda x: x.name) + expected = sorted(data, key=lambda x: x.name) + assert got == expected + + +@pytest.mark.parametrize("store", redis_store_fixture) +def test_update_unfilled_optional_nested_model(store: Store): + data = [Library(name="Babel Library", address="In a book")] + Library.insert(data) + + got = Book.select() + assert got is None + + got = sorted(Library.select(), key=lambda x: x.name) + expected = sorted(data, key=lambda x: x.name) + assert got == expected + + @pytest.mark.parametrize("store", redis_store_fixture) def test_select_default(store: Store): """Selecting without arguments returns all the book models"""