Created
June 23, 2024 20:13
-
-
Save TobeTek/74b8eb75900c261466ed30eaeb7b5070 to your computer and use it in GitHub Desktop.
A django management command to create migrations automatically for all models with Postgres' SearchVectorField
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import string | |
from collections import defaultdict | |
from django.core.management.base import BaseCommand, CommandError | |
from django.db import migrations | |
from django.db.migrations.writer import MigrationWriter | |
from django.db.models import Model | |
MIGRATION_FILE_NAME = "searchvectortrigger" | |
class Command(BaseCommand): | |
""" | |
Create migrations to create triggers for search vector fields. | |
Should be invoked after `./manage.py makemigrations` | |
""" | |
help = "Creates new migration(s) to create triggers for search vectors in models." | |
include_header = True | |
@property | |
def log_output(self): | |
return self.stdout | |
def log(self, msg): | |
self.log_output.write(msg) | |
def handle(self, *app_labels, **options): | |
from articles.models.articles import Article | |
from articles.models.categories import Category | |
from community.models import Post | |
from tools_and_settings.models import AppTool | |
SEARCH_VECTOR_FIELDS = { | |
Article: [ | |
{ | |
"vector_column": "english_fts_vector", | |
"trigger_columns": ["escaped_content_html", "topic", "title"], | |
}, | |
], | |
Category: [ | |
{ | |
"vector_column": "english_fts_vector", | |
"trigger_columns": ["escaped_content_html", "topic", "title"], | |
}, | |
], | |
Post: [ | |
{ | |
"vector_column": "english_fts_vector", | |
"trigger_columns": ["escaped_content_html", "topic", "title"], | |
}, | |
], | |
AppTool: [ | |
{ | |
"vector_column": "english_fts_vector", | |
"trigger_columns": ["escaped_content_html", "topic", "title"], | |
}, | |
{ | |
"vector_column": "description_fts_vector", | |
"trigger_columns": ["description"], | |
}, | |
], | |
} | |
search_trigger_migrations = defaultdict(list) | |
for model, search_vectors in SEARCH_VECTOR_FIELDS.items(): | |
operations = [] | |
for search_vector in search_vectors: | |
operations.append( | |
generate_search_vector_sql( | |
model=model, | |
vector_column=search_vector["vector_column"], | |
trigger_columns=search_vector["trigger_columns"], | |
) | |
) | |
search_trigger_migrations[model._meta.app_label].append( | |
[ | |
model._meta.model_name, | |
operations, | |
] | |
) | |
self.write_migration_files(search_trigger_migrations) | |
def write_migration_files(self, changes): | |
""" | |
Take a changes dict and write them out as migration files. | |
""" | |
for app_label, model_migrations in changes.items(): | |
for [model_name, operations] in model_migrations: | |
subclass = type( | |
"Migration", | |
(migrations.Migration,), | |
{ | |
"dependencies": [], | |
"operations": operations, | |
}, | |
) | |
migration = subclass( | |
name=f"0001_{MIGRATION_FILE_NAME}_{model_name}", | |
app_label=app_label, | |
) | |
writer = MigrationWriter(migration, self.include_header) | |
# Add dependency migrations if they exist | |
if dependency := self.get_most_recent_migration(writer.basedir): | |
dependency_migration_no, _ = dependency.split("_", 1) | |
new_migration_no = int(dependency_migration_no) + 1 | |
subclass = type( | |
"Migration", | |
(migrations.Migration,), | |
{ | |
"dependencies": [(app_label, dependency)], | |
"operations": operations, | |
}, | |
) | |
migration = subclass( | |
name=f"{new_migration_no:0>4}_{MIGRATION_FILE_NAME}_{model_name}", | |
app_label=app_label, | |
) | |
writer = MigrationWriter(migration, self.include_header) | |
migrations_directory = os.path.dirname(writer.path) | |
if self.has_search_vector_migration(writer.basedir, model_name): | |
continue | |
if not os.path.exists(migrations_directory): | |
os.makedirs(migrations_directory, exist_ok=True) | |
init_path = os.path.join(migrations_directory, "__init__.py") | |
if not os.path.isfile(init_path): | |
open(init_path, "w").close() | |
migration_string = writer.as_string() | |
with open(writer.path, "w", encoding="utf-8") as fh: | |
fh.write(migration_string) | |
def has_search_vector_migration(self, app_migrations_folder: str, model_name: str): | |
for filename in sorted(os.listdir(app_migrations_folder)): | |
if MIGRATION_FILE_NAME in filename and model_name in filename: | |
return True | |
return False | |
def get_most_recent_migration(self, app_migrations_folder: str): | |
migration_files = [ | |
filename | |
for filename in os.listdir(app_migrations_folder) | |
if "__init__" not in filename | |
] | |
migration_files = sorted(migration_files, key=lambda a: str(a)) | |
if migration_files: | |
latest_migration, _ = migration_files[-1].rsplit(".", 1) | |
return latest_migration | |
def generate_search_vector_sql( | |
model: type[Model], vector_column: str, trigger_columns: list[str] | |
): | |
CREATE_TRIGGER_SQL = """ALTER TABLE {model_table} DROP COLUMN IF EXISTS {vector_column}; | |
ALTER TABLE {model_table} ADD COLUMN {vector_column} tsvector GENERATED ALWAYS AS ({setweight_stmts}) STORED;""" | |
REVERSE_CREATE_TRIGGER_SQL = ( | |
"""ALTER TABLE {model_table} DROP COLUMN {vector_column};""" | |
) | |
if len(trigger_columns) > len(string.ascii_uppercase): | |
CommandError("Maximum number of trigger columns exceeded for search vector") | |
db_table = model._meta.db_table | |
setweight_stmts = " || ".join( | |
[ | |
f"setweight(to_tsvector('english', coalesce('{db_table}.{column}', '')), '{string.ascii_uppercase[indx]}')" | |
for indx, column in enumerate(trigger_columns) | |
] | |
) | |
return migrations.RunSQL( | |
sql=CREATE_TRIGGER_SQL.format( | |
model_table=db_table, | |
vector_column=vector_column, | |
setweight_stmts=setweight_stmts, | |
).replace("\n", " "), | |
reverse_sql=REVERSE_CREATE_TRIGGER_SQL.format( | |
model_table=db_table, vector_column=vector_column | |
).replace("\n", " "), | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This isn't necessary with the new GeneratedField that came out in Django 5.0
This is how I would go about this now: