Last active
March 27, 2021 22:05
-
-
Save huangsam/910381a34419d5c1f57b9e50e1851422 to your computer and use it in GitHub Desktop.
Play around with SQLAlchemy ORM and Core
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datetime import datetime | |
from enum import Enum as RegularEnum | |
from enum import auto | |
from sqlalchemy import Column, DateTime | |
from sqlalchemy import Enum as DatabaseEnum | |
from sqlalchemy import ForeignKey, Integer, String, event | |
from sqlalchemy.engine import create_engine | |
from sqlalchemy.orm import declarative_base, relationship, sessionmaker | |
# Initialize DB settings | |
engine = create_engine("sqlite://") | |
session = sessionmaker(bind=engine)() | |
Base = declarative_base() | |
class AuditAction(RegularEnum): | |
INSERT = auto() | |
UPDATE = auto() | |
DELETE = auto() | |
class Country(Base): | |
__tablename__ = "country" | |
id = Column(Integer, primary_key=True, autoincrement=True) | |
name = Column(String(64), unique=True, nullable=False) | |
# 1:N relationships | |
states = relationship("State", back_populates="country", cascade="all, delete-orphan") | |
def __repr__(self): | |
return f"<Country id={self.id}, name={self.name}>" | |
class State(Base): | |
__tablename__ = "state" | |
id = Column(Integer, primary_key=True, autoincrement=True) | |
country_id = Column(ForeignKey("country.id"), nullable=False) | |
name = Column(String(64), unique=True, nullable=False) | |
# 1:N relationships | |
cities = relationship("City", back_populates="state", cascade="all, delete-orphan") | |
# N:1 relationships | |
country = relationship("Country", back_populates="states") | |
def __repr__(self): | |
return f"<State id={self.id}, name={self.name}>" | |
class City(Base): | |
__tablename__ = "city" | |
id = Column(Integer, primary_key=True, autoincrement=True) | |
state_id = Column(ForeignKey("state.id"), nullable=False) | |
name = Column(String(64), unique=True, nullable=False) | |
# N:1 relationships | |
state = relationship("State", back_populates="cities") | |
def __repr__(self): | |
return f"<City id={self.id}, name={self.name}>" | |
class CityAuditLog(Base): | |
__tablename__ = "city_audit_log" | |
id = Column(Integer, primary_key=True, autoincrement=True) | |
name = Column(String(64), index=True, nullable=False) | |
action = Column(DatabaseEnum(AuditAction), index=True, nullable=False) | |
logged_time = Column(DateTime, default=datetime.utcnow) | |
def __repr__(self): | |
return f"<CityAuditLog id={self.id}, name={self.name}, action={self.action}>" | |
@event.listens_for(session, "before_flush") | |
def log_city_actions(db_session, _flush_context, _instances): | |
"""Log city actions on session events.""" | |
for obj in db_session.new: | |
if isinstance(obj, City): | |
db_session.add(CityAuditLog(name=obj.name, action=AuditAction.INSERT)) | |
for obj in db_session.dirty: | |
if isinstance(obj, City): | |
db_session.add(CityAuditLog(name=obj.name, action=AuditAction.UPDATE)) | |
for obj in db_session.deleted: | |
if isinstance(obj, City): | |
db_session.add(CityAuditLog(name=obj.name, action=AuditAction.DELETE)) | |
@event.listens_for(City.name, "set") | |
def print_city_set(target, value, oldvalue, _initiator): | |
"""Print city name on attribute set event.""" | |
print(f"[set] City {target.id} changed from {oldvalue} to {value}") | |
def main(): | |
#### | |
# ORM cascades | |
# https://docs.sqlalchemy.org/en/14/orm/cascades.html | |
#### | |
# Configure clean tables with metadata | |
Base.metadata.drop_all(bind=engine) | |
Base.metadata.create_all(bind=engine) | |
# Create one country | |
united_states = Country(name="USA") | |
session.add(united_states) | |
session.flush() | |
# Create states | |
california = State(name="California", country_id=united_states.id) | |
washington = State(name="Washington", country_id=united_states.id) | |
session.add_all([california, washington]) | |
session.flush() | |
# Create cities | |
california_cities = [ | |
City(name="Sacramento", state_id=california.id), | |
City(name="San Francisco", state_id=california.id), | |
City(name="Los Angeles", state_id=california.id), | |
City(name="Napa", state_id=california.id), | |
] | |
washington_cities = [ | |
City(name="Seattle", state_id=washington.id), | |
City(name="Spokane", state_id=washington.id), | |
City(name="Tacoma", state_id=washington.id), | |
] | |
session.add_all([*california_cities, *washington_cities]) | |
session.flush() | |
# Save query instance for developer convenience | |
audit_log_query = session.query(CityAuditLog) | |
#### | |
# ORM events | |
# https://docs.sqlalchemy.org/en/14/orm/events.html | |
#### | |
# Commit everything | |
session.commit() | |
assert len(session.query(Country).all()) == 1 | |
assert len(session.query(State).all()) == 2 | |
assert len(session.query(City).all()) == 7 | |
assert len(audit_log_query.all()) == 7 | |
assert len(audit_log_query.filter(CityAuditLog.action == AuditAction.INSERT).all()) == 7 | |
# Update everything | |
for state in united_states.states: | |
for city in state.cities: | |
city.name += " - Update" | |
assert len(audit_log_query.all()) == 14 | |
assert len(audit_log_query.filter(CityAuditLog.action == AuditAction.UPDATE).all()) == 7 | |
# Delete everything via cascade | |
session.delete(united_states) | |
session.flush() | |
assert len(session.query(Country).all()) == 0 | |
assert len(session.query(State).all()) == 0 | |
assert len(session.query(City).all()) == 0 | |
assert len(audit_log_query.all()) == 21 | |
assert len(audit_log_query.filter(CityAuditLog.action == AuditAction.DELETE).all()) == 7 | |
# Print all audit log entries | |
for log in audit_log_query.all(): | |
print(log) | |
# Rollback the updates and deletes | |
session.rollback() | |
assert len(audit_log_query.all()) == 7 | |
assert len(audit_log_query.filter(CityAuditLog.action == AuditAction.INSERT).all()) == 7 | |
if __name__ == "__main__": | |
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from enum import Enum as RegularEnum | |
from enum import auto | |
from sqlalchemy import Column | |
from sqlalchemy import Enum as DatabaseEnum | |
from sqlalchemy import ForeignKey, Integer, String | |
from sqlalchemy.engine import create_engine | |
from sqlalchemy.orm import declarative_base, relationship, sessionmaker | |
from sqlalchemy.sql.expression import and_, select, text | |
# Initialize DB settings | |
engine = create_engine("sqlite://") | |
session = sessionmaker(bind=engine)() | |
Base = declarative_base() | |
class Gender(RegularEnum): | |
MALE = auto() | |
FEMALE = auto() | |
class Person(Base): | |
__tablename__ = "person" | |
id = Column(Integer, primary_key=True, autoincrement=True) | |
name = Column(String(64), unique=True, nullable=False) | |
age = Column(Integer, index=True, nullable=False) | |
gender = Column(DatabaseEnum(Gender, name="gender_types")) | |
# 1:N relationships | |
addresses = relationship("Address", back_populates="person", cascade="all, delete-orphan") | |
def __repr__(self): | |
return f"<Person id={self.id} name={self.name} age={self.age}>" | |
class Address(Base): | |
__tablename__ = "address" | |
id = Column(Integer, primary_key=True, autoincrement=True) | |
person_id = Column(ForeignKey("person.id"), nullable=False) | |
street = Column(String(256), nullable=False) | |
state = Column(String(8), index=True, nullable=False) | |
zip_code = Column(String(16), index=True, nullable=False) | |
# N:1 relationships | |
person = relationship("Person", back_populates="addresses") | |
def __repr__(self): | |
return f"<Address id={self.id} street={self.street} state={self.state} zip_code={self.zip_code}>" | |
def main(): | |
#### | |
# ORM demo | |
# https://docs.sqlalchemy.org/en/14/orm/tutorial.html | |
#### | |
# Configure clean tables with metadata | |
Base.metadata.drop_all(bind=engine) | |
Base.metadata.create_all(bind=engine) | |
# Create John and Jane from scratch | |
person_john = Person(name="John", age=30, gender=Gender.MALE) | |
person_jane = Person(name="Jane", age=30, gender=Gender.FEMALE) | |
session.add_all([person_john, person_jane]) | |
assert person_john.id is None | |
assert person_jane.id is None | |
# Flush John and Jane to get real IDs for them | |
session.flush() | |
assert person_john.id is not None | |
assert person_jane.id is not None | |
assert len(session.query(Person).all()) == 2 | |
# Rollback results in cleanup | |
session.rollback() | |
assert len(session.query(Person).all()) == 0 | |
# NOTE: OLD John and Jane still have OLD IDs | |
assert person_john.id is not None | |
assert person_jane.id is not None | |
# Commit John and Jane | |
session.add_all([person_jane, person_john]) | |
session.commit() | |
# Rollback results in nothing | |
session.rollback() | |
assert len(session.query(Person).all()) == 2 | |
# Flush Bob and Mary | |
person_bob = Person(name="Bob", age=25, gender=Gender.MALE) | |
person_mary = Person(name="Mary", age=25, gender=Gender.FEMALE) | |
session.add_all([person_bob, person_mary]) | |
session.flush() | |
# Flush addresses | |
address_bob = Address( | |
person_id=person_bob.id, | |
street="1234 Fiction Avenue", | |
state="CA", | |
zip_code="12345", | |
) | |
address_mary = Address( | |
person_id=person_mary.id, | |
street="1235 Fiction Avenue", | |
state="CA", | |
zip_code="12345", | |
) | |
session.add_all([address_bob, address_mary]) | |
session.flush() | |
# First address rollback results in cleanup | |
assert len(session.query(Person).all()) == 4 | |
assert len(session.query(Address).all()) == 2 | |
session.rollback() | |
assert len(session.query(Person).all()) == 2 | |
assert len(session.query(Address).all()) == 0 | |
# Second address rollback results in nothing | |
session.rollback() | |
assert len(session.query(Person).all()) == 2 | |
assert len(session.query(Address).all()) == 0 | |
# Add data back in again | |
session.add_all([person_bob, person_mary]) | |
session.flush() | |
session.add_all([address_bob, address_mary]) | |
session.commit() | |
# Rollback results in nothing | |
assert len(session.query(Person).all()) == 4 | |
assert len(session.query(Address).all()) == 2 | |
session.rollback() | |
assert len(session.query(Person).all()) == 4 | |
assert len(session.query(Address).all()) == 2 | |
# Add a couple more addresses to Bob and Mary | |
session.add_all( | |
[ | |
Address( | |
person_id=person_bob.id, | |
street="1236 Fiction Ave", | |
state="CA", | |
zip_code="12345", | |
), | |
Address( | |
person_id=person_bob.id, | |
street="1237 Fiction Ave", | |
state="CA", | |
zip_code="12345", | |
), | |
Address( | |
person_id=person_mary.id, | |
street="1238 Fiction Ave", | |
state="CA", | |
zip_code="12345", | |
), | |
Address( | |
person_id=person_mary.id, | |
street="1239 Fiction Ave", | |
state="CA", | |
zip_code="12345", | |
), | |
] | |
) | |
session.commit() | |
# Bob and Mary do have places | |
assert len(person_bob.addresses) == 3 | |
assert len(person_mary.addresses) == 3 | |
# John and Jane do NOT have places | |
assert len(person_john.addresses) == 0 | |
assert len(person_jane.addresses) == 0 | |
# Delete Bob and his addresses with cascade functionality | |
session.delete(person_bob) | |
session.commit() | |
bob_or_none = session.query(Person).filter(Person.name == "Bob").one_or_none() | |
all_addresses = session.query(Address).all() | |
assert bob_or_none is None | |
assert len(all_addresses) == 3 | |
assert not any(address.person.name == "Bob" for address in all_addresses) | |
# NOTE: OLD Bob still refers to the OLD addresses | |
assert len(person_bob.addresses) == 3 | |
# Delete one address from Mary | |
mary_address_first = person_mary.addresses[0] | |
session.delete(mary_address_first) | |
session.commit() | |
assert mary_address_first not in person_mary.addresses | |
# NOTE: OLD Mary address still has an ID | |
assert mary_address_first.id is not None | |
#### | |
# Core demo | |
# https://docs.sqlalchemy.org/en/14/core/tutorial.html | |
#### | |
conn = engine.connect() | |
assert conn.closed is False | |
# Select tables for upcoming queries | |
person_table = Person.__table__ | |
address_table = Address.__table__ | |
# Select all person columns | |
s = select(person_table) | |
for row in conn.execute(s): | |
print(f"[all] {row}") | |
# Select three address columns | |
s = select(address_table.c.street, address_table.c.state, address_table.c.zip_code) | |
for row in conn.execute(s): | |
print(f"[pick] {row}") | |
# Select inner join two tables with multiple constraints | |
c = and_( | |
person_table.c.id == address_table.c.person_id, | |
person_table.c.age > 20, | |
person_table.c.gender == Gender.FEMALE, | |
person_table.c.name == "Mary", | |
) | |
s = select(person_table.c.name, address_table.c.street, address_table.c.zip_code).where(c) | |
for row in conn.execute(s): | |
print(f"[and] {row}") | |
# Select with text | |
s = text( | |
""" | |
SELECT person.name FROM person | |
WHERE person.age > :person_age | |
AND person.gender = :person_gender | |
""" | |
).bindparams(person_age=25, person_gender=Gender.MALE.name) | |
for row in conn.execute(s): | |
print(f"[text] {row}") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Helpful links:
https://docs.sqlalchemy.org/en/13/orm/session_basics.html
https://docs.sqlalchemy.org/en/13/orm/session_state_management.html
https://docs.sqlalchemy.org/en/13/orm/events.html