Skip to content

Commit 09bb06c

Browse files
Michael Kryukovmichaelkryukov
Michael Kryukov
authored andcommitted
fix: support ExpressionField in sorting operations; fixes #55
1 parent 98873f3 commit 09bb06c

File tree

3 files changed

+58
-9
lines changed

3 files changed

+58
-9
lines changed

mongomock_motor/patches.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ def _normalize_strings(obj):
7575
if isinstance(obj, list):
7676
return [_normalize_strings(v) for v in obj]
7777

78+
if isinstance(obj, tuple):
79+
return tuple(_normalize_strings(v) for v in obj)
80+
7881
if isinstance(obj, dict):
7982
return {_normalize_strings(k): _normalize_strings(v) for k, v in obj.items()}
8083

@@ -85,7 +88,41 @@ def _normalize_strings(obj):
8588
return obj
8689

8790

88-
def _patch_iter_documents(collection):
91+
def _patch_iter_documents_and_get_dataset(collection):
92+
"""
93+
When using beanie or other solutions that utilize classes inheriting from
94+
the "str" type, we need to explicitly transform these instances to plain
95+
strings in cases where internal workings of "mongomock" unable to handle
96+
custom string-like classes. Currently only beanie's "ExpressionField" is
97+
transformed to plain strings.
98+
"""
99+
100+
def _iter_documents_with_normalized_strings(fn):
101+
@wraps(fn)
102+
def wrapper(filter):
103+
return fn(_normalize_strings(filter))
104+
105+
return wrapper
106+
107+
collection._iter_documents = _iter_documents_with_normalized_strings(
108+
collection._iter_documents,
109+
)
110+
111+
def _get_dataset_with_normalized_strings(fn):
112+
@wraps(fn)
113+
def wrapper(spec, sort, fields, as_class):
114+
return fn(spec, _normalize_strings(sort), fields, as_class)
115+
116+
return wrapper
117+
118+
collection._get_dataset = _get_dataset_with_normalized_strings(
119+
collection._get_dataset,
120+
)
121+
122+
return collection
123+
124+
125+
def _patch_get_dataset(collection):
89126
"""
90127
When using beanie, keys can have "ExpressionField" type,
91128
that is inherited from "str". Looks like pymongo works ok
@@ -94,13 +131,14 @@ def _patch_iter_documents(collection):
94131

95132
def with_normalized_strings_in_filter(fn):
96133
@wraps(fn)
97-
def wrapper(filter):
98-
return fn(_normalize_strings(filter))
134+
def wrapper(spec, sort, fields, as_class):
135+
print(sort)
136+
return fn(spec, _normalize_strings(sort), fields, as_class)
99137

100138
return wrapper
101139

102-
collection._iter_documents = with_normalized_strings_in_filter(
103-
collection._iter_documents,
140+
collection._get_dataset = with_normalized_strings_in_filter(
141+
collection._get_dataset,
104142
)
105143

106144
return collection
@@ -110,7 +148,7 @@ def _patch_collection_internals(collection):
110148
if getattr(collection, '_patched_by_mongomock_motor', False):
111149
return collection
112150
collection = _patch_insert_and_ensure_uniques(collection)
113-
collection = _patch_iter_documents(collection)
151+
collection = _patch_iter_documents_and_get_dataset(collection)
114152
collection._patched_by_mongomock_motor = True
115153
return collection
116154

tests/test_beanie.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,14 @@ async def test_beanie_links():
7575
house = houses[0]
7676
await house.fetch_all_links()
7777
assert house.door.height == 2.1
78+
79+
80+
@pytest.mark.anyio
81+
async def test_beanie_sort():
82+
client = AsyncMongoMockClient()
83+
await init_beanie(database=client.beanie_test, document_models=[Door])
84+
85+
await Door.insert_many([Door(width=width) for width in [4, 2, 3, 1]])
86+
87+
doors = await Door.find().sort(Door.width).to_list()
88+
assert [door.width for door in doors] == [1, 2, 3, 4]

tests/test_mocking.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pymongo.results import UpdateResult
77

88
from mongomock_motor import AsyncMongoMockClient
9-
from mongomock_motor.patches import _patch_iter_documents
9+
from mongomock_motor.patches import _patch_iter_documents_and_get_dataset
1010

1111

1212
@pytest.mark.anyio
@@ -56,8 +56,8 @@ async def test_no_multiple_patching():
5656
database = AsyncMongoMockClient()['test']
5757

5858
with patch(
59-
'mongomock_motor.patches._patch_iter_documents',
60-
wraps=_patch_iter_documents,
59+
'mongomock_motor.patches._patch_iter_documents_and_get_dataset',
60+
wraps=_patch_iter_documents_and_get_dataset,
6161
) as patch_iter_documents:
6262
for _ in range(2):
6363
collection = database['test']

0 commit comments

Comments
 (0)