Skip to content

Instantly share code, notes, and snippets.

@khayrov
Last active January 26, 2022 02:17
Show Gist options
  • Save khayrov/6291557 to your computer and use it in GitHub Desktop.
Save khayrov/6291557 to your computer and use it in GitHub Desktop.
Serializable transactions and retry in SQLAlchemy
from sqlalchemy import (create_engine, event,
Column, Integer,
ForeignKey)
from sqlalchemy import event
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.exc import DBAPIError
import sqlite3
import psycopg2
import sys
import threading
from random import random
ENGINE = 'sqlite'
if ENGINE == 'sqlite':
engine = create_engine('sqlite:///serialize.db',
isolation_level='SERIALIZABLE')
event.listen(engine, 'connect',
lambda conn, rec: conn.execute('PRAGMA foreign_keys=ON;'))
# Refer to SQLAlchemy docs for explanation:
# http://docs.sqlalchemy.org/en/rel_0_8/dialects/sqlite.html#serializable-transaction-isolation
@event.listens_for(engine, 'begin')
def do_begin(conn):
conn.execute('BEGIN;')
elif ENGINE == 'postgresql':
engine = create_engine('postgresql+psycopg2://user:pass@localhost/serialize',
isolation_level='SERIALIZABLE')
else:
raise ValueError('Unknown engine')
session_factory = sessionmaker(bind=engine)
session = scoped_session(session_factory)
class Base(object):
@declared_attr
def __tablename__(cls):
return cls.__name__.lower()
Base = declarative_base(cls=Base)
class Task(Base):
id = Column(Integer, primary_key=True, autoincrement=True)
class Node(Base):
id = Column(Integer, primary_key=True, autoincrement=True)
class Schedule(Base):
task_id = Column(Integer, ForeignKey(Task.id), nullable=False,
primary_key=True)
task = relationship(Task)
node_id = Column(Integer, ForeignKey(Node.id), nullable=False,
primary_key=True)
node = relationship(Node)
@classmethod
def update(cls):
tasks = session.query(Task).all()
nodes = session.query(Node).all()
session.query(cls).delete()
for task in tasks:
for node in nodes:
if random() < 0.1:
session.add(cls(task=task, node=node))
def init_db():
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
def retry(fn):
while True:
try:
fn()
session.commit()
return
except DBAPIError, e:
orig = e.orig
if getattr(orig, 'pgcode', None) == '40001':
print >>sys.stderr, 'Retry...'
session.rollback()
elif isinstance(orig, sqlite3.DatabaseError) and \
orig.args == ('database is locked',):
print >>sys.stderr, 'Retry...'
session.rollback()
else:
raise
class TaskInserter(threading.Thread):
def add_task(self):
new_task = Task()
session.add(new_task)
Schedule.update()
def run(self):
for _ in xrange(100):
retry(self.add_task)
class NodeInserter(threading.Thread):
def add_node(self):
new_node = Node()
session.add(new_node)
Schedule.update()
def run(self):
for _ in xrange(100):
retry(self.add_node)
def main():
init_db()
t1 = TaskInserter()
t2 = NodeInserter()
t1.start()
t2.start()
t1.join()
t2.join()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment