Skip to content

Commit ee0f66d

Browse files
Michael Kryukovmichaelkryukov
Michael Kryukov
authored andcommitted
fix: avoid patching collections multiple times; closes #49
1 parent e11fd2a commit ee0f66d

File tree

3 files changed

+50
-10
lines changed

3 files changed

+50
-10
lines changed

mongomock_motor/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ async def to_list(self, *args, **kwargs):
193193
class AsyncMongoMockCollection:
194194
def __init__(self, database, collection):
195195
self.database = database
196-
self.__collection = collection
196+
self.__collection = _patch_collection_internals(collection)
197197

198198
def get_io_loop(self):
199199
return self.database.get_io_loop()
@@ -248,9 +248,7 @@ def get_io_loop(self):
248248
def get_collection(self, *args, **kwargs):
249249
return AsyncMongoMockCollection(
250250
self,
251-
_patch_collection_internals(
252-
self.__database.get_collection(*args, **kwargs),
253-
),
251+
self.__database.get_collection(*args, **kwargs),
254252
)
255253

256254
def aggregate(self, *args, **kwargs) -> AsyncCursor:

mongomock_motor/patches.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,11 @@ def wrapper(data, *args, **kwargs):
6161

6262
return wrapper
6363

64-
collection._insert = with_enriched_duplicate_key_error(collection._insert)
64+
collection._insert = with_enriched_duplicate_key_error(
65+
collection._insert,
66+
)
6567
collection._ensure_uniques = with_enriched_duplicate_key_error(
66-
collection._ensure_uniques
68+
collection._ensure_uniques,
6769
)
6870

6971
return collection
@@ -89,19 +91,27 @@ def _patch_iter_documents(collection):
8991
that is inherited from "str". Looks like pymongo works ok
9092
with that, so we should be too.
9193
"""
92-
_iter_documents = collection._iter_documents
9394

94-
def iter_documents(filter):
95-
return _iter_documents(_normalize_strings(filter))
95+
def with_normalized_strings_in_filter(fn):
96+
@wraps(fn)
97+
def wrapper(filter):
98+
return fn(_normalize_strings(filter))
9699

97-
collection._iter_documents = iter_documents
100+
return wrapper
101+
102+
collection._iter_documents = with_normalized_strings_in_filter(
103+
collection._iter_documents,
104+
)
98105

99106
return collection
100107

101108

102109
def _patch_collection_internals(collection):
110+
if getattr(collection, '_patched_by_mongomock_motor', False):
111+
return collection
103112
collection = _patch_insert_and_ensure_uniques(collection)
104113
collection = _patch_iter_documents(collection)
114+
collection._patched_by_mongomock_motor = True
105115
return collection
106116

107117

tests/test_mocking.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
import pytest
44
from bson import ObjectId
5+
from pymongo.read_preferences import Primary
56
from pymongo.results import UpdateResult
67

78
from mongomock_motor import AsyncMongoMockClient
9+
from mongomock_motor.patches import _patch_iter_documents
810

911

1012
@pytest.mark.anyio
@@ -47,3 +49,33 @@ async def test_patch_object():
4749
):
4850
with pytest.raises(RuntimeError):
4951
await sample_function(collection)
52+
53+
54+
@pytest.mark.anyio
55+
async def test_no_multiple_patching():
56+
database = AsyncMongoMockClient()['test']
57+
58+
with patch(
59+
'mongomock_motor.patches._patch_iter_documents',
60+
wraps=_patch_iter_documents,
61+
) as patch_iter_documents:
62+
for _ in range(2):
63+
collection = database['test']
64+
assert collection
65+
66+
assert patch_iter_documents.call_count == 1
67+
68+
for _ in range(2):
69+
collection = database.get_collection(
70+
'test',
71+
read_preference=Primary,
72+
)
73+
assert collection
74+
75+
assert patch_iter_documents.call_count == 3
76+
77+
for _ in range(2):
78+
collection = database.get_collection('test')
79+
assert collection
80+
81+
assert patch_iter_documents.call_count == 3

0 commit comments

Comments
 (0)