Created
March 13, 2013 20:25
-
-
Save dcramer/5155822 to your computer and use it in GitHub Desktop.
MySQL create_or_update for Django ORM
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import unicode_literals, division | |
| import operator | |
| from django.db import connections, transaction, DEFAULT_DB_ALIAS | |
| from django.db.models.expressions import ExpressionNode, F | |
| def create_or_update(model, using=DEFAULT_DB_ALIAS, **kwargs): | |
| """ | |
| Similar to get_or_create, either creates a row or updates it. | |
| This relies on MySQL's ON DUPLICATE KEY UPDATE behavior. | |
| >>> create_or_update(MyModel, key='value', defaults={ | |
| >>> 'value': F('value') + 1, | |
| >>> }) | |
| """ | |
| meta = model._meta | |
| connection = connections[using] | |
| qn = connection.ops.quote_name | |
| defaults = kwargs.pop('defaults', {}) | |
| all_field_names = kwargs.keys() + defaults.keys() | |
| fields = dict( | |
| (f.name, qn(f.column)) | |
| for f in meta.fields | |
| if f.name in all_field_names | |
| ) | |
| # coerce to lists so we can maintain ordering | |
| field_items = fields.items() | |
| # we need a dummy instance to resolve expression nodes | |
| inst = model() | |
| default_values = [] | |
| for k, v in defaults.iteritems(): | |
| if isinstance(v, ExpressionNode): | |
| v = resolve_expression_node(inst, v) | |
| default_values.append((k, v)) | |
| c_join = ', '.join | |
| query = """ | |
| INSERT INTO {table} ({columns}) | |
| VALUES ({values}) | |
| ON DUPLICATE KEY UPDATE | |
| {updates} | |
| """.format( | |
| table=meta.db_table, | |
| columns=c_join(c for f, c in field_items if f in all_field_names), | |
| values=c_join('%s' for f in field_items), | |
| updates=c_join('{0}=VALUES({0})'.format( | |
| fields[f]) for f, _ in default_values), | |
| ) | |
| params = [] | |
| for f, _ in field_items: | |
| if f in defaults: | |
| params.append(defaults[f]) | |
| elif f in kwargs: | |
| params.append(kwargs[f]) | |
| cursor = connection.cursor() | |
| try: | |
| cursor.execute(query, params) | |
| transaction.commit_unless_managed(using=using) | |
| finally: | |
| cursor.close() | |
| EXPRESSION_NODE_CALLBACKS = { | |
| ExpressionNode.ADD: operator.add, | |
| ExpressionNode.SUB: operator.sub, | |
| ExpressionNode.MUL: operator.mul, | |
| ExpressionNode.DIV: operator.div, | |
| ExpressionNode.MOD: operator.mod, | |
| } | |
| # Django 1.5 compatibility | |
| # pylint: disable=E1101 | |
| try: | |
| EXPRESSION_NODE_CALLBACKS[ExpressionNode.AND] = operator.and_ | |
| except AttributeError: | |
| EXPRESSION_NODE_CALLBACKS[ExpressionNode.BITAND] = operator.and_ | |
| try: | |
| EXPRESSION_NODE_CALLBACKS[ExpressionNode.OR] = operator.or_ | |
| except AttributeError: | |
| EXPRESSION_NODE_CALLBACKS[ExpressionNode.BITOR] = operator.or_ | |
| # pylint: enable=E1101 | |
| class CannotResolve(Exception): | |
| pass | |
| # adapted from https://github.com/getsentry/sentry | |
| def resolve_expression_node(instance, node): | |
| def _resolve(instance, node): | |
| if isinstance(node, F): | |
| return getattr(instance, node.name) | |
| elif isinstance(node, ExpressionNode): | |
| return resolve_expression_node(instance, node) | |
| return node | |
| op = EXPRESSION_NODE_CALLBACKS.get(node.connector, None) | |
| if not op: | |
| raise CannotResolve | |
| runner = _resolve(instance, node.children[0]) | |
| for n in node.children[1:]: | |
| runner = op(runner, _resolve(instance, n)) | |
| return runner |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment