@@ -75,6 +75,9 @@ def _normalize_strings(obj):
75
75
if isinstance (obj , list ):
76
76
return [_normalize_strings (v ) for v in obj ]
77
77
78
+ if isinstance (obj , tuple ):
79
+ return tuple (_normalize_strings (v ) for v in obj )
80
+
78
81
if isinstance (obj , dict ):
79
82
return {_normalize_strings (k ): _normalize_strings (v ) for k , v in obj .items ()}
80
83
@@ -85,7 +88,41 @@ def _normalize_strings(obj):
85
88
return obj
86
89
87
90
88
- def _patch_iter_documents (collection ):
91
+ def _patch_iter_documents_and_get_dataset (collection ):
92
+ """
93
+ When using beanie or other solutions that utilize classes inheriting from
94
+ the "str" type, we need to explicitly transform these instances to plain
95
+ strings in cases where internal workings of "mongomock" unable to handle
96
+ custom string-like classes. Currently only beanie's "ExpressionField" is
97
+ transformed to plain strings.
98
+ """
99
+
100
+ def _iter_documents_with_normalized_strings (fn ):
101
+ @wraps (fn )
102
+ def wrapper (filter ):
103
+ return fn (_normalize_strings (filter ))
104
+
105
+ return wrapper
106
+
107
+ collection ._iter_documents = _iter_documents_with_normalized_strings (
108
+ collection ._iter_documents ,
109
+ )
110
+
111
+ def _get_dataset_with_normalized_strings (fn ):
112
+ @wraps (fn )
113
+ def wrapper (spec , sort , fields , as_class ):
114
+ return fn (spec , _normalize_strings (sort ), fields , as_class )
115
+
116
+ return wrapper
117
+
118
+ collection ._get_dataset = _get_dataset_with_normalized_strings (
119
+ collection ._get_dataset ,
120
+ )
121
+
122
+ return collection
123
+
124
+
125
+ def _patch_get_dataset (collection ):
89
126
"""
90
127
When using beanie, keys can have "ExpressionField" type,
91
128
that is inherited from "str". Looks like pymongo works ok
@@ -94,13 +131,14 @@ def _patch_iter_documents(collection):
94
131
95
132
def with_normalized_strings_in_filter (fn ):
96
133
@wraps (fn )
97
- def wrapper (filter ):
98
- return fn (_normalize_strings (filter ))
134
+ def wrapper (spec , sort , fields , as_class ):
135
+ print (sort )
136
+ return fn (spec , _normalize_strings (sort ), fields , as_class )
99
137
100
138
return wrapper
101
139
102
- collection ._iter_documents = with_normalized_strings_in_filter (
103
- collection ._iter_documents ,
140
+ collection ._get_dataset = with_normalized_strings_in_filter (
141
+ collection ._get_dataset ,
104
142
)
105
143
106
144
return collection
@@ -110,7 +148,7 @@ def _patch_collection_internals(collection):
110
148
if getattr (collection , '_patched_by_mongomock_motor' , False ):
111
149
return collection
112
150
collection = _patch_insert_and_ensure_uniques (collection )
113
- collection = _patch_iter_documents (collection )
151
+ collection = _patch_iter_documents_and_get_dataset (collection )
114
152
collection ._patched_by_mongomock_motor = True
115
153
return collection
116
154
0 commit comments