-
-
Save gsw945/3902a66161f26705bcaa5fbee2415cc6 to your computer and use it in GitHub Desktop.
Implementing all four SQL transaction isolation levels in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf8 -*- | |
from __future__ import print_function | |
class LockManager: | |
def __init__(self): | |
self.locks = [] | |
def add(self, transaction, record_id): | |
if not self.exists(transaction, record_id): | |
self.locks.append([transaction, record_id]) | |
def exists(self, transaction, record_id): | |
return any(lock[0] is transaction and lock[1] == record_id \ | |
for lock in self.locks) | |
class Table: | |
def __init__(self): | |
self.next_xid = 1 | |
self.active_xids = set() | |
self.records = [] | |
self.locks = LockManager() | |
def new_transaction(self, transaction_type): | |
self.next_xid += 1 | |
self.active_xids.add(self.next_xid) | |
return transaction_type(self, self.next_xid) | |
class RollbackException(Exception): | |
pass | |
class Transaction: | |
def __init__(self, table, xid): | |
self.table = table | |
self.xid = xid | |
self.rollback_actions = [] | |
def add_record(self, id, name): | |
record = { | |
'id': id, | |
'name': name, | |
'created_xid': self.xid, | |
'expired_xid': 0 | |
} | |
self.rollback_actions.append(["delete", len(self.table.records)]) | |
self.table.records.append(record) | |
def delete_record(self, id): | |
for i, record in enumerate(self.table.records): | |
if self.record_is_visible(record) and record['id'] == id: | |
if self.record_is_locked(record): | |
raise RollbackException("Row locked by another transaction.") | |
else: | |
record['expired_xid'] = self.xid | |
self.rollback_actions.append(["add", i]) | |
def update_record(self, id, name): | |
self.delete_record(id) | |
self.add_record(id, name) | |
def fetch_record(self, id): | |
for record in self.table.records: | |
if self.record_is_visible(record) and record['id'] is id: | |
return record | |
return None | |
def count_records(self, min_id, max_id): | |
count = 0 | |
for record in self.table.records: | |
if self.record_is_visible(record) and \ | |
min_id <= record['id'] <= max_id: | |
count += 1 | |
return count | |
def fetch_all_records(self): | |
visible_records = [] | |
for record in self.table.records: | |
if self.record_is_visible(record): | |
visible_records.append(record) | |
return visible_records | |
def fetch(self, expr): | |
visible_records = [] | |
for record in self.table.records: | |
if self.record_is_visible(record) and expr(record): | |
visible_records.append(record) | |
return visible_records | |
def commit(self): | |
self.table.active_xids.discard(self.xid) | |
def rollback(self): | |
for action in reversed(self.rollback_actions): | |
if action[0] == 'add': | |
self.table.records[action[1]]['expired_xid'] = 0 | |
elif action[0] == 'delete': | |
self.table.records[action[1]]['expired_xid'] = self.xid | |
self.table.active_xids.discard(self.xid) | |
class ReadUncommittedTransaction(Transaction): | |
def record_is_locked(self, record): | |
return record['expired_xid'] != 0 | |
def record_is_visible(self, record): | |
return record['expired_xid'] == 0 | |
class ReadCommittedTransaction(Transaction): | |
def record_is_locked(self, record): | |
return record['expired_xid'] != 0 and \ | |
row['expired_xid'] in self.table.active_xids | |
def record_is_visible(self, record): | |
# The record was created in active transaction that is not our | |
# own. | |
if record['created_xid'] in self.table.active_xids and \ | |
record['created_xid'] != self.xid: | |
return False | |
# The record is expired or and no transaction holds it that is | |
# our own. | |
if record['expired_xid'] != 0 and \ | |
(record['expired_xid'] not in self.table.active_xids or \ | |
record['expired_xid'] == self.xid): | |
return False | |
return True | |
class RepeatableReadTransaction(ReadCommittedTransaction): | |
def record_is_locked(self, record): | |
return ReadCommittedTransaction.record_is_locked(self, record) or \ | |
self.table.locks.exists(self, record['id']) | |
def record_is_visible(self, record): | |
is_visible = ReadCommittedTransaction.record_is_visible(self, record) | |
if is_visible: | |
self.table.locks.add(self, record['id']) | |
return is_visible | |
class SerializableTransaction(RepeatableReadTransaction): | |
def __init__(self, table, xid): | |
Transaction.__init__(self, table, xid) | |
self.existing_xids = self.table.active_xids.copy() | |
def record_is_visible(self, record): | |
is_visible = ReadCommittedTransaction.record_is_visible(self, record) \ | |
and record['created_xid'] <= self.xid \ | |
and record['created_xid'] in self.existing_xids | |
if is_visible: | |
self.table.locks.add(self, record['id']) | |
return is_visible | |
class TransactionTest: | |
def __init__(self, transaction_type): | |
self.table = Table() | |
client = self.table.new_transaction(ReadCommittedTransaction) | |
client.add_record(id=1, name="Joe") | |
client.add_record(id=3, name="Jill") | |
client.commit() | |
self.client1 = self.table.new_transaction(transaction_type) | |
self.client2 = self.table.new_transaction(transaction_type) | |
def run_test(self): | |
try: | |
return self.run() | |
except RollbackException: | |
return False | |
def result(self): | |
if self.run_test(): | |
return u'✔' | |
return u'✘' | |
class DirtyRead(TransactionTest): | |
def run(self): | |
result1 = self.client1.fetch_record(id=1) | |
self.client2.update_record(id=1, name="Joe 2") | |
result2 = self.client1.fetch_record(id=1) | |
return result1 != result2 | |
class NonRepeatableRead(TransactionTest): | |
def run(self): | |
result1 = self.client1.fetch_record(id=1) | |
self.client2.update_record(id=1, name="Joe 2") | |
self.client2.commit() | |
result2 = self.client1.fetch_record(id=1) | |
return result1 != result2 | |
class PhantomRead(TransactionTest): | |
def run(self): | |
result1 = len(self.client1.fetch(lambda r: 1 <= r['id'] <= 3)) | |
self.client2.add_record(id=2, name="John") | |
self.client2.commit() | |
result2 = self.client1.count_records(min_id=1, max_id=3) | |
return result1 != result2 | |
isolation_modes = [ | |
['read uncommitted', ReadUncommittedTransaction], | |
['read committed ', ReadCommittedTransaction], | |
['repeatable read ', RepeatableReadTransaction], | |
['serializable ', SerializableTransaction] | |
] | |
possible_errors = [DirtyRead, NonRepeatableRead, PhantomRead] | |
print(' Dirty Repeat Phantom') | |
for isolation_mode in isolation_modes: | |
results = [possible_error(isolation_mode[1]).result() for possible_error in possible_errors] | |
print(isolation_mode[0] + " " + results[0] + " " + results[1] + " " + results[2]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment