Skip to content

Instantly share code, notes, and snippets.

@huangsam
Last active March 27, 2021 22:05
Show Gist options
  • Save huangsam/910381a34419d5c1f57b9e50e1851422 to your computer and use it in GitHub Desktop.
Save huangsam/910381a34419d5c1f57b9e50e1851422 to your computer and use it in GitHub Desktop.
Play around with SQLAlchemy ORM and Core
from datetime import datetime
from enum import Enum as RegularEnum
from enum import auto
from sqlalchemy import Column, DateTime
from sqlalchemy import Enum as DatabaseEnum
from sqlalchemy import ForeignKey, Integer, String, event
from sqlalchemy.engine import create_engine
from sqlalchemy.orm import declarative_base, relationship, sessionmaker
# Initialize DB settings
engine = create_engine("sqlite://")
session = sessionmaker(bind=engine)()
Base = declarative_base()
class AuditAction(RegularEnum):
INSERT = auto()
UPDATE = auto()
DELETE = auto()
class Country(Base):
__tablename__ = "country"
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(64), unique=True, nullable=False)
# 1:N relationships
states = relationship("State", back_populates="country", cascade="all, delete-orphan")
def __repr__(self):
return f"<Country id={self.id}, name={self.name}>"
class State(Base):
__tablename__ = "state"
id = Column(Integer, primary_key=True, autoincrement=True)
country_id = Column(ForeignKey("country.id"), nullable=False)
name = Column(String(64), unique=True, nullable=False)
# 1:N relationships
cities = relationship("City", back_populates="state", cascade="all, delete-orphan")
# N:1 relationships
country = relationship("Country", back_populates="states")
def __repr__(self):
return f"<State id={self.id}, name={self.name}>"
class City(Base):
__tablename__ = "city"
id = Column(Integer, primary_key=True, autoincrement=True)
state_id = Column(ForeignKey("state.id"), nullable=False)
name = Column(String(64), unique=True, nullable=False)
# N:1 relationships
state = relationship("State", back_populates="cities")
def __repr__(self):
return f"<City id={self.id}, name={self.name}>"
class CityAuditLog(Base):
__tablename__ = "city_audit_log"
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(64), index=True, nullable=False)
action = Column(DatabaseEnum(AuditAction), index=True, nullable=False)
logged_time = Column(DateTime, default=datetime.utcnow)
def __repr__(self):
return f"<CityAuditLog id={self.id}, name={self.name}, action={self.action}>"
@event.listens_for(session, "before_flush")
def log_city_actions(db_session, _flush_context, _instances):
"""Log city actions on session events."""
for obj in db_session.new:
if isinstance(obj, City):
db_session.add(CityAuditLog(name=obj.name, action=AuditAction.INSERT))
for obj in db_session.dirty:
if isinstance(obj, City):
db_session.add(CityAuditLog(name=obj.name, action=AuditAction.UPDATE))
for obj in db_session.deleted:
if isinstance(obj, City):
db_session.add(CityAuditLog(name=obj.name, action=AuditAction.DELETE))
@event.listens_for(City.name, "set")
def print_city_set(target, value, oldvalue, _initiator):
"""Print city name on attribute set event."""
print(f"[set] City {target.id} changed from {oldvalue} to {value}")
def main():
####
# ORM cascades
# https://docs.sqlalchemy.org/en/14/orm/cascades.html
####
# Configure clean tables with metadata
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
# Create one country
united_states = Country(name="USA")
session.add(united_states)
session.flush()
# Create states
california = State(name="California", country_id=united_states.id)
washington = State(name="Washington", country_id=united_states.id)
session.add_all([california, washington])
session.flush()
# Create cities
california_cities = [
City(name="Sacramento", state_id=california.id),
City(name="San Francisco", state_id=california.id),
City(name="Los Angeles", state_id=california.id),
City(name="Napa", state_id=california.id),
]
washington_cities = [
City(name="Seattle", state_id=washington.id),
City(name="Spokane", state_id=washington.id),
City(name="Tacoma", state_id=washington.id),
]
session.add_all([*california_cities, *washington_cities])
session.flush()
# Save query instance for developer convenience
audit_log_query = session.query(CityAuditLog)
####
# ORM events
# https://docs.sqlalchemy.org/en/14/orm/events.html
####
# Commit everything
session.commit()
assert len(session.query(Country).all()) == 1
assert len(session.query(State).all()) == 2
assert len(session.query(City).all()) == 7
assert len(audit_log_query.all()) == 7
assert len(audit_log_query.filter(CityAuditLog.action == AuditAction.INSERT).all()) == 7
# Update everything
for state in united_states.states:
for city in state.cities:
city.name += " - Update"
assert len(audit_log_query.all()) == 14
assert len(audit_log_query.filter(CityAuditLog.action == AuditAction.UPDATE).all()) == 7
# Delete everything via cascade
session.delete(united_states)
session.flush()
assert len(session.query(Country).all()) == 0
assert len(session.query(State).all()) == 0
assert len(session.query(City).all()) == 0
assert len(audit_log_query.all()) == 21
assert len(audit_log_query.filter(CityAuditLog.action == AuditAction.DELETE).all()) == 7
# Print all audit log entries
for log in audit_log_query.all():
print(log)
# Rollback the updates and deletes
session.rollback()
assert len(audit_log_query.all()) == 7
assert len(audit_log_query.filter(CityAuditLog.action == AuditAction.INSERT).all()) == 7
if __name__ == "__main__":
main()
from enum import Enum as RegularEnum
from enum import auto
from sqlalchemy import Column
from sqlalchemy import Enum as DatabaseEnum
from sqlalchemy import ForeignKey, Integer, String
from sqlalchemy.engine import create_engine
from sqlalchemy.orm import declarative_base, relationship, sessionmaker
from sqlalchemy.sql.expression import and_, select, text
# Initialize DB settings
engine = create_engine("sqlite://")
session = sessionmaker(bind=engine)()
Base = declarative_base()
class Gender(RegularEnum):
MALE = auto()
FEMALE = auto()
class Person(Base):
__tablename__ = "person"
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(64), unique=True, nullable=False)
age = Column(Integer, index=True, nullable=False)
gender = Column(DatabaseEnum(Gender, name="gender_types"))
# 1:N relationships
addresses = relationship("Address", back_populates="person", cascade="all, delete-orphan")
def __repr__(self):
return f"<Person id={self.id} name={self.name} age={self.age}>"
class Address(Base):
__tablename__ = "address"
id = Column(Integer, primary_key=True, autoincrement=True)
person_id = Column(ForeignKey("person.id"), nullable=False)
street = Column(String(256), nullable=False)
state = Column(String(8), index=True, nullable=False)
zip_code = Column(String(16), index=True, nullable=False)
# N:1 relationships
person = relationship("Person", back_populates="addresses")
def __repr__(self):
return f"<Address id={self.id} street={self.street} state={self.state} zip_code={self.zip_code}>"
def main():
####
# ORM demo
# https://docs.sqlalchemy.org/en/14/orm/tutorial.html
####
# Configure clean tables with metadata
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
# Create John and Jane from scratch
person_john = Person(name="John", age=30, gender=Gender.MALE)
person_jane = Person(name="Jane", age=30, gender=Gender.FEMALE)
session.add_all([person_john, person_jane])
assert person_john.id is None
assert person_jane.id is None
# Flush John and Jane to get real IDs for them
session.flush()
assert person_john.id is not None
assert person_jane.id is not None
assert len(session.query(Person).all()) == 2
# Rollback results in cleanup
session.rollback()
assert len(session.query(Person).all()) == 0
# NOTE: OLD John and Jane still have OLD IDs
assert person_john.id is not None
assert person_jane.id is not None
# Commit John and Jane
session.add_all([person_jane, person_john])
session.commit()
# Rollback results in nothing
session.rollback()
assert len(session.query(Person).all()) == 2
# Flush Bob and Mary
person_bob = Person(name="Bob", age=25, gender=Gender.MALE)
person_mary = Person(name="Mary", age=25, gender=Gender.FEMALE)
session.add_all([person_bob, person_mary])
session.flush()
# Flush addresses
address_bob = Address(
person_id=person_bob.id,
street="1234 Fiction Avenue",
state="CA",
zip_code="12345",
)
address_mary = Address(
person_id=person_mary.id,
street="1235 Fiction Avenue",
state="CA",
zip_code="12345",
)
session.add_all([address_bob, address_mary])
session.flush()
# First address rollback results in cleanup
assert len(session.query(Person).all()) == 4
assert len(session.query(Address).all()) == 2
session.rollback()
assert len(session.query(Person).all()) == 2
assert len(session.query(Address).all()) == 0
# Second address rollback results in nothing
session.rollback()
assert len(session.query(Person).all()) == 2
assert len(session.query(Address).all()) == 0
# Add data back in again
session.add_all([person_bob, person_mary])
session.flush()
session.add_all([address_bob, address_mary])
session.commit()
# Rollback results in nothing
assert len(session.query(Person).all()) == 4
assert len(session.query(Address).all()) == 2
session.rollback()
assert len(session.query(Person).all()) == 4
assert len(session.query(Address).all()) == 2
# Add a couple more addresses to Bob and Mary
session.add_all(
[
Address(
person_id=person_bob.id,
street="1236 Fiction Ave",
state="CA",
zip_code="12345",
),
Address(
person_id=person_bob.id,
street="1237 Fiction Ave",
state="CA",
zip_code="12345",
),
Address(
person_id=person_mary.id,
street="1238 Fiction Ave",
state="CA",
zip_code="12345",
),
Address(
person_id=person_mary.id,
street="1239 Fiction Ave",
state="CA",
zip_code="12345",
),
]
)
session.commit()
# Bob and Mary do have places
assert len(person_bob.addresses) == 3
assert len(person_mary.addresses) == 3
# John and Jane do NOT have places
assert len(person_john.addresses) == 0
assert len(person_jane.addresses) == 0
# Delete Bob and his addresses with cascade functionality
session.delete(person_bob)
session.commit()
bob_or_none = session.query(Person).filter(Person.name == "Bob").one_or_none()
all_addresses = session.query(Address).all()
assert bob_or_none is None
assert len(all_addresses) == 3
assert not any(address.person.name == "Bob" for address in all_addresses)
# NOTE: OLD Bob still refers to the OLD addresses
assert len(person_bob.addresses) == 3
# Delete one address from Mary
mary_address_first = person_mary.addresses[0]
session.delete(mary_address_first)
session.commit()
assert mary_address_first not in person_mary.addresses
# NOTE: OLD Mary address still has an ID
assert mary_address_first.id is not None
####
# Core demo
# https://docs.sqlalchemy.org/en/14/core/tutorial.html
####
conn = engine.connect()
assert conn.closed is False
# Select tables for upcoming queries
person_table = Person.__table__
address_table = Address.__table__
# Select all person columns
s = select(person_table)
for row in conn.execute(s):
print(f"[all] {row}")
# Select three address columns
s = select(address_table.c.street, address_table.c.state, address_table.c.zip_code)
for row in conn.execute(s):
print(f"[pick] {row}")
# Select inner join two tables with multiple constraints
c = and_(
person_table.c.id == address_table.c.person_id,
person_table.c.age > 20,
person_table.c.gender == Gender.FEMALE,
person_table.c.name == "Mary",
)
s = select(person_table.c.name, address_table.c.street, address_table.c.zip_code).where(c)
for row in conn.execute(s):
print(f"[and] {row}")
# Select with text
s = text(
"""
SELECT person.name FROM person
WHERE person.age > :person_age
AND person.gender = :person_gender
"""
).bindparams(person_age=25, person_gender=Gender.MALE.name)
for row in conn.execute(s):
print(f"[text] {row}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment