Skip to content

Instantly share code, notes, and snippets.

@dcramer
Created March 13, 2013 20:25
Show Gist options
  • Select an option

  • Save dcramer/5155822 to your computer and use it in GitHub Desktop.

Select an option

Save dcramer/5155822 to your computer and use it in GitHub Desktop.
MySQL create_or_update for Django ORM
from __future__ import unicode_literals, division
import operator
from django.db import connections, transaction, DEFAULT_DB_ALIAS
from django.db.models.expressions import ExpressionNode, F
def create_or_update(model, using=DEFAULT_DB_ALIAS, **kwargs):
"""
Similar to get_or_create, either creates a row or updates it.
This relies on MySQL's ON DUPLICATE KEY UPDATE behavior.
>>> create_or_update(MyModel, key='value', defaults={
>>> 'value': F('value') + 1,
>>> })
"""
meta = model._meta
connection = connections[using]
qn = connection.ops.quote_name
defaults = kwargs.pop('defaults', {})
all_field_names = kwargs.keys() + defaults.keys()
fields = dict(
(f.name, qn(f.column))
for f in meta.fields
if f.name in all_field_names
)
# coerce to lists so we can maintain ordering
field_items = fields.items()
# we need a dummy instance to resolve expression nodes
inst = model()
default_values = []
for k, v in defaults.iteritems():
if isinstance(v, ExpressionNode):
v = resolve_expression_node(inst, v)
default_values.append((k, v))
c_join = ', '.join
query = """
INSERT INTO {table} ({columns})
VALUES ({values})
ON DUPLICATE KEY UPDATE
{updates}
""".format(
table=meta.db_table,
columns=c_join(c for f, c in field_items if f in all_field_names),
values=c_join('%s' for f in field_items),
updates=c_join('{0}=VALUES({0})'.format(
fields[f]) for f, _ in default_values),
)
params = []
for f, _ in field_items:
if f in defaults:
params.append(defaults[f])
elif f in kwargs:
params.append(kwargs[f])
cursor = connection.cursor()
try:
cursor.execute(query, params)
transaction.commit_unless_managed(using=using)
finally:
cursor.close()
EXPRESSION_NODE_CALLBACKS = {
ExpressionNode.ADD: operator.add,
ExpressionNode.SUB: operator.sub,
ExpressionNode.MUL: operator.mul,
ExpressionNode.DIV: operator.div,
ExpressionNode.MOD: operator.mod,
}
# Django 1.5 compatibility
# pylint: disable=E1101
try:
EXPRESSION_NODE_CALLBACKS[ExpressionNode.AND] = operator.and_
except AttributeError:
EXPRESSION_NODE_CALLBACKS[ExpressionNode.BITAND] = operator.and_
try:
EXPRESSION_NODE_CALLBACKS[ExpressionNode.OR] = operator.or_
except AttributeError:
EXPRESSION_NODE_CALLBACKS[ExpressionNode.BITOR] = operator.or_
# pylint: enable=E1101
class CannotResolve(Exception):
pass
# adapted from https://github.com/getsentry/sentry
def resolve_expression_node(instance, node):
def _resolve(instance, node):
if isinstance(node, F):
return getattr(instance, node.name)
elif isinstance(node, ExpressionNode):
return resolve_expression_node(instance, node)
return node
op = EXPRESSION_NODE_CALLBACKS.get(node.connector, None)
if not op:
raise CannotResolve
runner = _resolve(instance, node.children[0])
for n in node.children[1:]:
runner = op(runner, _resolve(instance, n))
return runner
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment