Skip to content

Commit 5897983

Browse files
fix: update enable_gridfs_integration; improved test for bulk_write
1 parent fa36978 commit 5897983

File tree

2 files changed

+70
-11
lines changed

2 files changed

+70
-11
lines changed

mongomock_motor/__init__.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -339,12 +339,47 @@ def __hash__(self):
339339

340340
@contextmanager
341341
def enabled_gridfs_integration():
342+
Database = (PyMongoDatabase, MongoMockDatabase)
343+
Collection = (PyMongoDatabase, MongoMockCollection)
344+
342345
with ExitStack() as stack:
343-
stack.enter_context(
344-
patch('gridfs.Database', (PyMongoDatabase, MongoMockDatabase))
345-
)
346-
stack.enter_context(
347-
patch('gridfs.grid_file.Collection', (PyMongoDatabase, MongoMockCollection))
348-
)
349-
stack.enter_context(patch('gridfs.GridOutCursor', _create_grid_out_cursor))
346+
try:
347+
stack.enter_context(
348+
patch(
349+
'gridfs.synchronous.grid_file.Database',
350+
Database,
351+
)
352+
)
353+
stack.enter_context(
354+
patch(
355+
'gridfs.synchronous.grid_file.Collection',
356+
Collection,
357+
)
358+
)
359+
stack.enter_context(
360+
patch(
361+
'gridfs.synchronous.grid_file.GridOutCursor',
362+
_create_grid_out_cursor,
363+
)
364+
)
365+
except (AttributeError, ModuleNotFoundError):
366+
stack.enter_context(
367+
patch(
368+
'gridfs.Database',
369+
Database,
370+
)
371+
)
372+
stack.enter_context(
373+
patch(
374+
'gridfs.grid_file.Collection',
375+
Collection,
376+
)
377+
)
378+
stack.enter_context(
379+
patch(
380+
'gridfs.GridOutCursor',
381+
_create_grid_out_cursor,
382+
)
383+
)
384+
350385
yield

tests/test_workflow.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from datetime import datetime, timezone
22

33
import bson
4+
import bson.tz_util
45
import pytest
5-
from pymongo import ReplaceOne
6+
from pymongo import DeleteMany, InsertOne, ReplaceOne, UpdateOne
67

78
from mongomock_motor import AsyncMongoMockClient
89

@@ -58,7 +59,30 @@ async def test_tz_awareness():
5859
@pytest.mark.anyio
5960
async def test_bulk_write():
6061
collection = AsyncMongoMockClient()['tests']['test']
61-
result = await collection.bulk_write(
62-
[ReplaceOne(filter={'_id': 1}, replacement={'_id': 1}, upsert=True)]
62+
63+
write_result = await collection.bulk_write(
64+
[
65+
InsertOne({'_id': 1}),
66+
DeleteMany({}),
67+
InsertOne({'_id': 1}),
68+
InsertOne({'_id': 2}),
69+
InsertOne({'_id': 3}),
70+
UpdateOne({'_id': 1}, {'$set': {'foo': 'bar'}}),
71+
UpdateOne({'_id': 4}, {'$inc': {'j': 1}}, upsert=True),
72+
ReplaceOne({'j': 1}, {'j': 2}),
73+
],
6374
)
64-
assert result.bulk_api_result['nUpserted'] == 1
75+
76+
assert write_result.bulk_api_result['nInserted'] == 4
77+
assert write_result.bulk_api_result['nMatched'] == 2
78+
assert write_result.bulk_api_result['nModified'] == 2
79+
assert write_result.bulk_api_result['nUpserted'] == 1
80+
81+
documents = await collection.find({}).to_list(None)
82+
83+
assert documents == [
84+
{'_id': 1, 'foo': 'bar'},
85+
{'_id': 2},
86+
{'_id': 3},
87+
{'_id': 4, 'j': 2},
88+
]

0 commit comments

Comments
 (0)