diff --git a/cacheops/query.py b/cacheops/query.py index 2fd781e4..9ae69151 100644 --- a/cacheops/query.py +++ b/cacheops/query.py @@ -1,3 +1,5 @@ +# -*- coding: utf-8 -*- +import copy import sys import json import threading @@ -10,7 +12,7 @@ from django.utils.encoding import smart_str, force_text from django.core.exceptions import ImproperlyConfigured, EmptyResultSet from django.db import DEFAULT_DB_ALIAS -from django.db.models import Model +from django.db.models import Model, Prefetch from django.db.models.manager import BaseManager from django.db.models.query import QuerySet from django.db.models.signals import pre_save, post_save, post_delete, m2m_changed @@ -23,7 +25,8 @@ MAX_GET_RESULTS = None from .conf import model_profile, settings, ALL_OPS -from .utils import monkey_mix, stamp_fields, func_cache_key, cached_view_fab, family_has_profile +from .utils import monkey_mix, stamp_fields, func_cache_key, cached_view_fab, \ + family_has_profile, get_model_from_lookup from .utils import md5 from .sharding import get_prefix from .redis import redis_client, handle_connection_failure, load_script @@ -251,6 +254,33 @@ def nocache(self): else: return self.cache(ops=[]) + def cache_prefetch_related(self, *lookups): + """ + Same as prefetch_related but attempts to pull relations from the cache instead + + lookups - same as for django's vanilla prefetch_related() + """ + + # If relations are already fetched there is no point to continuing + if self._prefetch_done: + return self + + prefetches = [] + + for pf in lookups: + if isinstance(pf, Prefetch): + item = copy.copy(pf) + item.queryset = item.queryset.cache(ops=['fetch']) + prefetches.append(item) + + if isinstance(pf, str): + model_class = get_model_from_lookup(self.model, pf) + prefetches.append( + Prefetch(pf, model_class._default_manager.all().cache(ops=['fetch'])) + ) + + return self.prefetch_related(*prefetches) + def cloning(self, cloning=1000): self._cloning = cloning return self diff --git a/cacheops/utils.py b/cacheops/utils.py index 38f6bfdf..4ff256ab 100644 --- a/cacheops/utils.py +++ b/cacheops/utils.py @@ -129,6 +129,31 @@ def wrapper(request, *args, **kwargs): return cached_view +def get_model_from_lookup(base_model, orm_lookup): + """ + Given a base model and an ORM lookup, follow any relations and return + the final model class of the lookup. + """ + + result = base_model + for field_name in orm_lookup.split('__'): + + if field_name.endswith('_set'): + field_name = field_name.split('_set')[0] + + try: + field = result._meta.get_field(field_name) + except models.FieldDoesNotExist: + break + + if hasattr(field, 'related_model'): + result = field.related_model + else: + break + + return result + + ### Whitespace handling for template tags from django.utils.safestring import mark_safe diff --git a/tests/test_extras.py b/tests/test_extras.py index f822cb01..caf3d1f4 100644 --- a/tests/test_extras.py +++ b/tests/test_extras.py @@ -1,12 +1,14 @@ from django.db import transaction +from django.db.models import Prefetch from django.test import TestCase, override_settings from cacheops import cached_as, no_invalidation, invalidate_obj, invalidate_model, invalidate_all from cacheops.conf import settings from cacheops.signals import cache_read, cache_invalidated +from cacheops.utils import get_model_from_lookup from .utils import BaseTestCase, make_inc -from .models import Post, Category, Local, DbAgnostic, DbBinded +from .models import Post, Category, Local, DbAgnostic, DbBinded, Brand, Label class SettingsTests(TestCase): @@ -183,3 +185,25 @@ def test_db_agnostic_disabled(self): with self.assertNumQueries(1, using='slave'): list(DbBinded.objects.cache().using('slave')) + + +class CachedPrefetchTest(BaseTestCase): + + def test_get_model_from_lookup(self): + assert get_model_from_lookup(Brand, 'labels') is Label + + def test_cache_prefetch_related(self): + qs = Brand.objects.all().cache_prefetch_related('labels') + + pf = qs._prefetch_related_lookups[0] + + assert isinstance(pf, Prefetch) + assert pf.queryset.model is Label + assert pf.queryset._cacheprofile + + def test_cache_prefetch_related_with_ops(self): + qs = Brand.objects.all().cache_prefetch_related('labels') + + pf = qs._prefetch_related_lookups[0] + + self.assertEqual(pf.queryset._cacheprofile['ops'], {'fetch'})