Created
October 18, 2022 14:42
Django 4.1 serial to identity migration script, alternative version
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Alternative version that uses internal-poking technique from: | |
https://www.enterprisedb.com/blog/postgresql-10-identity-columns-explained | |
""" | |
from __future__ import annotations | |
import argparse | |
from typing import Any | |
from django.core.management.base import BaseCommand | |
from django.db import DEFAULT_DB_ALIAS, connections | |
from django.db.backends.utils import CursorWrapper | |
from django.db.transaction import atomic | |
class Command(BaseCommand): | |
help = "Migrate all tables using 'serial' columns to use 'identity' instead." | |
def add_arguments(self, parser: argparse.ArgumentParser) -> None: | |
parser.add_argument( | |
"--database", | |
default=DEFAULT_DB_ALIAS, | |
help='Which database to update. Defaults to the "default" database.', | |
) | |
parser.add_argument( | |
"--write", | |
action="store_true", | |
default=False, | |
help="Actually edit the database", | |
) | |
def handle(self, *args: Any, database: str, write: bool, **kwargs: Any) -> None: | |
if not write: | |
self.stdout.write("In dry run mode (--write not passed)") | |
with connections[database].cursor() as cursor: | |
cursor.execute(find_serial_columns) | |
column_specs = cursor.fetchall() | |
self.stdout.write(f"Found {len(column_specs)} columns to update") | |
for table_name, column_name in column_specs: | |
print(table_name, column_name) | |
migrate_serial_to_identity( | |
cursor, database, table_name, column_name, write | |
) | |
# Adapted from: https://dba.stackexchange.com/a/90567 | |
find_serial_columns = """\ | |
SELECT | |
a.attrelid::regclass::text AS table_name, | |
a.attname AS column_name | |
FROM pg_attribute a | |
WHERE | |
a.attnum > 0 | |
AND NOT a.attisdropped | |
AND a.atttypid = ANY ('{int,int8,int2}'::regtype[]) | |
AND EXISTS ( | |
SELECT FROM pg_attrdef ad | |
WHERE | |
ad.adrelid = a.attrelid | |
AND ad.adnum = a.attnum | |
AND ( | |
pg_get_expr(ad.adbin, ad.adrelid) | |
= | |
'nextval(''' | |
|| ( | |
pg_get_serial_sequence(a.attrelid::regclass::text, a.attname) | |
)::regclass | |
|| '''::regclass)' | |
) | |
) | |
ORDER BY a.attnum | |
""" | |
def migrate_serial_to_identity( | |
cursor: CursorWrapper, | |
database: str, | |
table_name: str, | |
column_name: str, | |
write: bool, | |
) -> None: | |
with atomic(using=database): | |
# Adapted from upgrade_serial_to_identity() in: | |
# https://www.enterprisedb.com/blog/postgresql-10-identity-columns-explained | |
cursor.execute( | |
"""\ | |
SELECT attnum | |
FROM pg_attribute | |
WHERE attrelid = %s::regclass | |
AND attname = %s | |
""", | |
(table_name, column_name), | |
) | |
column_number = cursor.fetchone()[0] | |
cursor.execute( | |
"""\ | |
SELECT objid | |
FROM pg_depend | |
WHERE (refclassid, refobjid, refobjsubid) = ('pg_class'::regclass, %s::regclass, %s) | |
AND classid = 'pg_class'::regclass | |
AND objsubid = 0 | |
AND deptype = 'a' | |
""", | |
(table_name, column_number), | |
) | |
results = cursor.fetchall() | |
if len(results) < 1: | |
print("Failed to find linked sequence") | |
raise SystemExit(1) | |
elif len(results) > 1: | |
print("Found more than one linked sequence!") | |
raise SystemExit(1) | |
sequence_id = results[0][0] | |
if write: | |
# Drop the default | |
qn = cursor.db.ops.quote_name | |
cursor.execute( | |
f"""\ | |
ALTER TABLE {qn(table_name)} | |
ALTER COLUMN {qn(column_name)} DROP DEFAULT; | |
""" | |
) | |
# Modify sequence to be an internal dependency | |
cursor.execute( | |
"""\ | |
UPDATE pg_depend | |
SET deptype = 'i' | |
WHERE (classid, objid, objsubid) = ('pg_class'::regclass, %s, 0) | |
AND deptype = 'a' | |
""", | |
(sequence_id,), | |
) | |
# Change to identity column, generated by default | |
cursor.execute( | |
"""\ | |
UPDATE pg_attribute | |
SET attidentity = 'd' | |
WHERE attrelid = %s::regclass | |
AND attname = %s | |
""", | |
(table_name, column_name), | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment