Last active
December 18, 2015 15:09
-
-
Save dcramer/5802587 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class EverythingCollector(Collector): | |
""" | |
More or less identical to the default Django collector except we always | |
return relations (even when they shouldnt matter). | |
""" | |
def collect(self, objs, source=None, nullable=False, collect_related=True, | |
source_attr=None, reverse_dependency=False): | |
new_objs = self.add(objs) | |
if not new_objs: | |
return | |
model = new_objs[0].__class__ | |
# Recursively collect concrete model's parent models, but not their | |
# related objects. These will be found by meta.get_all_related_objects() | |
concrete_model = model._meta.concrete_model | |
for ptr in concrete_model._meta.parents.iteritems(): | |
if ptr: | |
# FIXME: This seems to be buggy and execute a query for each | |
# parent object fetch. We have the parent data in the obj, | |
# but we don't have a nice way to turn that data into parent | |
# object instance. | |
parent_objs = [getattr(obj, ptr.name) for obj in new_objs] | |
self.collect(parent_objs, source=model, | |
source_attr=ptr.rel.related_name, | |
collect_related=False, | |
reverse_dependency=True) | |
if collect_related: | |
for related in model._meta.get_all_related_objects( | |
include_hidden=True, include_proxy_eq=True): | |
sub_objs = self.related_objects(related, new_objs) | |
self.add(sub_objs) | |
# TODO This entire block is only needed as a special case to | |
# support cascade-deletes for GenericRelation. It should be | |
# removed/fixed when the ORM gains a proper abstraction for virtual | |
# or composite fields, and GFKs are reworked to fit into that. | |
for relation in model._meta.many_to_many: | |
if not relation.rel.through: | |
sub_objs = relation.bulk_related_objects(new_objs, self.using) | |
self.collect(sub_objs, | |
source=model, | |
source_attr=relation.rel.related_name, | |
nullable=True) | |
def merge_into(self, other, callback=lambda x: x, using='default'): | |
""" | |
Collects objects related to ``self`` and updates their foreign keys to | |
point to ``other``. | |
If ``callback`` is specified, it will be executed on each collected chunk | |
before any changes are made, and should return a modified list of results | |
that still need updated. | |
NOTE: Duplicates (unique constraints) which exist and are bound to ``other`` | |
are preserved, and relations on ``self`` are discarded. | |
""" | |
# TODO: proper support for database routing | |
s_model = type(self) | |
# Find all the objects than need to be deleted. | |
collector = EverythingCollector(using=using) | |
collector.collect([self]) | |
for model, objects in collector.data.iteritems(): | |
# find all potential keys which match our type | |
fields = set( | |
f.name for f in model._meta.fields | |
if isinstance(f, ForeignKey) | |
and f.rel.to == s_model | |
if f.rel.to | |
) | |
print model, objects, fields | |
if not fields: | |
# the collector pulls in the self reference, so if it's our model | |
# we actually assume it's probably not related to itself, and its | |
# perfectly ok | |
if model == s_model: | |
continue | |
raise TypeError('Unable to determine related keys on %r' % model) | |
for obj in objects: | |
send_signals = not model._meta.auto_created | |
# find fields which need changed | |
update_kwargs = {} | |
for f_name in fields: | |
if getattr(obj, f_name) == self: | |
update_kwargs[f_name] = other | |
if not update_kwargs: | |
# as before, if we're referencing ourself, this is ok | |
if obj == self: | |
continue | |
raise ValueError('Mismatched row present in related results') | |
signal_kwargs = { | |
'sender': model, | |
'instance': obj, | |
'using': using, | |
'migrated': True, | |
} | |
if send_signals: | |
pre_delete.send(**signal_kwargs) | |
post_delete.send(**signal_kwargs) | |
for k, v in update_kwargs.iteritems(): | |
setattr(obj, k, v) | |
if send_signals: | |
pre_save.send(created=True, **signal_kwargs) | |
sid = transaction.savepoint(using=using) | |
try: | |
model.objects.filter(pk=obj.pk).update(**update_kwargs) | |
except IntegrityError: | |
# duplicate key exists, destroy the relations | |
transaction.savepoint_rollback(sid, using=using) | |
model.objects.filter(pk=obj.pk).delete() | |
else: | |
transaction.savepoint_commit(sid, using=using) | |
if send_signals: | |
post_save.send(created=True, **signal_kwargs) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MergeIntoTest(TestCase): | |
def test_all_the_things(self): | |
user_1 = User.objects.create(username='original') | |
user_2 = User.objects.create(username='new') | |
team_1 = Team.objects.create(owner=user_1) | |
team_2 = Team.objects.create(owner=user_2) | |
project_1 = Project.objects.create(owner=user_1, team=team_1) | |
project_2 = Project.objects.create(owner=user_2, team=team_2) | |
ag = AccessGroup.objects.create(team=team_2) | |
ag.members.add(user_1) | |
ag.members.add(user_2) | |
merge_into(user_1, user_2) | |
assert Team.objects.get(id=team_1.id).owner == user_2 | |
assert Team.objects.get(id=team_2.id).owner == user_2 | |
assert Project.objects.get(id=project_1.id).owner == user_2 | |
assert Project.objects.get(id=project_2.id).owner == user_2 | |
assert list(ag.members.all()) == [user_2] | |
# make sure we didnt remove the instance | |
assert User.objects.filter(id=user_1.id).exists() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment