Skip to content

Instantly share code, notes, and snippets.

@huangsam
Last active March 1, 2023 06:36
Show Gist options
  • Save huangsam/00c0deb23a755f320ef4b84c38caeedc to your computer and use it in GitHub Desktop.
Save huangsam/00c0deb23a755f320ef4b84c38caeedc to your computer and use it in GitHub Desktop.
Testing out SQLAlchemy 2.0
from datetime import datetime, timezone
from typing import Any, Optional, Sequence, Type
from sqlalchemy import (
DateTime,
Dialect,
ForeignKey,
Result,
Row,
String,
TypeDecorator,
UniqueConstraint,
create_engine,
event,
insert,
select,
)
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, sessionmaker
from sqlalchemy.sql.compiler import SQLCompiler
engine = create_engine("sqlite+pysqlite:///:memory:")
Session = sessionmaker(engine)
def utc_now():
return datetime.now(tz=timezone.utc)
class UtcDateTime(TypeDecorator):
# https://github.com/spoqa/sqlalchemy-utc/issues/16
# https://docs.sqlalchemy.org/en/20/core/custom_types.html
impl = DateTime
cache_ok = False
@property
def python_type(self) -> Type[Any]:
"""Python type object expected to be returned."""
return datetime
def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[datetime]:
"""Process input data to be a datetime with UTC timezone."""
if value is None:
return None
elif not isinstance(value, datetime):
raise ValueError(f"{value} is not a datetime instance")
elif value.tzinfo is None:
raise ValueError(f"{value} is timezone-naive")
return value.astimezone(tz=timezone.utc)
def process_result_value(self, value: Optional[datetime], dialect: Dialect) -> Optional[datetime]:
"""Process output data to be a datetime with UTC timezone."""
if value is None:
return None
return value.replace(tzinfo=timezone.utc)
def process_literal_param(self, value: Optional[datetime], dialect: Dialect) -> datetime:
"""Process literal data during SQL compilation of a statement."""
if value is None:
raise ValueError(f"{value} is not a datetime instance")
elif value.tzinfo is None:
raise ValueError(f"{value} is timezone-naive")
return value.astimezone(tz=timezone.utc)
class Base(DeclarativeBase):
# https://docs.sqlalchemy.org/en/20/orm/mapping_styles.html#orm-declarative-mapping
pass
class ToyProduct(Base):
__tablename__ = "toy_product"
id: Mapped[int] = mapped_column(primary_key=True)
# https://stackoverflow.com/questions/8295131/best-practices-for-sql-varchar-column-length
product_name: Mapped[str] = mapped_column(String(255))
team_name: Mapped[str] = mapped_column(String(255), nullable=True)
platform_name: Mapped[str] = mapped_column(String(255))
created_time: Mapped[datetime] = mapped_column(UtcDateTime, default=utc_now())
updated_time: Mapped[datetime] = mapped_column(UtcDateTime, nullable=True, onupdate=utc_now())
__table_args__ = (UniqueConstraint("product_name", "platform_name"),)
def __str__(self):
return f"'{self.product_name}' was created on {self.platform_name} at {self.created_time}"
class ToyProductEvent(Base):
__tablename__ = "toy_product_change"
id: Mapped[int] = mapped_column(primary_key=True)
product_id: Mapped[int] = mapped_column(ForeignKey("toy_product.id"))
created_time: Mapped[datetime] = mapped_column(UtcDateTime, default=utc_now())
def register_orm_listeners(session):
# https://docs.sqlalchemy.org/en/20/orm/events.html
@event.listens_for(session, "after_attach")
def process_before_attach(_session, instance):
print(f"we ARE at before_attach with {repr(instance)}")
@event.listens_for(session, "after_flush")
def process_after_flush(_session, flush_context):
print("we ARE at after_flush...")
# https://docs.sqlalchemy.org/en/20/orm/session_api.html#sqlalchemy.orm.Session.new
for instance in session.new:
if isinstance(instance, ToyProduct):
session.add(ToyProductEvent(product_id=instance.id))
elif isinstance(instance, ToyProductEvent):
print("we GOT a new product event persisted")
@event.listens_for(session, "before_commit")
def process_before_commit(_session):
print("we ARE at before_commit...")
def main():
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
# https://docs.sqlalchemy.org/en/20/orm/session_basics.html
with Session.begin() as session:
# Dynamically add event handlers for ORM events
register_orm_listeners(session)
# Execute a bulk insert operation for table entries - note
# that these inserts do not trigger ORM session events like the
# other DB model operations do
# https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html
session.execute(
insert(ToyProduct),
[
{
"product_name": "Tetris",
"team_name": "TTC",
"platform_name": "WEB",
},
{
"product_name": "Monopoly",
"team_name": "Hasbro",
"platform_name": "ANDROID",
},
{
"product_name": "Candy Land",
"team_name": "Hasbro",
"platform_name": "IOS",
},
],
)
# Triggers attach events
session.add(ToyProduct(product_name="Chess", team_name="JC", platform_name="IOS"))
session.add(ToyProduct(product_name="Checkers", team_name="JC", platform_name="IOS"))
# Triggers flush event
session.flush()
# Define a SELECT query with filters and sorting criteria
stmt = (
select(ToyProduct)
.where(ToyProduct.platform_name.in_(["WEB", "IOS"]))
.where(ToyProduct.created_time < utc_now())
.order_by(ToyProduct.platform_name)
)
# Get the SQL query as a string
# https://docs.sqlalchemy.org/en/20/faq/sqlexpressions.html#rendering-postcompile-parameters-as-bound-parameters
compiled_query: SQLCompiler = stmt.compile(engine, compile_kwargs={"literal_binds": True})
print(f"---\nSQL SELECT query to run:\n\n{compiled_query}\n---")
# Execute a sample SELECT query
result: Result = session.execute(stmt)
if result is None:
return
use_scalars = True
if use_scalars:
# Note that serialized objects can be queried via ScalarResult
product_seq: Sequence[ToyProduct] = result.scalars().all()
for product in product_seq:
print(product)
else:
# Note that row objects can be queried via Result
product_rows: Sequence[Row] = result.all()
for row in product_rows:
# Row data is accessed with index or attribute name
assert row[0] == row.ToyProduct, "Something is wrong with SQLAlchemy!"
product: ToyProduct = row.ToyProduct
print(product)
if __name__ == "__main__":
main()