Created
October 28, 2013 18:57
-
-
Save loic/7202542 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/django/db/models/query.py b/django/db/models/query.py | |
index 34e8dab..55b9ab3 100644 | |
--- a/django/db/models/query.py | |
+++ b/django/db/models/query.py | |
@@ -1620,9 +1620,10 @@ class RawQuerySet(object): | |
class Prefetch(object): | |
- def __init__(self, lookup, queryset=None): | |
+ def __init__(self, lookup, queryset=None, parent_filter=None): | |
self.lookup = lookup | |
self.queryset = queryset | |
+ self.parent_filter = parent_filter | |
def normalize_prefetch_lookups(lookups, prefix=None): | |
@@ -1646,7 +1647,7 @@ def normalize_prefetch_lookups(lookups, prefix=None): | |
# Then we apply a prefix if needed. | |
if prefix: | |
to_attr = LOOKUP_SEP.join([prefix, to_attr]) | |
- prefetch = Prefetch(LOOKUP_SEP.join([prefix, prefetch.lookup]), prefetch.queryset) | |
+ prefetch = Prefetch(LOOKUP_SEP.join([prefix, prefetch.lookup]), prefetch.queryset, prefetch.parent_filter) | |
ret.append((to_attr, prefetch)) | |
ret.sort(key=lambda x: x[0]) | |
@@ -1844,6 +1845,9 @@ def prefetch_one_level(instances, prefetcher, to_attr, lookup): | |
# The 'values to be matched' must be hashable as they will be used | |
# in a dictionary. | |
+ if lookup.parent_filter: | |
+ instances = filter(lookup.parent_filter, instances) | |
+ | |
rel_qs, rel_obj_attr, instance_attr, single, cache_name =\ | |
prefetcher.get_prefetch_queryset(instances, lookup.queryset) | |
# We have to handle the possibility that the default manager itself added | |
diff --git a/tests/prefetch_related/tests.py b/tests/prefetch_related/tests.py | |
index 4fb71bb..a4f0bd2 100644 | |
--- a/tests/prefetch_related/tests.py | |
+++ b/tests/prefetch_related/tests.py | |
@@ -367,6 +367,14 @@ class CustomPrefetchTests(TestCase): | |
self.assertEquals(lst2[0].hlst[0].rooms_lst[1], self.room1_2) | |
self.assertEquals(len(lst2[1].hlst), 0) | |
+ def test_parent_filter(self): | |
+ odd_pk = lambda x: x.pk % 2 | |
+ | |
+ houses = House.objects.prefetch_related(Prefetch('rooms', parent_filter=odd_pk)) | |
+ for house in houses: | |
+ num_queries = 0 if odd_pk(house) else 1 | |
+ with self.assertNumQueries(num_queries): | |
+ list(house.rooms.all()) | |
class DefaultManagerTests(TestCase): | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment