Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions django_mongodb_backend/aggregates.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def aggregate(self, compiler, connection, operator=None, resolve_inner_expressio
node.set_source_expressions([Case(condition), *source_expressions[1:]])
else:
node = self
lhs_mql = process_lhs(node, compiler, connection)
lhs_mql = process_lhs(node, compiler, connection, as_expr=True)
if resolve_inner_expression:
return lhs_mql
operator = operator or MONGO_AGGREGATIONS.get(self.__class__, self.function.lower())
Expand All @@ -39,9 +39,9 @@ def count(self, compiler, connection, resolve_inner_expression=False):
self.filter, then=Case(When(IsNull(source_expressions[0], False), then=Value(1)))
)
node.set_source_expressions([Case(condition), *source_expressions[1:]])
inner_expression = process_lhs(node, compiler, connection)
inner_expression = process_lhs(node, compiler, connection, as_expr=True)
else:
lhs_mql = process_lhs(self, compiler, connection)
lhs_mql = process_lhs(self, compiler, connection, as_expr=True)
null_cond = {"$in": [{"$type": lhs_mql}, ["missing", "null"]]}
inner_expression = {
"$cond": {"if": null_cond, "then": None, "else": lhs_mql if self.distinct else 1}
Expand All @@ -51,7 +51,7 @@ def count(self, compiler, connection, resolve_inner_expression=False):
return {"$sum": inner_expression}
# If distinct=True or resolve_inner_expression=False, sum the size of the
# set.
lhs_mql = process_lhs(self, compiler, connection)
lhs_mql = process_lhs(self, compiler, connection, as_expr=True)
# None shouldn't be counted, so subtract 1 if it's present.
exits_null = {"$cond": {"if": {"$in": [{"$literal": None}, lhs_mql]}, "then": -1, "else": 0}}
return {"$add": [{"$size": lhs_mql}, exits_null]}
Expand All @@ -66,7 +66,7 @@ def stddev_variance(self, compiler, connection):


def register_aggregates():
Aggregate.as_mql = aggregate
Count.as_mql = count
StdDev.as_mql = stddev_variance
Variance.as_mql = stddev_variance
Aggregate.as_mql_expr = aggregate
Count.as_mql_expr = count
StdDev.as_mql_expr = stddev_variance
Variance.as_mql_expr = stddev_variance
57 changes: 55 additions & 2 deletions django_mongodb_backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import logging
import os

from django.core.exceptions import ImproperlyConfigured
from bson import Decimal128
from django.core.exceptions import EmptyResultSet, FullResultSet, ImproperlyConfigured
from django.db import DEFAULT_DB_ALIAS
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.backends.utils import debug_transaction
Expand Down Expand Up @@ -96,6 +97,58 @@ class DatabaseWrapper(BaseDatabaseWrapper):
}
_connection_pools = {}

def _isnull_operator(field, is_null):
if is_null:
return {"$or": [{field: {"$exists": False}}, {field: None}]}
return {"$and": [{field: {"$exists": True}}, {field: {"$ne": None}}]}

def _range_operator(a, b):
conditions = []
start, end = b
if start is not None:
conditions.append({a: {"$gte": b[0]}})
if end is not None:
conditions.append({a: {"$lte": b[1]}})
if not conditions:
raise FullResultSet
if start is not None and end is not None:
# Decimal128 can't be natively compared.
if isinstance(start, Decimal128):
start = start.to_decimal()
if isinstance(end, Decimal128):
end = end.to_decimal()
if start > end:
raise EmptyResultSet
return {"$and": conditions}

def _regex_operator(field, regex, insensitive=False):
options = "i" if insensitive else ""
return {field: {"$regex": regex, "$options": options}}

mongo_operators = {
"exact": lambda a, b: {a: b},
"gt": lambda a, b: {a: {"$gt": b}},
"gte": lambda a, b: {a: {"$gte": b}},
# MongoDB considers null less than zero. Exclude null values to match
# SQL behavior.
"lt": lambda a, b: {"$and": [{a: {"$lt": b}}, DatabaseWrapper._isnull_operator(a, False)]},
"lte": lambda a, b: {
"$and": [{a: {"$lte": b}}, DatabaseWrapper._isnull_operator(a, False)]
},
"in": lambda a, b: {a: {"$in": tuple(b)}},
"isnull": _isnull_operator,
"range": _range_operator,
"iexact": lambda a, b: DatabaseWrapper._regex_operator(a, f"^{b}$", insensitive=True),
"startswith": lambda a, b: DatabaseWrapper._regex_operator(a, f"^{b}"),
"istartswith": lambda a, b: DatabaseWrapper._regex_operator(a, f"^{b}", insensitive=True),
"endswith": lambda a, b: DatabaseWrapper._regex_operator(a, f"{b}$"),
"iendswith": lambda a, b: DatabaseWrapper._regex_operator(a, f"{b}$", insensitive=True),
"contains": lambda a, b: DatabaseWrapper._regex_operator(a, b),
"icontains": lambda a, b: DatabaseWrapper._regex_operator(a, b, insensitive=True),
"regex": lambda a, b: DatabaseWrapper._regex_operator(a, b),
"iregex": lambda a, b: DatabaseWrapper._regex_operator(a, b, insensitive=True),
}

def _isnull_expr(field, is_null):
mql = {
"$or": [
Expand All @@ -112,7 +165,7 @@ def _regex_expr(field, regex_vals, insensitive=False):
options = "i" if insensitive else ""
return {"$regexMatch": {"input": field, "regex": regex, "options": options}}

mongo_operators = {
mongo_expr_operators = {
"exact": lambda a, b: {"$eq": [a, b]},
"gt": lambda a, b: {"$gt": [a, b]},
"gte": lambda a, b: {"$gte": [a, b]},
Expand Down
32 changes: 17 additions & 15 deletions django_mongodb_backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,14 @@ def _get_replace_expr(self, sub_expr, group, alias):
if getattr(sub_expr, "distinct", False):
# If the expression should return distinct values, use $addToSet to
# deduplicate.
rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True)
rhs = sub_expr.as_mql(
self, self.connection, resolve_inner_expression=True, as_expr=True
)
group[alias] = {"$addToSet": rhs}
replacing_expr = sub_expr.copy()
replacing_expr.set_source_expressions([inner_column, None])
else:
group[alias] = sub_expr.as_mql(self, self.connection)
group[alias] = sub_expr.as_mql(self, self.connection, as_expr=True)
replacing_expr = inner_column
# Count must return 0 rather than null.
if isinstance(sub_expr, Count):
Expand Down Expand Up @@ -302,9 +304,7 @@ def _compound_searches_queries(self, search_replacements):
search.as_mql(self, self.connection),
{
"$addFields": {
result_col.as_mql(self, self.connection, as_path=True): {
"$meta": score_function
}
result_col.as_mql(self, self.connection): {"$meta": score_function}
}
},
]
Expand Down Expand Up @@ -334,7 +334,7 @@ def pre_sql_setup(self, with_col_aliases=False):
pipeline.extend(query.get_pipeline())
# Remove the added subqueries.
self.subqueries = []
pipeline.append({"$match": {"$expr": having}})
pipeline.append({"$match": having})
self.aggregation_pipeline = pipeline
self.annotations = {
target: expr.replace_expressions(all_replacements)
Expand Down Expand Up @@ -481,11 +481,11 @@ def build_query(self, columns=None):
query.lookup_pipeline = self.get_lookup_pipeline()
where = self.get_where()
try:
expr = where.as_mql(self, self.connection) if where else {}
match_mql = where.as_mql(self, self.connection) if where else {}
except FullResultSet:
query.match_mql = {}
else:
query.match_mql = {"$expr": expr}
query.match_mql = match_mql
if extra_fields:
query.extra_fields = self.get_project_fields(extra_fields, force_expression=True)
query.subqueries = self.subqueries
Expand Down Expand Up @@ -643,7 +643,9 @@ def get_combinator_queries(self):
for alias, expr in self.columns:
# Unfold foreign fields.
if isinstance(expr, Col) and expr.alias != self.collection_name:
ids[expr.alias][expr.target.column] = expr.as_mql(self, self.connection)
ids[expr.alias][expr.target.column] = expr.as_mql(
self, self.connection, as_expr=True
)
else:
ids[alias] = f"${alias}"
# Convert defaultdict to dict so it doesn't appear as
Expand Down Expand Up @@ -707,16 +709,16 @@ def get_project_fields(self, columns=None, ordering=None, force_expression=False
# For brevity/simplicity, project {"field_name": 1}
# instead of {"field_name": "$field_name"}.
if isinstance(expr, Col) and name == expr.target.column and not force_expression
else expr.as_mql(self, self.connection)
else expr.as_mql(self, self.connection, as_expr=True)
)
except EmptyResultSet:
empty_result_set_value = getattr(expr, "empty_result_set_value", NotImplemented)
value = (
False if empty_result_set_value is NotImplemented else empty_result_set_value
)
fields[collection][name] = Value(value).as_mql(self, self.connection)
fields[collection][name] = Value(value).as_mql(self, self.connection, as_expr=True)
except FullResultSet:
fields[collection][name] = Value(True).as_mql(self, self.connection)
fields[collection][name] = Value(True).as_mql(self, self.connection, as_expr=True)
# Annotations (stored in None) and the main collection's fields
# should appear in the top-level of the fields dict.
fields.update(fields.pop(None, {}))
Expand All @@ -739,10 +741,10 @@ def _get_ordering(self):
idx = itertools.count(start=1)
for order in self.order_by_objs or []:
if isinstance(order.expression, Col):
field_name = order.as_mql(self, self.connection).removeprefix("$")
field_name = order.as_mql(self, self.connection, as_expr=True).removeprefix("$")
fields.append((order.expression.target.column, order.expression))
elif isinstance(order.expression, Ref):
field_name = order.as_mql(self, self.connection).removeprefix("$")
field_name = order.as_mql(self, self.connection, as_expr=True).removeprefix("$")
else:
field_name = f"__order{next(idx)}"
fields.append((field_name, order.expression))
Expand Down Expand Up @@ -879,7 +881,7 @@ def execute_sql(self, result_type):
)
prepared = field.get_db_prep_save(value, connection=self.connection)
if hasattr(value, "as_mql"):
prepared = prepared.as_mql(self, self.connection)
prepared = prepared.as_mql(self, self.connection, as_expr=True)
values[field.column] = prepared
try:
criteria = self.build_query().match_mql
Expand Down
Loading