Last active
March 1, 2023 06:36
-
-
Save huangsam/00c0deb23a755f320ef4b84c38caeedc to your computer and use it in GitHub Desktop.
Testing out SQLAlchemy 2.0
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datetime import datetime, timezone | |
from typing import Any, Optional, Sequence, Type | |
from sqlalchemy import ( | |
DateTime, | |
Dialect, | |
ForeignKey, | |
Result, | |
Row, | |
String, | |
TypeDecorator, | |
UniqueConstraint, | |
create_engine, | |
event, | |
insert, | |
select, | |
) | |
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, sessionmaker | |
from sqlalchemy.sql.compiler import SQLCompiler | |
engine = create_engine("sqlite+pysqlite:///:memory:") | |
Session = sessionmaker(engine) | |
def utc_now(): | |
return datetime.now(tz=timezone.utc) | |
class UtcDateTime(TypeDecorator): | |
# https://github.com/spoqa/sqlalchemy-utc/issues/16 | |
# https://docs.sqlalchemy.org/en/20/core/custom_types.html | |
impl = DateTime | |
cache_ok = False | |
@property | |
def python_type(self) -> Type[Any]: | |
"""Python type object expected to be returned.""" | |
return datetime | |
def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[datetime]: | |
"""Process input data to be a datetime with UTC timezone.""" | |
if value is None: | |
return None | |
elif not isinstance(value, datetime): | |
raise ValueError(f"{value} is not a datetime instance") | |
elif value.tzinfo is None: | |
raise ValueError(f"{value} is timezone-naive") | |
return value.astimezone(tz=timezone.utc) | |
def process_result_value(self, value: Optional[datetime], dialect: Dialect) -> Optional[datetime]: | |
"""Process output data to be a datetime with UTC timezone.""" | |
if value is None: | |
return None | |
return value.replace(tzinfo=timezone.utc) | |
def process_literal_param(self, value: Optional[datetime], dialect: Dialect) -> datetime: | |
"""Process literal data during SQL compilation of a statement.""" | |
if value is None: | |
raise ValueError(f"{value} is not a datetime instance") | |
elif value.tzinfo is None: | |
raise ValueError(f"{value} is timezone-naive") | |
return value.astimezone(tz=timezone.utc) | |
class Base(DeclarativeBase): | |
# https://docs.sqlalchemy.org/en/20/orm/mapping_styles.html#orm-declarative-mapping | |
pass | |
class ToyProduct(Base): | |
__tablename__ = "toy_product" | |
id: Mapped[int] = mapped_column(primary_key=True) | |
# https://stackoverflow.com/questions/8295131/best-practices-for-sql-varchar-column-length | |
product_name: Mapped[str] = mapped_column(String(255)) | |
team_name: Mapped[str] = mapped_column(String(255), nullable=True) | |
platform_name: Mapped[str] = mapped_column(String(255)) | |
created_time: Mapped[datetime] = mapped_column(UtcDateTime, default=utc_now()) | |
updated_time: Mapped[datetime] = mapped_column(UtcDateTime, nullable=True, onupdate=utc_now()) | |
__table_args__ = (UniqueConstraint("product_name", "platform_name"),) | |
def __str__(self): | |
return f"'{self.product_name}' was created on {self.platform_name} at {self.created_time}" | |
class ToyProductEvent(Base): | |
__tablename__ = "toy_product_change" | |
id: Mapped[int] = mapped_column(primary_key=True) | |
product_id: Mapped[int] = mapped_column(ForeignKey("toy_product.id")) | |
created_time: Mapped[datetime] = mapped_column(UtcDateTime, default=utc_now()) | |
def register_orm_listeners(session): | |
# https://docs.sqlalchemy.org/en/20/orm/events.html | |
@event.listens_for(session, "after_attach") | |
def process_before_attach(_session, instance): | |
print(f"we ARE at before_attach with {repr(instance)}") | |
@event.listens_for(session, "after_flush") | |
def process_after_flush(_session, flush_context): | |
print("we ARE at after_flush...") | |
# https://docs.sqlalchemy.org/en/20/orm/session_api.html#sqlalchemy.orm.Session.new | |
for instance in session.new: | |
if isinstance(instance, ToyProduct): | |
session.add(ToyProductEvent(product_id=instance.id)) | |
elif isinstance(instance, ToyProductEvent): | |
print("we GOT a new product event persisted") | |
@event.listens_for(session, "before_commit") | |
def process_before_commit(_session): | |
print("we ARE at before_commit...") | |
def main(): | |
Base.metadata.drop_all(engine) | |
Base.metadata.create_all(engine) | |
# https://docs.sqlalchemy.org/en/20/orm/session_basics.html | |
with Session.begin() as session: | |
# Dynamically add event handlers for ORM events | |
register_orm_listeners(session) | |
# Execute a bulk insert operation for table entries - note | |
# that these inserts do not trigger ORM session events like the | |
# other DB model operations do | |
# https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html | |
session.execute( | |
insert(ToyProduct), | |
[ | |
{ | |
"product_name": "Tetris", | |
"team_name": "TTC", | |
"platform_name": "WEB", | |
}, | |
{ | |
"product_name": "Monopoly", | |
"team_name": "Hasbro", | |
"platform_name": "ANDROID", | |
}, | |
{ | |
"product_name": "Candy Land", | |
"team_name": "Hasbro", | |
"platform_name": "IOS", | |
}, | |
], | |
) | |
# Triggers attach events | |
session.add(ToyProduct(product_name="Chess", team_name="JC", platform_name="IOS")) | |
session.add(ToyProduct(product_name="Checkers", team_name="JC", platform_name="IOS")) | |
# Triggers flush event | |
session.flush() | |
# Define a SELECT query with filters and sorting criteria | |
stmt = ( | |
select(ToyProduct) | |
.where(ToyProduct.platform_name.in_(["WEB", "IOS"])) | |
.where(ToyProduct.created_time < utc_now()) | |
.order_by(ToyProduct.platform_name) | |
) | |
# Get the SQL query as a string | |
# https://docs.sqlalchemy.org/en/20/faq/sqlexpressions.html#rendering-postcompile-parameters-as-bound-parameters | |
compiled_query: SQLCompiler = stmt.compile(engine, compile_kwargs={"literal_binds": True}) | |
print(f"---\nSQL SELECT query to run:\n\n{compiled_query}\n---") | |
# Execute a sample SELECT query | |
result: Result = session.execute(stmt) | |
if result is None: | |
return | |
use_scalars = True | |
if use_scalars: | |
# Note that serialized objects can be queried via ScalarResult | |
product_seq: Sequence[ToyProduct] = result.scalars().all() | |
for product in product_seq: | |
print(product) | |
else: | |
# Note that row objects can be queried via Result | |
product_rows: Sequence[Row] = result.all() | |
for row in product_rows: | |
# Row data is accessed with index or attribute name | |
assert row[0] == row.ToyProduct, "Something is wrong with SQLAlchemy!" | |
product: ToyProduct = row.ToyProduct | |
print(product) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Other useful resources: