Skip to content

Instantly share code, notes, and snippets.

@clee704
Created December 31, 2011 17:05
Show Gist options
  • Save clee704/1544586 to your computer and use it in GitHub Desktop.
Save clee704/1544586 to your computer and use it in GitHub Desktop.
Rename Drupal 7 fields
#! /usr/bin/env python
from datetime import datetime
from StringIO import StringIO
import getpass, os, re, subprocess, sys
import MySQLdb
def main():
if len(sys.argv) != 3:
sys.exit("""Usage: {0} FIELD_NAME NEW_FIELD_NAME
* MySQL root account is needed to perform renaming.
* There must be no tables or columns whose names contain either FIELD_NAME or
NEW_FIELD_NAME, except for the table of the original field FIELD_NAME.
* Be sure there is no one updating the database. Turning off the web server
could help.
* Renaming is applied only to the database. Some files such as theme files may
need to be updated manually.
* Caches may have to be cleared after the renaming.
* It might fail if FIELD_NAME is short. Use the original dump to recover from
any failure.
""".format(sys.argv[0]))
try:
# TODO get these arguments from the command line; consider using argparse
user = 'root'
password = getpass.getpass('Password: ')
db = raw_input('Database: ')
backup_dir = '/www/backups'
except KeyboardInterrupt:
sys.exit('')
rename_field(sys.argv[1], sys.argv[2], user, password, db, backup_dir)
def rename_field(field_name, new_field_name, user, password, db, backup_dir):
mysql_options = [db, '--user=' + user, '--password=' + password, '-BN']
conn1 = MySQLdb.connect(db='information_schema', user=user, passwd=password)
conn2 = MySQLdb.connect(db=db, user=user, passwd=password)
c1 = conn1.cursor()
c2 = conn2.cursor()
# Check if the renaming can be done
if not re.match('^field(_[a-z]+)+$', field_name):
sys.exit('`{0}` is not a valid field name.'.format(field_name))
if not re.match('^field(_[a-z]+)+$', new_field_name):
sys.exit('`{0}` is not a valid field name.'.format(new_field_name))
c1.execute('select column_name, table_name from columns where table_schema = %s', (db,))
columns = list(c1.fetchall())
columns.sort()
tables = list(set(row[1] for row in columns))
tables.sort()
field_tables = ('field_data_' + field_name, 'field_revision_' + field_name)
if not all(name in tables for name in field_tables):
sys.exit('Field `{0}` does not exist.'.format(field_name))
for table_name in tables:
if field_name in table_name and (table_name not in field_tables or table_name.startswith(field_name)):
sys.exit('Existing table name `{0}` contains field name `{1}`.'.format(table_name, field_name))
if new_field_name in table_name:
sys.exit('Existing table name `{0}` contains field name `{1}`.'.format(table_name, new_field_name))
for column_name, table_name in columns:
if field_name in column_name and (table_name not in field_tables):
sys.exit('Existing column name `{0}` contains field name `{1}`.'.format(column_name, field_name))
if new_field_name in column_name:
sys.exit('Existing column name `{0}` contains field name `{1}`.'.format(column_name, new_field_name))
# Dump the current database
now = datetime.now()
timestamp = now.strftime('%Y%m%dT%H%M%S')
dump_path = os.path.join(backup_dir, '{0}.{1}.dump'.format(db, timestamp))
with open(dump_path, 'w') as f:
call(['mysqldump'] + mysql_options, stdout=f)
with open(dump_path, 'r') as f:
text = f.read()
# Modify the dump
field_name_d = field_name.replace('_', '-')
new_field_name_d = new_field_name.replace('_', '-')
new_field_name_all = new_field_name + '|' + new_field_name_d
length_diff = len(new_field_name) - len(field_name)
mod_text = text
mod_text = re.sub(field_name, new_field_name, mod_text)
mod_text = re.sub(field_name_d, new_field_name_d, mod_text)
mod_text = fix_serialization(mod_text, new_field_name_all, length_diff)
# Save the modified dump and load it into the database
mod_dump_path = '{0}.mod-{1}-{2}'.format(dump_path, field_name, new_field_name)
with open(mod_dump_path, 'w') as f:
f.write(mod_text)
with open(mod_dump_path, 'r') as f:
call(['mysql'] + mysql_options, stdin=f)
# Remove old field tables
for name in field_tables:
c2.execute('drop table {0}'.format(name))
def call(args, stdin=None, stdout=None):
p = subprocess.Popen(args, stdin=stdin, stdout=stdout)
output = None
if stdout is subprocess.PIPE:
output = p.communicate()[0]
else:
p.wait()
if p.returncode != 0:
sys.exit(1)
return output
def fix_serialization(text, pattern, diff):
p = re.compile(r's:([0-9]+):\\"')
temp = []
last_pos = 0
for m in re.finditer(p, text):
end = m.end()
n = int(m.group(1))
i = n
cursor = end
while i:
cursor += 1 if text[cursor] != '\\' else 2
i -= 1
if text[cursor:cursor + 2] != '\\"':
start_1 = m.start(1)
temp.append(text[last_pos:start_1])
string = text[end:cursor + abs(diff)]
x = len(re.findall(pattern, string))
while True:
string = text[end:cursor + x * diff + abs(diff)]
y = len(re.findall(pattern, string))
if x == y:
break
x = y
temp.append(str(n + x * diff))
last_pos = m.end(1)
temp.append(text[last_pos:])
fixed_text = ''.join(temp)
if check_serialization_integrity(fixed_text):
return fixed_text
else:
raise RuntimeError("Something went wrong.")
def check_serialization_integrity(text):
p = re.compile(r's:([0-9]+):\\"')
for m in re.finditer(p, text):
end = m.end()
n = int(m.group(1))
cursor = end
while n:
cursor += 1 if text[cursor] != '\\' else 2
n -= 1
string = text[end:cursor]
if text[cursor:cursor + 2] != '\\"':
print end, cursor, cursor - end, string, text[end - 10:cursor + 10]
return False
return True
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment