Skip to content

Instantly share code, notes, and snippets.

@crucialfelix
Last active March 3, 2025 20:00
Show Gist options
  • Save crucialfelix/7fa53265ed11e6761531f1b2e0d1f36a to your computer and use it in GitHub Desktop.
Save crucialfelix/7fa53265ed11e6761531f1b2e0d1f36a to your computer and use it in GitHub Desktop.
Django BulkSave - batches insert, updates, deletes and m2m into the minimum number of queries
import contextlib
import hashlib
import logging
from collections import defaultdict
from decimal import Decimal
from django.db.models import DecimalField, ForeignKey
log = logging.getLogger(__name__)
class BulkSave(object):
"""
Batches inserts, deletions and updates together to perform the minimum number of
SQL queries.
This can efficiently set ManyToMany relationships on many objects.
Batched inserts, batched updates, batched deletions.
Large numbers of operations can be stored in this BulkSave and then performed when
save() is called.
Usage::
with bulk_save() as bs:
# add objects. Will use a single insert statement
bs.add_insert(NewModel(value=1))
bs.add_insert(NewModel(value=2))
nm = NewModel(value=3)
bs.add_insert(nm)
# set m2m fields even though the field has not yet been saved
bs.set_m2m(nm, 'owners', [user1, user2])
# equivalent to:
# nm.owners = [user1, user2]
# update
update_me = NewModel.objects.get(id=2)
with bs.changing(update_me):
update_me.value = 2
"""
def __init__(self):
self.snapshots = defaultdict(dict)
self.updates = defaultdict(dict)
self.inserts = defaultdict(list)
self.deletes = defaultdict(list)
self.m2m = defaultdict(dict)
self.m2m_models = defaultdict(dict)
self.saved = False
def take_snapshot(self, model):
"""
Take a snapshot of all field values on a model,
prior to possibly setting some of those fields.
Afterwards call add_changed_fields with the model
to store any changes to be commited when the BulkSave completes.
"""
klass = model.__class__
opts = klass._meta
self.snapshots[klass][model.pk] = {}
for field in opts.fields:
atn = field.get_attname()
val = getattr(model, atn)
self.snapshots[klass][model.pk][atn] = val
def add_changed_fields(self, model):
"""
Having taken a snapshot of model
previously, add any changed fields
to the scheduled updates
For foreign_key compare with the id
and set by setting the object
"""
qwargs = {}
klass = model.__class__
opts = klass._meta
for field in opts.fields:
atn = field.get_attname()
val = getattr(model, atn)
try:
prev = self.snapshots[klass][model.pk][atn]
except KeyError:
raise Exception("No snapshot found for %s pk=%s attribute=%s" %
(klass, model.pk, atn))
if val and (isinstance(field, DecimalField)):
if not isinstance(val, Decimal):
if not isinstance(val, (int, float, str, unicode)):
raise Exception("%s %s is not a number, is: %s"
% (model, val, type(val)))
val = Decimal(str(round(float(val), field.decimal_places)))
if prev != val:
# print "CHANGED (%s).%s %s => %s" % (model, atn, prev, val)
qwargs[field.name] = getattr(model, field.name)
if qwargs:
self.set(model, qwargs)
del self.snapshots[klass][model.pk]
def has_changed(self, model):
"""
Having previously called take_snapshot(model),
determine if any fields have been changed since then.
"""
klass = model.__class__
opts = klass._meta
for field in opts.fields:
atn = field.get_attname()
val = getattr(model, atn)
prev = self.snapshots[klass][model.pk][atn]
if val and (isinstance(field, DecimalField)):
if not isinstance(val, Decimal):
if isinstance(val, (int, float)):
raise Exception("%s %s is not a number, is: %s"
% (model, val, type(val)))
val = Decimal(str(round(val, field.decimal_places)))
if prev != val:
return True
return False
def add_insert(self, model):
"""
Add an unsaved model (with no pk) to be inserted.
"""
self.inserts[model.__class__].append(model)
def add_delete(self, model):
"""
Add a model to be deleted.
"""
self.deletes[model.__class__].append(model)
def set(self, model, qwargs):
"""
Set fields update for a model using a dict
"""
klass = model.__class__
if model.pk not in self.updates[klass]:
self.updates[klass][model.pk] = {}
self.updates[klass][model.pk].update(qwargs)
def set_m2m(self, model, attname, objects):
"""
Set many-to-many objects for a model.
Equivalent to `model.{attname} = objects`
But it will do this in bulk with
one query to check for the current state (exists)
one for all deletes
and one for all inserts.
"""
klass = model.__class__
mid = model.pk
if mid not in self.m2m[klass]:
self.m2m[klass][mid] = {}
# Assert that all models are of the same type
assert len(set([type(m) for m in objects])) <= 1, Exception(
"Mixed models supplied to set_m2m: {} {}.{} = {}".format(model, type(model), attname, objects))
self.m2m[klass][mid][attname] = [
obj if isinstance(obj, int) else obj.pk
for obj in objects
]
self.m2m_models[klass][mid] = model
@contextlib.contextmanager
def changing(self, obj):
"""
Set fields on an object regardless of whether
it is updating or inserting the object.
usage::
with bulk_save.changing(model):
model.value = 1
If a pk exists (updating the model)
then it does snapshot then add_changed_fields.
If no pk (creating) then it does add_insert/
"""
creating = obj.pk is None
if not creating:
self.take_snapshot(obj)
yield
if not creating:
self.add_changed_fields(obj)
else:
self.add_insert(obj)
def save(self):
"""
Perform all updates/inserts/deletes and m2m changes.
"""
if self.saved:
raise Exception("BulkSave has already saved")
self.save_inserts()
self.save_updates()
self.save_deletes()
self.save_m2m()
self.saved = True
def save_inserts(self):
for klass, models in list(self.inserts.items()):
self.save_inserts_for_model(klass, models)
def save_inserts_for_model(self, klass, models):
opts = klass._meta
for model in models:
for field in opts.fields:
# if the foreign key field to an unsaved object
# but now the object has been saved
# then it still has no {fk}_id set
# so set it now with that id
if isinstance(field, ForeignKey):
atn = field.get_attname()
val = getattr(model, atn)
if val is None:
fk = getattr(model, field.get_cache_name(), None)
if fk:
val = fk.pk
setattr(model, field.get_attname(), val)
try:
klass.objects.bulk_create(models, 100)
except Exception as e:
# an IntegrityError or something
# report what the model and db error message was
raise Exception("%r while saving models: %s" % (e, klass))
def save_updates(self):
"""
Batch updates where possible.
[0.030] UPDATE "nsproperties_apt" SET "available_on" = '2017-07-02'::date WHERE "nsproperties_apt"."id" = 704702
[0.017] UPDATE "nsproperties_apt" SET "available_on" = '2017-07-05'::date WHERE "nsproperties_apt"."id" IN (704687, 704696)
[0.023] UPDATE "nsproperties_apt" SET "available_on" = '2017-07-06'::date WHERE "nsproperties_apt"."id" IN (704683, 704691, 704692, 704693, 704694, 704697, 704698, 704704)
"""
for klass, models in list(self.updates.items()):
batched_qwargs = dict()
batched_qwargs_pks = defaultdict(list)
for pk, qwargs in list(models.items()):
hh = dict_hash(qwargs)
batched_qwargs[hh] = qwargs
batched_qwargs_pks[hh].append(pk)
for hh, qwargs in list(batched_qwargs.items()):
pks = batched_qwargs_pks[hh]
pkwargs = dict(pk__in=pks) if len(pks) > 1 else dict(pk=pks[0])
klass.objects.filter(**pkwargs).update(**qwargs)
def save_deletes(self):
for klass, models in list(self.deletes.items()):
klass.objects.filter(
pk__in=[model.pk for model in models]).delete()
def save_m2m(self):
"""
self.m2m::
{
klass: {
model: {
m2m_attr: [id, id, ...]
}
}
}
"""
for klass in list(self.m2m.keys()):
self.save_m2m_for_model(klass)
def save_m2m_for_model(self, klass):
models_fields_ids = self.m2m[klass]
# model to get
fields_models_to_lookup = defaultdict(set)
# related models to get
for mid, fields_ids in list(models_fields_ids.items()):
model = self.m2m_models[klass][mid]
# model, {field: [id, id, ...], ...}
if model.pk is None:
raise Exception("No pk for model %s. cannot save m2m %s"
% (model, fields_ids))
for field, _ in list(fields_ids.items()):
fields_models_to_lookup[field].add(model)
for field, models_to_lookup in list(fields_models_to_lookup.items()):
self.save_m2m_for_field(
klass, field, models_to_lookup, models_fields_ids)
def save_m2m_for_field(self, klass, field, models_to_lookup, models_fields_ids):
opts = klass._meta
# joins that will need to be made
# {join_model: join_attrs[]}
joins_to_add = defaultdict(list)
# {join.objects: join_id[]}
joins_to_delete = defaultdict(list)
ff = opts.get_field(field)
# apt_contacts
join_objects = ff.rel.through.objects
# apt__in
filter_in = "%s__in" % ff.m2m_field_name()
qwargs = {filter_in: models_to_lookup}
# field names on the join object
# apt_id
getr = ff.m2m_column_name()
# contact_id
othr = ff.m2m_reverse_name()
existing = defaultdict(list)
joins = dict()
# find existing, joins
for join in join_objects.filter(**qwargs):
one_id = getattr(join, getr)
two_id = getattr(join, othr)
existing[one_id].append(two_id)
joins[(one_id, two_id)] = join
# Compare existing joins with what should exist
for mid, fields_ids in list(models_fields_ids.items()):
model = self.m2m_models[klass][mid]
current = set(existing[model.id])
for fg, shoulds in list(fields_ids.items()):
if fg == field:
shoulds = set(shoulds)
to_remove = current.difference(shoulds)
if to_remove:
rmv_join_ids = [
joins[(model.id, r)].id for r in to_remove
]
joins_to_delete[join_objects].extend(
rmv_join_ids)
to_add = shoulds.difference(current)
if to_add:
join_model = ff.rel.through
for a in to_add:
join_params = {
ff.m2m_column_name(): model.pk,
ff.m2m_reverse_name(): a
}
assert a and model.pk, \
Exception("null id for join: %s %s"
% (join_model, join_params))
joins_to_add[join_model].append(join_params)
for join_model_objects, to_delete in list(joins_to_delete.items()):
join_model_objects.filter(id__in=to_delete).delete()
for join_model, to_adds in list(joins_to_add.items()):
joins = [join_model(**params) for params in to_adds]
join_model.objects.bulk_create(joins, 500)
@contextlib.contextmanager
def bulk_saver(maybe=None):
"""
Context manager to perform a bulk save operation.
If no parent is passed in then this creates a BulkSave, runs any code
inside the context and saves when the context closes.
"""
saver = maybe or BulkSave()
try:
yield saver
except Exception as e:
raise e
else:
if maybe is None:
saver.save()
def dict_hash(qwargs):
"""
Generate a unique hash for the dictionary
Nested dictionaries are not hashable, so it falls back to hashing the unicode
representation.
Nested dictionaries can be passed in when saving to a PickleField or JSONField
"""
try:
items = sorted(qwargs.items())
return hash(frozenset(items))
except TypeError:
return hashlib.sha1(unicode(qwargs)).hexdigest()
@crucialfelix
Copy link
Author

You shouldn't call .save() at all

The example above is this:

        with bulk_save() as bs:
            # add objects. Will use a single insert statement
            bs.add_insert(NewModel(value=1))
            bs.add_insert(NewModel(value=2))
            nm = NewModel(value=3)
            bs.add_insert(nm)
            # set m2m fields even though the field has not yet been saved
            bs.set_m2m(nm, 'owners', [user1, user2])
            # equivalent to:
            # nm.owners = [user1, user2]
            # update
            update_me = NewModel.objects.get(id=2)
            with bs.changing(update_me):
                update_me.value = 2

so for your case:

    with bulk_save() as bs:
        bs.add_insert(Tag(tag="tag_1"))
        bs.add_insert(Tag(tag="tag_2"))

        nm = Tag(tag="tag_with_photo_1")
        bs.add_insert(nm)

        photo1 = Photo(name="user1")
        bs.add_insert(photo1)
        photo2 = Photo(name="user2")
        bs.add_insert(photo2)

        bs.set_m2m(nm, 'photos', [photo1, photo2])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment