-
-
Save willvousden/c670d0a76a8fafad4df8431ef33c0fb1 to your computer and use it in GitHub Desktop.
Serializable transactions and retry in SQLAlchemy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sqlalchemy import (create_engine, event, | |
Column, Integer, | |
ForeignKey) | |
from sqlalchemy import event | |
from sqlalchemy.orm import sessionmaker, scoped_session | |
from sqlalchemy.orm import relationship | |
from sqlalchemy.ext.declarative import declarative_base, declared_attr | |
from sqlalchemy.exc import DBAPIError | |
import sqlite3 | |
import psycopg2 | |
import sys | |
import threading | |
from random import random | |
ENGINE = 'sqlite' | |
if ENGINE == 'sqlite': | |
engine = create_engine('sqlite:///serialize.db', | |
isolation_level='SERIALIZABLE') | |
event.listen(engine, 'connect', | |
lambda conn, rec: conn.execute('PRAGMA foreign_keys=ON;')) | |
# Refer to SQLAlchemy docs for explanation: | |
# http://docs.sqlalchemy.org/en/rel_0_8/dialects/sqlite.html#serializable-transaction-isolation | |
@event.listens_for(engine, 'begin') | |
def do_begin(conn): | |
conn.execute('BEGIN;') | |
elif ENGINE == 'postgresql': | |
engine = create_engine('postgresql+psycopg2://user:pass@localhost/serialize', | |
isolation_level='SERIALIZABLE') | |
else: | |
raise ValueError('Unknown engine') | |
session_factory = sessionmaker(bind=engine) | |
session = scoped_session(session_factory) | |
class Base(object): | |
@declared_attr | |
def __tablename__(cls): | |
return cls.__name__.lower() | |
Base = declarative_base(cls=Base) | |
class Task(Base): | |
id = Column(Integer, primary_key=True, autoincrement=True) | |
class Node(Base): | |
id = Column(Integer, primary_key=True, autoincrement=True) | |
class Schedule(Base): | |
task_id = Column(Integer, ForeignKey(Task.id), nullable=False, | |
primary_key=True) | |
task = relationship(Task) | |
node_id = Column(Integer, ForeignKey(Node.id), nullable=False, | |
primary_key=True) | |
node = relationship(Node) | |
@classmethod | |
def update(cls): | |
tasks = session.query(Task).all() | |
nodes = session.query(Node).all() | |
session.query(cls).delete() | |
for task in tasks: | |
for node in nodes: | |
if random() < 0.1: | |
session.add(cls(task=task, node=node)) | |
def init_db(): | |
Base.metadata.drop_all(engine) | |
Base.metadata.create_all(engine) | |
def retry(fn): | |
while True: | |
try: | |
fn() | |
session.commit() | |
return | |
except DBAPIError, e: | |
orig = e.orig | |
if getattr(orig, 'pgcode', None) == '40001': | |
print >>sys.stderr, 'Retry...' | |
session.rollback() | |
elif isinstance(orig, sqlite3.DatabaseError) and \ | |
orig.args == ('database is locked',): | |
print >>sys.stderr, 'Retry...' | |
session.rollback() | |
else: | |
raise | |
class TaskInserter(threading.Thread): | |
def add_task(self): | |
new_task = Task() | |
session.add(new_task) | |
Schedule.update() | |
def run(self): | |
for _ in xrange(100): | |
retry(self.add_task) | |
class NodeInserter(threading.Thread): | |
def add_node(self): | |
new_node = Node() | |
session.add(new_node) | |
Schedule.update() | |
def run(self): | |
for _ in xrange(100): | |
retry(self.add_node) | |
def main(): | |
init_db() | |
t1 = TaskInserter() | |
t2 = NodeInserter() | |
t1.start() | |
t2.start() | |
t1.join() | |
t2.join() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment