diff --git a/django_mock_queries/constants.py b/django_mock_queries/constants.py index d857ba1..ff7f327 100644 --- a/django_mock_queries/constants.py +++ b/django_mock_queries/constants.py @@ -92,6 +92,7 @@ ) DjangoQ = locate('django.db.models.Q') +DjangoSqlQuery = locate('django.db.models.sql.Query') DjangoQuerySet = locate('django.db.models.QuerySet') DjangoDbRouter = locate('django.db.router') DjangoModelDeletionCollector = locate('django.db.models.deletion.Collector') diff --git a/django_mock_queries/query.py b/django_mock_queries/query.py index ca39759..65d7b20 100644 --- a/django_mock_queries/query.py +++ b/django_mock_queries/query.py @@ -38,6 +38,7 @@ class MockSet(MagicMock, metaclass=MockSetMeta): def __init__(self, *initial_items, **kwargs): clone = kwargs.pop('clone', None) model = kwargs.pop('model', None) + query = kwargs.pop('query', None) for x in self.RETURN_SELF_METHODS: kwargs.update({x: self._return_self}) @@ -48,6 +49,10 @@ def __init__(self, *initial_items, **kwargs): self.clone = clone self.model = getattr(clone, 'model', model) self.events = {} + self.query = ( + getattr(clone, 'query', query) + or MagicMock(spec=DjangoSqlQuery, model=self.model, clone=lambda: self.query) + ) self.add(*initial_items)