Created
November 10, 2012 21:43
-
-
Save cliffxuan/4052632 to your computer and use it in GitHub Desktop.
Introspect constraint on table's primary key in the entire database
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
import itertools | |
import sqlalchemy as sa | |
from sqlalchemy import Table, Column, String, ForeignKey | |
from newman.schema.branch import tables | |
class Introspector(object): | |
def __init__(self, metadata): | |
self.metadata = metadata | |
@property | |
def all_tables(self): | |
return self.metadata.tables.values() | |
@staticmethod | |
def get_primary_key(table): | |
"""takes a sqlalchemy table object and | |
returns the primary key string names. | |
params: | |
table -- sqlalchemy table object | |
return: | |
a list in this format [col_1, col_2] | |
e.g. for CustomerMembership table ['id'] | |
""" | |
return table.primary_key.columns.keys() | |
@staticmethod | |
def construct_sql(action, table, col): | |
if action == 'update': | |
return 'update %(table)s ' \ | |
'set %(col)s = :new_id '\ | |
'where %(col)s = :old_id' %dict(table=table, | |
col=col) | |
elif action == 'select': | |
return 'select %(col)s from %(table)s ' \ | |
'where %(col)s = :id' %dict(table=table, | |
col=col) | |
def find_foreign_keys_to_table(self, table): | |
"""takes a sqlalchemy table object and returns a list of | |
foreign keys referencing this table | |
params: | |
table -- sqlalchemy table object | |
returns: | |
a list of sqlalchemy ForeignKey objects | |
""" | |
foreign_key_list = [] | |
for t in self.all_tables: | |
for fk in t.foreign_keys: | |
if fk.column.table == table: | |
foreign_key_list.append(fk) | |
return foreign_key_list | |
def find_related_primary_key(self, table): | |
"""takes a sqlalchemy table object and returns a list of | |
primary keys in the db where constituent column(s) also | |
refers to the table | |
params: | |
table -- sqlalchemy table object | |
returns: | |
a list of sqlalchemy PrimaryKeyConstraint objects | |
""" | |
related_primary_key_list = [] | |
foreign_key_list = self.find_foreign_keys_to_table(table) | |
for foreign_key in foreign_key_list: | |
pk = foreign_key.parent.table.primary_key | |
if pk.columns.contains_column(foreign_key.parent): | |
related_primary_key_list.append(pk) | |
return list(set(related_primary_key_list)) | |
def find_related_unique_constraints(self, table): | |
"""takes a sqlalchemy table object and returns a list of | |
unique constraints in the db where constituent column(s) also | |
refers to the table | |
params: | |
table -- sqlalchemy table object | |
returns: | |
a list of sqlalchemy UniqueConstraint list | |
""" | |
related_unique_constraints = [] | |
foreign_key_list = self.find_foreign_keys_to_table(table) | |
for foreign_key in foreign_key_list: | |
for con in foreign_key.parent.table.constraints: | |
if type(con) == sa.schema.UniqueConstraint: | |
if con.columns.contains_column(foreign_key.parent): | |
related_unique_constraints.append(con) | |
return related_unique_constraints | |
def find_foreign_keys_with_other_constraint(self, table): | |
"""takes a sqlalchemy table object and returns a list of | |
foreign keys in the db where constituent column(s) is | |
a constituency of either a primary key constraint or a | |
unique constraint | |
params: | |
table -- sqlalchemy table object | |
returns: | |
a list of sqlalchemy ForeignKey objects | |
""" | |
uc_list = self.find_foreign_keys_to_table(table) | |
fk_no_constraint_list = self.find_foreign_keys_without_other_constraint(table) | |
return [fk for fk in uc_list if fk not in fk_no_constraint_list] | |
def find_foreign_keys_without_other_constraint(self, table): | |
"""takes a sqlalchemy table object and returns a list of | |
foreign keys in the db where constituent column(s) is not | |
a constituency of either a primary key constraint or a | |
unique constraint | |
params: | |
table -- sqlalchemy table object | |
returns: | |
a list of sqlalchemy ForeignKey objects | |
""" | |
fk_list = self.find_foreign_keys_to_table(table) | |
pk_list = self.find_related_primary_key(table) | |
pk_col_list = [col for col in itertools.chain(*[[ c for c in pk.columns] for pk in pk_list])] | |
uc_list = self.find_related_unique_constraints(table) | |
uc_col_list = [col for col in itertools.chain(*[[ c for c in uc.columns] for uc in uc_list])] | |
filtered = [fk for fk in fk_list if fk.parent not in uc_col_list] | |
filtered = [fk for fk in filtered if fk.parent not in pk_col_list] | |
return filtered | |
def generate_select_sql(self, table, with_comment=False): | |
"""takes a sqlalchemy table object and returns a sql statement | |
for selecting related data from related table to the object | |
params: | |
table -- sqlalchemy table object | |
with_comment -- boolean | |
returns: | |
sql string | |
""" | |
output = [] | |
def format_comment(comment): | |
return '\n/*\n%s\n*/' %comment | |
free_fk_list = self.find_foreign_keys_without_other_constraint(table) | |
if with_comment: | |
output.append(format_comment('the following queries are safe to run without violating any constraint')) | |
for free_fk in free_fk_list: | |
t_name = free_fk.parent.table.name | |
c_name = free_fk.parent.name | |
output.append(self.construct_sql('select', t_name, c_name)) | |
related_pk_list = self.find_related_primary_key(table) | |
if with_comment: | |
output.append(format_comment('the following queries may violat primary key constraint')) | |
col_list = [] | |
for related_pk in related_pk_list: | |
for key in related_pk.columns.keys(): | |
col = related_pk.columns[key] | |
for fk in col.foreign_keys: | |
if fk.column.table == table: | |
col_list.append(col) | |
for col in col_list: | |
t_name = col.table.name | |
c_name = col.name | |
output.append(self.construct_sql('select', t_name, c_name)) | |
if with_comment: | |
output.append(format_comment('the following queries may violat unique constraint')) | |
related_uc_list = self.find_related_unique_constraints(table) | |
col_list = [] | |
for related_uc in related_uc_list: | |
for key in related_uc.columns.keys(): | |
col = related_uc.columns[key] | |
for fk in col.foreign_keys: | |
if fk.column.table == table: | |
col_list.append(col) | |
for col in col_list: | |
t_name = col.table.name | |
c_name = col.name | |
output.append(self.construct_sql('select', t_name, c_name)) | |
return output | |
def generate_update_sql(self, table, with_comment=False): | |
"""takes a sqlalchemy table object and returns a sql statement | |
for updating the object | |
params: | |
table -- sqlalchemy table object | |
with_comment -- boolean | |
returns: | |
sql string | |
""" | |
output = [] | |
def format_comment(comment): | |
return '\n/*\n%s\n*/' %comment | |
free_fk_list = self.find_foreign_keys_without_other_constraint(table) | |
if with_comment: | |
output.append(format_comment('the following queries are safe to run without violating any constraint')) | |
for free_fk in free_fk_list: | |
t_name = free_fk.parent.table.name | |
c_name = free_fk.parent.name | |
output.append(self.construct_sql('update', t_name, c_name)) | |
related_pk_list = self.find_related_primary_key(table) | |
if with_comment: | |
output.append(format_comment('the following queries may violat primary key constraint')) | |
col_list = [] | |
for related_pk in related_pk_list: | |
for key in related_pk.columns.keys(): | |
col = related_pk.columns[key] | |
for fk in col.foreign_keys: | |
if fk.column.table == table: | |
col_list.append(col) | |
for col in col_list: | |
t_name = col.table.name | |
c_name = col.name | |
output.append(self.construct_sql('update', t_name, c_name)) | |
if with_comment: | |
output.append(format_comment('the following queries may violat unique constraint')) | |
related_uc_list = self.find_related_unique_constraints(table) | |
col_list = [] | |
for related_uc in related_uc_list: | |
for key in related_uc.columns.keys(): | |
col = related_uc.columns[key] | |
for fk in col.foreign_keys: | |
if fk.column.table == table: | |
col_list.append(col) | |
for col in col_list: | |
t_name = col.table.name | |
c_name = col.name | |
output.append(self.construct_sql('update', t_name, c_name)) | |
return output | |
class MergeError(Exception): | |
"""Error when merging two data objects""" | |
class Merger(object): | |
"""Merging two objects""" | |
def __init__(self, conn, introspector, table): | |
"""params: | |
db_introspector -- instance of Introspector | |
table -- sqlalchemy table | |
""" | |
self.conn = conn | |
self.table = table | |
self.introspector = introspector | |
@property | |
def pk_column(self): | |
"""sole primary key column. if multiple primary key cols are found | |
a MergeError is raised, as merge only happens on top level objects.""" | |
key_list = self.table.primary_key.columns.keys() | |
if len(key_list) > 1: | |
raise MergeError('Table "%s" has multi primary columns.' %self.table.name) | |
else: | |
key = key_list[0] | |
return self.table.primary_key.columns[key] | |
def find_conflict(self, id_1, id_2): | |
"""""" | |
def _conflicts_in_lists(one_list, other_list, compare_cols, id_cols): | |
cc = {} | |
for one, other in itertools.product(one_list, other_list): | |
if all(one[col] == other[col] for col in compare_cols): | |
k = tuple([one[col] for col in compare_cols]) | |
v = (tuple([one[col] for col in id_cols]), tuple([other[col] for col in id_cols])) | |
cc[k] = v | |
return cc | |
conflict = {} | |
related_uc_list = self.introspector.find_related_unique_constraints(self.table) | |
related_pk_list = self.introspector.find_related_primary_key(self.table) | |
for uc in related_uc_list: | |
related_table = uc.columns[uc.columns.keys()[0]].table | |
where_col = next(fk.parent.name for fk in related_table.foreign_keys if fk.column is self.pk_column) | |
compare_cols = [ key for key in uc.columns.keys() if key != where_col] | |
pk_cols = related_table.primary_key.columns.keys() | |
sql = 'select %s from %s where %s = ?' %(', '.join(pk_cols + compare_cols), related_table.name, where_col) | |
records_1 = self.conn.execute(sql, id_1).fetchall() | |
records_2 = self.conn.execute(sql, id_2).fetchall() | |
cc = _conflicts_in_lists(records_1, records_2, compare_cols, pk_cols) | |
if cc: | |
conflict[uc] = cc | |
for pk in related_pk_list: | |
#table_name = pk.table.name | |
table_name = pk.columns[pk.columns.keys()[0]].table.name | |
where_col = next(col.name for col in pk.columns if any(fk.column.table == self.table for fk in col.foreign_keys)) | |
pk_cols = pk.columns.keys() | |
sql = 'select %s from %s where %s = ?' %(', '.join(pk_cols), table_name, where_col) | |
compare_cols = [col for col in pk_cols if col != where_col] | |
records_1 = self.conn.execute(sql, id_1).fetchall() | |
records_2 = self.conn.execute(sql, id_2).fetchall() | |
cc = _conflicts_in_lists(records_1, records_2, compare_cols, pk_cols) | |
if cc: | |
conflict[pk] = cc | |
sql = sa.text('select * from %s where %s = :id' %(self.table.name, self.pk_column.name)) | |
rec_1 = self.conn.execute(sql, id=id_1).fetchone() | |
rec_2 = self.conn.execute(sql, id=id_2).fetchone() | |
for col in self.table.columns: | |
if col != self.pk_column: | |
v_1 = rec_1[col.name] | |
v_2 = rec_2[col.name] | |
if v_1 != v_2: | |
conflict[col] = (v_1, v_2) | |
return conflict | |
class DBIndependent(unittest.TestCase): | |
def setUp(self): | |
"""this method gets all the branch tables. | |
return: | |
a list of sqlalchemy Table objects | |
""" | |
self.introspector = Introspector(tables.metadata) | |
def test_find_foreign_keys_to_table(self): | |
t = tables.customer_table | |
foreign_key_list = self.introspector.find_foreign_keys_to_table(t) | |
dep_tables = [t.parent.table for t in foreign_key_list] | |
print [fk.parent.name for fk in foreign_key_list] | |
self.assertTrue(tables.customer_membership_table in dep_tables) | |
self.assertTrue(tables.direct_debit_table in dep_tables) | |
self.assertTrue(tables.customer_reservation_table in dep_tables) | |
def test_multi_primary_keys(self): | |
for t in self.introspector.all_tables: | |
if len(t.primary_key.columns) > 1: | |
deps = self.introspector.find_foreign_keys_to_table(t) | |
if len(deps): | |
print deps | |
def test_find_related_primary_key(self): | |
t = tables.customer_table | |
pk_list = self.introspector.find_related_primary_key(t) | |
try: | |
table_list = [pk.table for pk in pk_list] | |
except: | |
#for sqlalchemy 0.6.5 | |
#this error occurs: | |
#InvalidRequestError: This constraint is not bound to a table. | |
#Did you mean to call table.add_constraint(constraint)? | |
table_list = [pk.columns[pk.columns.keys()[0]].table for pk in pk_list] | |
print [tb.name for tb in table_list] | |
self.assertTrue(tables.customer_reservation_table in table_list) | |
def test_find_related_unique_constraints(self): | |
t = tables.customer_table | |
related_uc_list = self.introspector.find_related_unique_constraints(t) | |
related_tables = [t.table for t in related_uc_list] | |
print [t.name for t in related_tables] | |
self.assertTrue(tables.customer_membership_table in related_tables) | |
self.assertTrue(tables.direct_debit_table in related_tables) | |
self.assertTrue(tables.customer_product_credit_table not in related_tables) | |
def test_find_foreign_keys_with_other_constraint(self): | |
t = tables.customer_table | |
non_free_fk_list = self.introspector.find_foreign_keys_with_other_constraint(t) | |
fk = [k for k in tables.customer_membership_table.columns['customer_id'].foreign_keys if type(k) == ForeignKey][0] | |
self.assertTrue(fk in non_free_fk_list) | |
def test_find_foreign_keys_without_other_constraint(self): | |
t = tables.customer_table | |
free_fk_list = self.introspector.find_foreign_keys_without_other_constraint(t) | |
dep_tables = [fk.parent.table for fk in free_fk_list] | |
print ['.'.join([fk.parent.table.name, fk.parent.name]) for fk in free_fk_list] | |
self.assertTrue(tables.customer_product_credit_table in dep_tables) | |
def test_generate_update_sql(self): | |
t = tables.customer_table | |
print '\n'.join(self.introspector.generate_update_sql(t, with_comment=True)) | |
def test_generate_select_sql(self): | |
t = tables.customer_table | |
print '\n'.join(self.introspector.generate_select_sql(t, with_comment=True)) | |
class SQLite(unittest.TestCase): | |
def setUp(self): | |
metadata = sa.MetaData() | |
self.c_table = Table('Customer', | |
metadata, | |
Column("id", String(3), primary_key=True), | |
Column("name", String(4), nullable=False), | |
) | |
self.m_table = Table('Membership', | |
metadata, | |
Column("id", String(2), primary_key=True), | |
Column("name", String(4), nullable=False), | |
) | |
self.cm_table = Table('CustomerMembership', | |
metadata, | |
Column("id", String(4), primary_key=True), | |
Column("customer_id", String(3), ForeignKey('Customer.id', use_alter=True, name='FK_CustomerMembership_Customer'), nullable=False), | |
Column("membership_id", String(2), ForeignKey('Membership.id', use_alter=True, name='FK_CustomerMembership_Membership'), nullable=False), | |
sa.schema.UniqueConstraint('customer_id', 'membership_id', name='uCustomerMembership_customer_id-CustomerMembership_membership_id') | |
) | |
self.s_table = Table('Subscription', | |
metadata, | |
Column("id", String(2), primary_key=True), | |
Column("name", String(4), nullable=False), | |
) | |
self.cs_table = Table('CustomerSubscription', | |
metadata, | |
Column("subscription_id", String(3), ForeignKey('Subscription.id', use_alter=True, name='FK_CustomerSubscription_Subscription'), nullable=False, primary_key=True), | |
Column("customer_id", String(3), ForeignKey('Customer.id', use_alter=True, name='FK_CustomerSubscription_Customer'), nullable=False, primary_key=True), | |
) | |
engine = sa.create_engine('sqlite:///:memory:') | |
for tb in metadata.tables.values(): | |
tb.create(engine) | |
self.conn = engine.connect() | |
self.conn.execute("insert into Customer(id, name) values('c1', 'foo')") | |
self.conn.execute("insert into Customer(id, name) values('c2', 'bar')") | |
self.conn.execute("insert into Membership(id, name) values('m1', 'mem1')") | |
self.conn.execute("insert into Membership(id, name) values('m2', 'mem2')") | |
self.conn.execute("insert into Membership(id, name) values('m3', 'mem3')") | |
self.conn.execute("insert into Subscription(id, name) values('s1', 'sub1')") | |
self.conn.execute("insert into Subscription(id, name) values('s2', 'sub2')") | |
self.conn.execute("insert into Subscription(id, name) values('s3', 'sub3')") | |
self.conn.execute("insert into CustomerMembership(id, customer_id, membership_id) values('cm11', 'c1', 'm1')") | |
self.conn.execute("insert into CustomerMembership(id, customer_id, membership_id) values('cm12', 'c1', 'm2')") | |
self.conn.execute("insert into CustomerMembership(id, customer_id, membership_id) values('cm21', 'c2', 'm1')") | |
self.conn.execute("insert into CustomerMembership(id, customer_id, membership_id) values('cm23', 'c2', 'm3')") | |
self.conn.execute("insert into CustomerSubscription(customer_id, subscription_id) values('c1', 's1')") | |
self.conn.execute("insert into CustomerSubscription(customer_id, subscription_id) values('c1', 's2')") | |
self.conn.execute("insert into CustomerSubscription(customer_id, subscription_id) values('c2', 's1')") | |
self.conn.execute("insert into CustomerSubscription(customer_id, subscription_id) values('c2', 's3')") | |
self.introspector = Introspector(metadata) | |
def test_find_conflict(self): | |
merger = Merger(self.conn, self.introspector, self.c_table) | |
uc = next(c for c in self.cm_table.constraints if type(c) == sa.schema.UniqueConstraint) | |
pk = self.cs_table.primary_key | |
name_col = self.c_table.columns['name'] | |
conflict = merger.find_conflict('c1', 'c2') | |
self.assertTrue( uc in conflict ) | |
self.assertTrue( pk in conflict ) | |
self.assertTrue( name_col in conflict ) | |
self.assertTrue(('m1',) in conflict[uc]) | |
self.assertTrue(('s1',) in conflict[pk]) | |
self.assertEqual(('foo', 'bar'), conflict[name_col] ) | |
class QuickRun(unittest.TestCase): | |
def setUp(self): | |
self.conn = sa.create_engine('mssql://sa:Password1@localhost/branch_dedup_v3dot4').connect() | |
self.introspector = Introspector(tables.metadata) | |
self.merger = Merger(self.conn, self.introspector, tables.customer_table) | |
def test_run_merge(self): | |
t = tables.customer_table | |
trans = self.conn.begin() | |
try: | |
top2 = self.conn.execute('select top 2 id from customer') | |
old_id = top2.fetchone()[0] | |
new_id = top2.fetchone()[0] | |
conflict = self.merger.find_conflict(old_id, new_id) | |
if conflict: | |
col = [] | |
pk = [] | |
uc = [] | |
for k in conflict: | |
if type(k) == sa.Column: | |
col.append(k.name) | |
elif type(k) == sa.schema.UniqueConstraint: | |
uc.append(k.table.name + '=>' + ', '.join(k.columns.keys()) + str(conflict[k])) | |
elif type(k) == sa.schema.PrimaryKeyConstraint: | |
where_col = next(c.name for c in k.columns if any(fk.column.table == t for fk in c.foreign_keys)) | |
pk_cols = k.columns.keys() | |
compare_cols = [c for c in pk_cols if c != where_col] | |
pk.append(next(c for c in k.columns).table.name + '=>' + ', '.join(compare_cols) + str(conflict[k])) | |
print 'needs human input for:\n' | |
if col: | |
print 'Columns of Customer table:', ', '.join(col) | |
if uc: | |
print 'Unique Constraint:', ', '.join(uc) | |
if pk: | |
print 'Primary Key Constraint:', ', '.join(pk) | |
else: | |
print 'give all that belongs to "%s" to "%s"' %(old_id, new_id) | |
update_sql = self.introspector.generate_update_sql(t) | |
for u_sql in update_sql: | |
u_sql = sa.text(u_sql) | |
self.conn.execute(u_sql, old_id=old_id, new_id=new_id) | |
select_sql = self.introspector.generate_select_sql(t) | |
nothing_left = True | |
for sql in select_sql: | |
result = self.conn.execute(sa.text(sql), id=old_id).fetchall() | |
if result: | |
print sql.replace(':id', "'" + old_id + "'") | |
print result | |
nothing_left = False | |
if nothing_left: | |
self.conn.execute('delete from customer where id = ?', old_id) | |
print 'removed old customer' | |
else: | |
print 'could not remove old customer' | |
trans.commit() | |
except: | |
print 'merge failed' | |
#print u_sql | |
trans.rollback() | |
raise | |
def test_raise_unique_constraint(self): | |
m_id, count = self.conn.execute("select membership_id, count(membership_id) from customermembership group by membership_id order by count(membership_id) desc").fetchone() | |
top2 = self.conn.execute('select top 2 id, customer_id from customermembership where membership_id = ?', m_id) | |
first = top2.fetchone() | |
second = top2.fetchone() | |
sql = 'update customermembership set customer_id = :c_id where id = :id' | |
self.assertRaises(Exception, self.conn.execute, sa.text(sql), c_id=first[1], id=second[0]) | |
def test_raise_unique_constraint_no_raised_in_multi_update_statement(self): | |
m_id, count = self.conn.execute("select membership_id, count(membership_id) from customermembership group by membership_id order by count(membership_id) desc").fetchone() | |
top2 = self.conn.execute('select top 2 id, customer_id from customermembership where membership_id = ?', m_id) | |
first = top2.fetchone() | |
second = top2.fetchone() | |
sql = [] | |
sql.append("update customer set firstname = 'foo3' where id = :first_c_id") | |
sql.append('update customermembership set customer_id = :first_c_id where id = :first_id') | |
#this line violate unique constraint, | |
#but is ignored by sqlalchemy | |
sql.append('update customermembership set customer_id = :second_c_id where id = :first_id') | |
sql.append("update customer set middlename = 'mid2' where id = :first_c_id") | |
sql.append("update customer set lastname = 'bar2' where id = :first_c_id") | |
self.conn.execute(sa.text('\n'.join(sql)), first_c_id=first[1], first_id=first[0], second_c_id=second[1]) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment