diff --git a/django_mock_queries/query.py b/django_mock_queries/query.py index ca39759..d58f1ed 100644 --- a/django_mock_queries/query.py +++ b/django_mock_queries/query.py @@ -127,7 +127,7 @@ def annotate(self, **kwargs): row._annotated_fields.append(key) setattr(row, key, get_attribute(row, value)[0]) - return MockSet(*results, clone=self) + return self._mockset_class()(*results, clone=self) def aggregate(self, *args, **kwargs): result = {} diff --git a/tests/test_query.py b/tests/test_query.py index a5bfc71..2b37d42 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1090,6 +1090,13 @@ def test_annotate(self): self.assertEqual(qs[2].color_or_car, 'kia') + def test_annotate_returns_current_class_instance(self): + class CustomMockSet(MockSet): + pass + + qs = CustomMockSet(Car(model='golf', id=1)) + self.assertIsInstance(qs.annotate(model=models.F('model')), CustomMockSet) + def test_query_values_raises_attribute_error_when_field_is_not_in_meta_concrete_fields(self): qs = MockSet(MockModel(foo=1), MockModel(foo=2)) self.assertRaises(FieldError, qs.values, 'bar')