|
import functools |
|
from django.db import models |
|
from django.db.models.constants import LOOKUP_SEP |
|
from django.db.models.query import ModelIterable |
|
from polymorphic.models import PolymorphicModel |
|
|
|
|
|
def rgetattr(obj, attr, separator='.'): |
|
"""Recursive getattr to retrieve attributes of nested objects.""" |
|
return functools.reduce(getattr, [obj] + attr.split(separator)) |
|
|
|
|
|
def rsetattr(obj, attr, value, separator='.'): |
|
"""Recursively getattr to fetch final object layer before using setattr.""" |
|
attrs = attr.split(separator) |
|
setattr(functools.reduce(getattr, attrs[:-1], obj), attrs[-1], value) |
|
|
|
def is_polymorphic_subclass(super_cls, sub_cls): |
|
try: |
|
return (issubclass(sub_cls, models.Model) and |
|
sub_cls != models.Model and |
|
sub_cls != super_cls and |
|
sub_cls != PolymorphicModel and |
|
super_cls in sub_cls._meta.parents) |
|
except AttributeError: |
|
return False |
|
|
|
|
|
def polymorphic_iterator(*fields): |
|
class PolymorphicModelIterable: |
|
def __init__(self, *args, **kwargs): |
|
self.iterable = ModelIterable(*args, **kwargs) |
|
|
|
def __iter__(self): |
|
for obj in self.iterable: |
|
for field in fields: |
|
# Must get recursively in case our related polymorphic model is nested. |
|
instance = rgetattr(obj, field, separator=LOOKUP_SEP) |
|
real_instance_name = instance.polymorphic_ctype.model |
|
# We must copy the field cache for the base_model instance to the real instance |
|
# else additional data from select_related will be lost. |
|
real_instance = instance._state.fields_cache.pop(real_instance_name) |
|
real_instance._state.fields_cache = instance._state.fields_cache |
|
# Same recursion goes here for setting the related object. |
|
rsetattr(obj, field, real_instance, separator=LOOKUP_SEP) |
|
yield obj |
|
|
|
return PolymorphicModelIterable |
|
|
|
|
|
class PolymorphicRelatedQuerySetMixin: |
|
"""A class with a relationship to a polymorphic model should use this queryset""" |
|
def _get_nested_base_model(self, field): |
|
field_parts = field.split(LOOKUP_SEP) |
|
model = self.model |
|
for part in field_parts: |
|
field = getattr(model, part) |
|
# Should find a better solution to determine the related model than below. |
|
try: |
|
# In case of forward related descriptors. |
|
model = field.field.related_model |
|
except AttributeError: |
|
# In case of reverse related descriptors. |
|
model = field.related.related_model |
|
|
|
return model |
|
|
|
def select_polymorphic_related(self, *fields): |
|
""" |
|
Specify fields that should be cast to the real polymorphic class. |
|
""" |
|
subclass_names = [] |
|
if fields: |
|
for field in fields: |
|
field_class = self._get_nested_base_model(field) |
|
# This is somewhat a replication of PolymorphicModel._get_inheritance_relation_fields_and_models(), |
|
# but it's necessary to do this unless we want to instantiate a base_model instance for every |
|
# query made. Would be a consideration to perhaps cache the results. |
|
for sub_cls in field_class.__subclasses__(): |
|
if is_polymorphic_subclass(field_class, sub_cls): |
|
if sub_cls._meta.parents[field_class].remote_field.related_name: |
|
subclass_names.append(sub_cls._meta.parents[field_class].remote_field.related_name) |
|
else: |
|
subclass_names.append('{}__{}'.format(field, sub_cls.__name__.lower())) |
|
# We also need to add the polymorphic_ctype field name |
|
polymorphic_ctype_field_name = field_class.polymorphic_internal_model_fields[0] |
|
subclass_names.append('{}__{}'.format(field, polymorphic_ctype_field_name)) |
|
self._iterable_class = polymorphic_iterator(*fields) |
|
return self.select_related(*subclass_names) |
|
|
|
|
|
class PolymorphicRelatedQuerySet(PolymorphicRelatedQuerySetMixin, models.QuerySet): |
|
pass |
|
|