Last active
December 29, 2015 00:39
-
-
Save jacobg/7587826 to your computer and use it in GitHub Desktop.
The purpose of this monkey patch is to support writable fields using dotted source in Django Rest Framework. It based on the following fork which began the implementation, but left out a few details:
https://github.com/craigds/django-rest-framework/compare/writable-dotted-field-source
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# Based on this patch: https://github.com/craigds/django-rest-framework/compare/writable-dotted-field-source | |
# The purpose of this patch is to support writable fields using dotted source. It lets you do stuff like this: | |
# class UserProfileSerializer(ModelSerializer): | |
# | |
# user = PrimaryKeyRelatedField(read_only=True) | |
# | |
# active = BooleanField(source='user.is_active') | |
# first_name = CharField(source='user.first_name') | |
# last_name = CharField(source='user.last_name') | |
# email = EmailField(source='user.email') | |
# last_login = DateTimeField(source='user.last_login', read_only=True) | |
# | |
from rest_framework import compat, serializers | |
from django.core.exceptions import ObjectDoesNotExist | |
# | |
# Changes to rest_framework/fields.py | |
# | |
def set_component(obj, attr_name, value): | |
""" | |
Given an object, and an attribute name, set that attribute on the object. | |
Mirrors get_component, except set_component doesn't handle callable | |
components. | |
""" | |
if isinstance(obj, dict): | |
obj[attr_name] = value | |
else: | |
try: | |
attr = getattr(obj, attr_name, None) | |
except ObjectDoesNotExist: | |
# happens for non-null FK fields that haven't yet been set. | |
pass | |
else: | |
if compat.six.callable(attr): | |
raise TypeError("%r.%s is a method; can't set it" % (obj, attr_name)) | |
setattr(obj, attr_name, value) | |
serializers.set_component = set_component | |
def _get_source_value(self, obj, field_name): | |
""" | |
Given an object and a field name, traverses the components in | |
self.source/field_name and returns the source value from the object. | |
The source/field_name may contain dot-separated components. | |
Each component should refer to an attribute, a dict key, or a | |
callable with no arguments. | |
""" | |
source = self.source or field_name | |
value = obj | |
for component in source.split('.'): | |
value = serializers.get_component(value, component) | |
if value is None: | |
break | |
return value | |
serializers.Field._get_source_value = _get_source_value | |
def Field_field_to_native(self, obj, field_name): | |
""" | |
Given and object and a field name, returns the value that should be | |
serialized for that field. | |
""" | |
if obj is None: | |
return self.empty | |
if self.source == '*': | |
return self.to_native(obj) | |
value = self._get_source_value(obj, field_name) | |
return self.to_native(value) | |
serializers.Field.field_to_native = Field_field_to_native | |
def _set_source_value(self, obj, field_name, value): | |
""" | |
Looks up a field on the given object and sets its value. | |
Uses self.source if set, otherwise the given field name. | |
This obeys the same rules as _get_source_value, except that the | |
final component of self.source/field_name can't be a callable. | |
""" | |
source = self.source or field_name | |
parts = source.split('.') | |
last_source_part = parts.pop() | |
if len(parts) > 0: | |
if not getattr(obj, '_traversed_objects', None): | |
obj._traversed_objects = [] | |
traversed_objects = obj._traversed_objects | |
item = obj | |
accessor = '' | |
for component in parts: | |
item = serializers.get_component(item, component) | |
accessor += ('.' if accessor else '') + component | |
if accessor not in map(lambda tuple: tuple[0], traversed_objects): | |
# prepend so that deeper objects get saved first | |
# TODO: consider a depth first tree traversal, which might cover more | |
# complex cases. | |
traversed_objects.insert(0, (accessor, item)) | |
set_component(item, last_source_part, value) | |
serializers.WritableField._set_source_value = _set_source_value | |
# | |
# Changes to rest_framework/relations.py | |
# | |
def RelatedField_field_to_native(self, obj, field_name): | |
try: | |
if self.source == '*': | |
return self.to_native(obj) | |
value = self._get_source_value(obj, field_name) | |
except ObjectDoesNotExist: | |
return None | |
if value is None: | |
return None | |
if self.many: | |
if serializers.is_simple_callable(getattr(value, 'all', None)): | |
return [self.to_native(item) for item in value.all()] | |
else: | |
# Also support non-queryset iterables. | |
# This allows us to also support plain lists of related items. | |
return [self.to_native(item) for item in value] | |
return self.to_native(value) | |
serializers.RelatedField.field_to_native = RelatedField_field_to_native | |
# | |
# Changes to rest_framework/serializers.py | |
# | |
def ModelSerializer_restore_object(self, attrs, instance=None): | |
""" | |
Restore the model instance. | |
""" | |
m2m_data = {} | |
related_data = {} | |
nested_forward_relations = {} | |
meta = self.opts.model._meta | |
# Reverse fk or one-to-one relations | |
for (obj, model) in meta.get_all_related_objects_with_model(): | |
field_name = obj.field.related_query_name() | |
if field_name in attrs: | |
related_data[field_name] = attrs.pop(field_name) | |
# Reverse m2m relations | |
for (obj, model) in meta.get_all_related_m2m_objects_with_model(): | |
field_name = obj.get_accessor_name() | |
if field_name in attrs: | |
m2m_data[field_name] = attrs.pop(field_name) | |
# Forward m2m relations | |
for field in meta.many_to_many: | |
if field.name in attrs: | |
m2m_data[field.name] = attrs.pop(field.name) | |
# Nested forward relations - These need to be marked so we can save | |
# them before saving the parent model instance. | |
for field_name in attrs.keys(): | |
if isinstance(self.fields.get(field_name, None), serializers.Serializer): | |
nested_forward_relations[field_name] = attrs[field_name] | |
# Update an existing instance... | |
if instance is not None: | |
for key, val in attrs.items(): | |
self._set_source_value(instance, key, val) | |
# ...or create a new instance | |
else: | |
instance = self.opts.model(**attrs) | |
# Any relations that cannot be set until we've | |
# saved the model get hidden away on these | |
# private attributes, so we can deal with them | |
# at the point of save. | |
instance._related_data = related_data | |
instance._m2m_data = m2m_data | |
instance._nested_forward_relations = nested_forward_relations | |
return instance | |
serializers.ModelSerializer.restore_object = ModelSerializer_restore_object | |
def ModelSerializer_save_object(self, obj, **kwargs): | |
""" | |
Save the deserialized object. | |
""" | |
if getattr(obj, '_traversed_objects', None): | |
for accessor_name, sub_object in obj._traversed_objects: | |
if sub_object: | |
self.save_object(sub_object) | |
setattr(obj, accessor_name, sub_object) | |
if getattr(obj, '_nested_forward_relations', None): | |
# Nested relationships need to be saved before we can save the | |
# parent instance. | |
for field_name, sub_object in obj._nested_forward_relations.items(): | |
if sub_object: | |
self.save_object(sub_object) | |
setattr(obj, field_name, sub_object) | |
obj.save(**kwargs) | |
if getattr(obj, '_m2m_data', None): | |
for accessor_name, object_list in obj._m2m_data.items(): | |
setattr(obj, accessor_name, object_list) | |
del(obj._m2m_data) | |
if getattr(obj, '_related_data', None): | |
for accessor_name, related in obj._related_data.items(): | |
if isinstance(related, serializers.RelationsList): | |
# Nested reverse fk relationship | |
for related_item in related: | |
fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name | |
setattr(related_item, fk_field, obj) | |
self.save_object(related_item) | |
# Delete any removed objects | |
if related._deleted: | |
[self.delete_object(item) for item in related._deleted] | |
elif isinstance(related, serializers.models.Model): | |
# Nested reverse one-one relationship | |
fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name | |
setattr(related, fk_field, obj) | |
self.save_object(related) | |
else: | |
# Reverse FK or reverse one-one | |
setattr(obj, accessor_name, related) | |
del(obj._related_data) | |
serializers.ModelSerializer.save_object = ModelSerializer_save_object | |
def BaseSerializer_restore_object(self, attrs, instance=None): | |
""" | |
Deserialize a dictionary of attributes into an object instance. | |
You should override this method to control how deserialized objects | |
are instantiated. | |
""" | |
if instance is not None: | |
for k, v in attrs.items(): | |
self._set_source_value(instance, k, v) | |
return instance | |
return attrs | |
serializers.BaseSerializer.restore_object = BaseSerializer_restore_object |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment