Last active
March 3, 2025 20:00
-
-
Save crucialfelix/7fa53265ed11e6761531f1b2e0d1f36a to your computer and use it in GitHub Desktop.
Django BulkSave - batches insert, updates, deletes and m2m into the minimum number of queries
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import contextlib | |
import hashlib | |
import logging | |
from collections import defaultdict | |
from decimal import Decimal | |
from django.db.models import DecimalField, ForeignKey | |
log = logging.getLogger(__name__) | |
class BulkSave(object): | |
""" | |
Batches inserts, deletions and updates together to perform the minimum number of | |
SQL queries. | |
This can efficiently set ManyToMany relationships on many objects. | |
Batched inserts, batched updates, batched deletions. | |
Large numbers of operations can be stored in this BulkSave and then performed when | |
save() is called. | |
Usage:: | |
with bulk_save() as bs: | |
# add objects. Will use a single insert statement | |
bs.add_insert(NewModel(value=1)) | |
bs.add_insert(NewModel(value=2)) | |
nm = NewModel(value=3) | |
bs.add_insert(nm) | |
# set m2m fields even though the field has not yet been saved | |
bs.set_m2m(nm, 'owners', [user1, user2]) | |
# equivalent to: | |
# nm.owners = [user1, user2] | |
# update | |
update_me = NewModel.objects.get(id=2) | |
with bs.changing(update_me): | |
update_me.value = 2 | |
""" | |
def __init__(self): | |
self.snapshots = defaultdict(dict) | |
self.updates = defaultdict(dict) | |
self.inserts = defaultdict(list) | |
self.deletes = defaultdict(list) | |
self.m2m = defaultdict(dict) | |
self.m2m_models = defaultdict(dict) | |
self.saved = False | |
def take_snapshot(self, model): | |
""" | |
Take a snapshot of all field values on a model, | |
prior to possibly setting some of those fields. | |
Afterwards call add_changed_fields with the model | |
to store any changes to be commited when the BulkSave completes. | |
""" | |
klass = model.__class__ | |
opts = klass._meta | |
self.snapshots[klass][model.pk] = {} | |
for field in opts.fields: | |
atn = field.get_attname() | |
val = getattr(model, atn) | |
self.snapshots[klass][model.pk][atn] = val | |
def add_changed_fields(self, model): | |
""" | |
Having taken a snapshot of model | |
previously, add any changed fields | |
to the scheduled updates | |
For foreign_key compare with the id | |
and set by setting the object | |
""" | |
qwargs = {} | |
klass = model.__class__ | |
opts = klass._meta | |
for field in opts.fields: | |
atn = field.get_attname() | |
val = getattr(model, atn) | |
try: | |
prev = self.snapshots[klass][model.pk][atn] | |
except KeyError: | |
raise Exception("No snapshot found for %s pk=%s attribute=%s" % | |
(klass, model.pk, atn)) | |
if val and (isinstance(field, DecimalField)): | |
if not isinstance(val, Decimal): | |
if not isinstance(val, (int, float, str, unicode)): | |
raise Exception("%s %s is not a number, is: %s" | |
% (model, val, type(val))) | |
val = Decimal(str(round(float(val), field.decimal_places))) | |
if prev != val: | |
# print "CHANGED (%s).%s %s => %s" % (model, atn, prev, val) | |
qwargs[field.name] = getattr(model, field.name) | |
if qwargs: | |
self.set(model, qwargs) | |
del self.snapshots[klass][model.pk] | |
def has_changed(self, model): | |
""" | |
Having previously called take_snapshot(model), | |
determine if any fields have been changed since then. | |
""" | |
klass = model.__class__ | |
opts = klass._meta | |
for field in opts.fields: | |
atn = field.get_attname() | |
val = getattr(model, atn) | |
prev = self.snapshots[klass][model.pk][atn] | |
if val and (isinstance(field, DecimalField)): | |
if not isinstance(val, Decimal): | |
if isinstance(val, (int, float)): | |
raise Exception("%s %s is not a number, is: %s" | |
% (model, val, type(val))) | |
val = Decimal(str(round(val, field.decimal_places))) | |
if prev != val: | |
return True | |
return False | |
def add_insert(self, model): | |
""" | |
Add an unsaved model (with no pk) to be inserted. | |
""" | |
self.inserts[model.__class__].append(model) | |
def add_delete(self, model): | |
""" | |
Add a model to be deleted. | |
""" | |
self.deletes[model.__class__].append(model) | |
def set(self, model, qwargs): | |
""" | |
Set fields update for a model using a dict | |
""" | |
klass = model.__class__ | |
if model.pk not in self.updates[klass]: | |
self.updates[klass][model.pk] = {} | |
self.updates[klass][model.pk].update(qwargs) | |
def set_m2m(self, model, attname, objects): | |
""" | |
Set many-to-many objects for a model. | |
Equivalent to `model.{attname} = objects` | |
But it will do this in bulk with | |
one query to check for the current state (exists) | |
one for all deletes | |
and one for all inserts. | |
""" | |
klass = model.__class__ | |
mid = model.pk | |
if mid not in self.m2m[klass]: | |
self.m2m[klass][mid] = {} | |
# Assert that all models are of the same type | |
assert len(set([type(m) for m in objects])) <= 1, Exception( | |
"Mixed models supplied to set_m2m: {} {}.{} = {}".format(model, type(model), attname, objects)) | |
self.m2m[klass][mid][attname] = [ | |
obj if isinstance(obj, int) else obj.pk | |
for obj in objects | |
] | |
self.m2m_models[klass][mid] = model | |
@contextlib.contextmanager | |
def changing(self, obj): | |
""" | |
Set fields on an object regardless of whether | |
it is updating or inserting the object. | |
usage:: | |
with bulk_save.changing(model): | |
model.value = 1 | |
If a pk exists (updating the model) | |
then it does snapshot then add_changed_fields. | |
If no pk (creating) then it does add_insert/ | |
""" | |
creating = obj.pk is None | |
if not creating: | |
self.take_snapshot(obj) | |
yield | |
if not creating: | |
self.add_changed_fields(obj) | |
else: | |
self.add_insert(obj) | |
def save(self): | |
""" | |
Perform all updates/inserts/deletes and m2m changes. | |
""" | |
if self.saved: | |
raise Exception("BulkSave has already saved") | |
self.save_inserts() | |
self.save_updates() | |
self.save_deletes() | |
self.save_m2m() | |
self.saved = True | |
def save_inserts(self): | |
for klass, models in list(self.inserts.items()): | |
self.save_inserts_for_model(klass, models) | |
def save_inserts_for_model(self, klass, models): | |
opts = klass._meta | |
for model in models: | |
for field in opts.fields: | |
# if the foreign key field to an unsaved object | |
# but now the object has been saved | |
# then it still has no {fk}_id set | |
# so set it now with that id | |
if isinstance(field, ForeignKey): | |
atn = field.get_attname() | |
val = getattr(model, atn) | |
if val is None: | |
fk = getattr(model, field.get_cache_name(), None) | |
if fk: | |
val = fk.pk | |
setattr(model, field.get_attname(), val) | |
try: | |
klass.objects.bulk_create(models, 100) | |
except Exception as e: | |
# an IntegrityError or something | |
# report what the model and db error message was | |
raise Exception("%r while saving models: %s" % (e, klass)) | |
def save_updates(self): | |
""" | |
Batch updates where possible. | |
[0.030] UPDATE "nsproperties_apt" SET "available_on" = '2017-07-02'::date WHERE "nsproperties_apt"."id" = 704702 | |
[0.017] UPDATE "nsproperties_apt" SET "available_on" = '2017-07-05'::date WHERE "nsproperties_apt"."id" IN (704687, 704696) | |
[0.023] UPDATE "nsproperties_apt" SET "available_on" = '2017-07-06'::date WHERE "nsproperties_apt"."id" IN (704683, 704691, 704692, 704693, 704694, 704697, 704698, 704704) | |
""" | |
for klass, models in list(self.updates.items()): | |
batched_qwargs = dict() | |
batched_qwargs_pks = defaultdict(list) | |
for pk, qwargs in list(models.items()): | |
hh = dict_hash(qwargs) | |
batched_qwargs[hh] = qwargs | |
batched_qwargs_pks[hh].append(pk) | |
for hh, qwargs in list(batched_qwargs.items()): | |
pks = batched_qwargs_pks[hh] | |
pkwargs = dict(pk__in=pks) if len(pks) > 1 else dict(pk=pks[0]) | |
klass.objects.filter(**pkwargs).update(**qwargs) | |
def save_deletes(self): | |
for klass, models in list(self.deletes.items()): | |
klass.objects.filter( | |
pk__in=[model.pk for model in models]).delete() | |
def save_m2m(self): | |
""" | |
self.m2m:: | |
{ | |
klass: { | |
model: { | |
m2m_attr: [id, id, ...] | |
} | |
} | |
} | |
""" | |
for klass in list(self.m2m.keys()): | |
self.save_m2m_for_model(klass) | |
def save_m2m_for_model(self, klass): | |
models_fields_ids = self.m2m[klass] | |
# model to get | |
fields_models_to_lookup = defaultdict(set) | |
# related models to get | |
for mid, fields_ids in list(models_fields_ids.items()): | |
model = self.m2m_models[klass][mid] | |
# model, {field: [id, id, ...], ...} | |
if model.pk is None: | |
raise Exception("No pk for model %s. cannot save m2m %s" | |
% (model, fields_ids)) | |
for field, _ in list(fields_ids.items()): | |
fields_models_to_lookup[field].add(model) | |
for field, models_to_lookup in list(fields_models_to_lookup.items()): | |
self.save_m2m_for_field( | |
klass, field, models_to_lookup, models_fields_ids) | |
def save_m2m_for_field(self, klass, field, models_to_lookup, models_fields_ids): | |
opts = klass._meta | |
# joins that will need to be made | |
# {join_model: join_attrs[]} | |
joins_to_add = defaultdict(list) | |
# {join.objects: join_id[]} | |
joins_to_delete = defaultdict(list) | |
ff = opts.get_field(field) | |
# apt_contacts | |
join_objects = ff.rel.through.objects | |
# apt__in | |
filter_in = "%s__in" % ff.m2m_field_name() | |
qwargs = {filter_in: models_to_lookup} | |
# field names on the join object | |
# apt_id | |
getr = ff.m2m_column_name() | |
# contact_id | |
othr = ff.m2m_reverse_name() | |
existing = defaultdict(list) | |
joins = dict() | |
# find existing, joins | |
for join in join_objects.filter(**qwargs): | |
one_id = getattr(join, getr) | |
two_id = getattr(join, othr) | |
existing[one_id].append(two_id) | |
joins[(one_id, two_id)] = join | |
# Compare existing joins with what should exist | |
for mid, fields_ids in list(models_fields_ids.items()): | |
model = self.m2m_models[klass][mid] | |
current = set(existing[model.id]) | |
for fg, shoulds in list(fields_ids.items()): | |
if fg == field: | |
shoulds = set(shoulds) | |
to_remove = current.difference(shoulds) | |
if to_remove: | |
rmv_join_ids = [ | |
joins[(model.id, r)].id for r in to_remove | |
] | |
joins_to_delete[join_objects].extend( | |
rmv_join_ids) | |
to_add = shoulds.difference(current) | |
if to_add: | |
join_model = ff.rel.through | |
for a in to_add: | |
join_params = { | |
ff.m2m_column_name(): model.pk, | |
ff.m2m_reverse_name(): a | |
} | |
assert a and model.pk, \ | |
Exception("null id for join: %s %s" | |
% (join_model, join_params)) | |
joins_to_add[join_model].append(join_params) | |
for join_model_objects, to_delete in list(joins_to_delete.items()): | |
join_model_objects.filter(id__in=to_delete).delete() | |
for join_model, to_adds in list(joins_to_add.items()): | |
joins = [join_model(**params) for params in to_adds] | |
join_model.objects.bulk_create(joins, 500) | |
@contextlib.contextmanager | |
def bulk_saver(maybe=None): | |
""" | |
Context manager to perform a bulk save operation. | |
If no parent is passed in then this creates a BulkSave, runs any code | |
inside the context and saves when the context closes. | |
""" | |
saver = maybe or BulkSave() | |
try: | |
yield saver | |
except Exception as e: | |
raise e | |
else: | |
if maybe is None: | |
saver.save() | |
def dict_hash(qwargs): | |
""" | |
Generate a unique hash for the dictionary | |
Nested dictionaries are not hashable, so it falls back to hashing the unicode | |
representation. | |
Nested dictionaries can be passed in when saving to a PickleField or JSONField | |
""" | |
try: | |
items = sorted(qwargs.items()) | |
return hash(frozenset(items)) | |
except TypeError: | |
return hashlib.sha1(unicode(qwargs)).hexdigest() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You shouldn't call .save() at all
The example above is this:
so for your case: