Last active
February 15, 2024 21:44
-
-
Save paulwinex/6e9c53750774a233701adef4edd1d6dd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sqlalchemy import create_engine, Column, Integer, String, ForeignKey, Table | |
from sqlalchemy.orm import sessionmaker, relationship, backref, declared_attr, Mapped | |
from sqlalchemy.orm import declarative_base, Mapped, mapped_column | |
from sqlalchemy.orm.exc import NoResultFound | |
from sqlalchemy.sql.expression import and_ | |
from sqlalchemy import select, func, union_all, text | |
import asyncio | |
from sqlalchemy.ext.asyncio import ( | |
AsyncSession, | |
create_async_engine, | |
async_sessionmaker, | |
async_scoped_session, | |
) | |
BaseModel = declarative_base() | |
DATABASE_URL = "sqlite+aiosqlite:///:memory:" | |
class TaggableMixin: | |
@declared_attr | |
def tags(cls) -> Mapped[list["TagModel"]]: | |
return relationship( | |
"TagModel", | |
secondary=TagLinks.__table__, | |
primaryjoin=lambda: and_( | |
cls.id == TagLinks.model_id, | |
TagLinks.model_type == cls.__tablename__, | |
), | |
secondaryjoin=lambda: and_( | |
TagModel.id == TagLinks.tag_id, | |
TagLinks.model_type == cls.__tablename__, | |
), | |
overlaps="tags", | |
lazy="selectin", | |
uselist=True, | |
) | |
class TagModel(BaseModel): | |
__tablename__ = "tags" | |
id: Mapped[int] = mapped_column(primary_key=True) | |
name: Mapped[str] = mapped_column(unique=True, nullable=False) | |
parent_id: Mapped[int] = mapped_column(ForeignKey("tags.id"), nullable=True) | |
path: Mapped[str] = mapped_column(nullable=True) | |
children: Mapped[list["TagModel"]] = relationship( | |
"TagModel", | |
backref=backref("parent", remote_side=[id]), | |
lazy="selectin", | |
uselist=True, | |
) | |
async def update_path(self, session): | |
if self.parent_id is None: | |
self.path = f"/{self.id}/" # type: ignore | |
else: | |
parent = await session.get(TagModel, self.parent_id) | |
self.path = f"{parent.path}{self.id}/" # type: ignore | |
def __str__(self): | |
return f"{self.name} [{self.id}]" | |
def __repr__(self): | |
return f"<{self.name} {self.path}>" | |
class TagLinks(BaseModel): | |
__tablename__ = "tag_links" | |
id: Mapped[int] = mapped_column(primary_key=True) | |
tag_id: Mapped[int] = mapped_column(ForeignKey("tags.id")) | |
model_id: Mapped[int] | |
model_type: Mapped[str] = mapped_column(String(50)) | |
def __str__(self): | |
return f"{model_type}.{self.id} > TagModel.{self.tag_id}" | |
def __repr__(self): | |
return f"<TagLinks {self.__str__()}>" | |
class AssetModel(TaggableMixin, BaseModel): | |
__tablename__ = "asset" | |
id: Mapped[int] = mapped_column(primary_key=True) | |
name: Mapped[str] = mapped_column(index=True, unique=True, nullable=False) | |
label: Mapped[str] = mapped_column(index=True, nullable=False) | |
def __str__(self): | |
return f"{self.name} [{self.id}]" | |
def __repr__(self): | |
return f"<{self.__str__()}>" | |
class ProjectModel(TaggableMixin, BaseModel): | |
__tablename__ = "projects" | |
id: Mapped[int] = mapped_column(primary_key=True) | |
name: Mapped[str] = mapped_column(nullable=False, unique=True, index=True) | |
label: Mapped[str] = mapped_column(nullable=False, unique=True) | |
def __str__(self): | |
return f"{self.name} [{self.id}]" | |
def __repr__(self): | |
return f"<{self.__str__()}>" | |
engine = create_async_engine(url=DATABASE_URL, echo=False, future=True) | |
session_factory = async_sessionmaker( | |
bind=engine, | |
autoflush=False, | |
autocommit=False, | |
expire_on_commit=False, | |
) | |
async def create_tables() -> None: | |
async with engine.begin() as conn: | |
await conn.run_sync(BaseModel.metadata.create_all) | |
async def add_tag(session, name, parent_id=None): | |
tag = TagModel(name=name, parent_id=parent_id) | |
session.add(tag) | |
await session.flush() | |
await tag.update_path(session) | |
await session.commit() | |
return tag | |
async def add_object_with_tags(session, model_class, name, label, *tags): | |
obj = model_class(name=name, label=label) | |
session.add(obj) | |
await session.flush() | |
for tag in tags: | |
tag_link = TagLinks( | |
tag_id=tag.id, model_id=obj.id, model_type=model_class.__tablename__ | |
) | |
session.add(tag_link) | |
await session.commit() | |
return obj | |
async def get_objects_by_tag(session, model_class, tag_name): | |
tag_query = select(TagModel).where(TagModel.name == tag_name) | |
tag_result = await session.execute(tag_query) | |
tag = tag_result.scalar() | |
if tag is None: | |
return [] | |
objects_query = ( | |
select(model_class).join(model_class.tags).where(TagModel.id == tag.id) | |
) | |
objects_result = await session.execute(objects_query) | |
objects = objects_result.scalars().all() | |
return objects | |
async def find_objects_by_tag_and_descendants(session, model_class, tag_name): | |
tag_query = select(TagModel).where(TagModel.name == tag_name) | |
tag_result = await session.execute(tag_query) | |
tag = tag_result.scalar() | |
if tag is None: | |
return [] | |
tag_ids_query = select(TagModel.id).where(TagModel.path.like(f"{tag.path}%")) | |
tag_ids_result = await session.execute(tag_ids_query) | |
tag_ids = [t[0] for t in tag_ids_result.all()] | |
objects_query = ( | |
select(model_class).join(model_class.tags).where(TagModel.id.in_(tag_ids)) | |
) | |
objects_result = await session.execute(objects_query) | |
objects = objects_result.scalars().all() | |
return objects | |
async def get_all_tags(session): | |
query = select(TagModel) | |
result = await session.execute(query) | |
return result.scalars().all() | |
async def show_tags(session): | |
tags = await get_all_tags(session) | |
paths = sorted([(tag.path, tag.name) for tag in tags], key=lambda x: x[0]) | |
print("Paths:") | |
for tag in paths: | |
objects_tag = await find_objects_by_tag_and_descendants( | |
session, AssetModel, tag[1] | |
) | |
print("-" * (tag[0].count("/") - 1), tag[1], tag[0], ">", objects_tag) | |
async def add_tag_to_object(session, tag_name: str | TagModel, instance): | |
if isinstance(tag_name, TagModel): | |
tag = tag_name | |
else: | |
tag_query = select(TagModel).where(TagModel.name == tag_name) | |
tag_result = await session.execute(tag_query) | |
tag = tag_result.scalar() | |
if tag is None: | |
raise TagNotFoundError(tag_name) | |
tag_link = TagLinks( | |
tag_id=tag.id, model_id=instance.id, model_type=instance.__class__.__tablename__ | |
) | |
session.add(tag_link) | |
await session.commit() | |
await session.refresh(instance) | |
return instance | |
async def testing1(): | |
await create_tables() | |
async with session_factory() as session: | |
tag0 = await add_tag(session, "root") | |
tag1 = await add_tag(session, "tag1", parent_id=tag0.id) | |
tag11 = await add_tag(session, "tag1-1", parent_id=tag1.id) | |
tag12 = await add_tag(session, "tag1-2", parent_id=tag1.id) | |
tag2 = await add_tag(session, "tag2", parent_id=tag0.id) | |
tag21 = await add_tag(session, "tag2-1", parent_id=tag12.id) | |
project1 = await add_object_with_tags( | |
session, ProjectModel, "project1", "Project 1", tag0 | |
) | |
asset1 = await add_object_with_tags( | |
session, AssetModel, "asset1", "Asset 1", tag1 | |
) | |
asset2 = await add_object_with_tags( | |
session, AssetModel, "asset2", "Asset 2", tag2 | |
) | |
asset3 = await add_object_with_tags( | |
session, AssetModel, "asset21", "Asset 21", tag21 | |
) | |
asset4 = await add_object_with_tags( | |
session, AssetModel, "asset12", "Asset 12", tag12 | |
) | |
# Получение объектов по тегу | |
objects_tag1 = await get_objects_by_tag(session, ProjectModel, "root") | |
print(f"ProjectModels with tag 'root': {objects_tag1}") | |
objects_tag2 = await get_objects_by_tag(session, AssetModel, "tag1") | |
print(f"AssetModels with tag 'tag1': {objects_tag2}") | |
objects_tag3 = await get_objects_by_tag(session, AssetModel, "tag2") | |
print(f"AssetModels with tag 'tag2': {objects_tag3}") | |
objects_tag2_1 = await get_objects_by_tag(session, AssetModel, "tag2-1") | |
print(f"AssetModels with tag 'tag2_1': {objects_tag2_1}") | |
objects_tag2 = await get_objects_by_tag(session, AssetModel, "tag2") | |
print(f"AssetModels with tag 'tag2': {objects_tag2}") | |
objects_tag2_1 = await get_objects_by_tag(session, AssetModel, "tag1-2") | |
print(f"AssetModels with tag 'tag1-2': {objects_tag2_1}") | |
hi_objects = await find_objects_by_tag_and_descendants( | |
session, AssetModel, "tag1" | |
) | |
print("Hi objects for tag1", hi_objects) | |
await show_tags(session) | |
async def testing2(): | |
await create_tables() | |
async with session_factory() as session: | |
# create tags | |
tag0 = await add_tag(session, "root") | |
tag1 = await add_tag(session, "tag1", parent_id=tag0.id) | |
tag2 = await add_tag(session, "tag2", parent_id=tag1.id) | |
tag3 = await add_tag(session, "tag3", parent_id=tag2.id) | |
asset1 = await add_object_with_tags( | |
session, AssetModel, "asset1", "Asset 1", tag3 | |
) | |
objects_tag3 = await get_objects_by_tag(session, AssetModel, "tag3") | |
print(f"ProjectModels with tag 'tag3': {objects_tag3}") | |
async def testing3(): | |
async with session_factory() as session: | |
await create_tables() | |
# Создаем теги | |
tag1 = await add_tag(session, "Tag1") | |
tag2 = await add_tag(session, "Tag2", tag1.id) | |
tag3 = await add_tag(session, "Tag3", tag2.id) | |
# Создаем объекты с тегами | |
asset1 = await add_object_with_tags( | |
session, AssetModel, "Asset1", "Label1", tag1, tag2 | |
) | |
asset2 = await add_object_with_tags( | |
session, AssetModel, "Asset2", "Label2", tag3 | |
) | |
# Получаем объекты по тегу | |
assets_by_tag1 = await get_objects_by_tag(session, AssetModel, "Tag1") | |
print("Assets by Tag1:", assets_by_tag1) | |
# Получаем объекты по тегу и всем его потомкам | |
assets_by_tag1_and_descendants = await find_objects_by_tag_and_descendants( | |
session, AssetModel, "Tag1" | |
) | |
print("Assets by Tag1 and its descendants:", assets_by_tag1_and_descendants) | |
# Создаем проект с тегами | |
project1 = await add_object_with_tags( | |
session, ProjectModel, "Project1", "Project Label1", tag1, tag3 | |
) | |
# Получаем проекты по тегу | |
projects_by_tag1 = await get_objects_by_tag(session, ProjectModel, "Tag1") | |
print("Projects by Tag1:", projects_by_tag1) | |
# Получаем проекты по тегу и всем его потомкам | |
projects_by_tag1_and_descendants = await find_objects_by_tag_and_descendants( | |
session, ProjectModel, "Tag1" | |
) | |
print("Projects by Tag1 and its descendants:", projects_by_tag1_and_descendants) | |
print("-" * 10) | |
query = select(AssetModel).where(AssetModel.id == asset2.id) | |
res = await session.execute(query) | |
asset = res.scalar() | |
print("Asset Tags 1", asset.tags) | |
asset = await add_tag_to_object(session, tag1, asset) | |
print("Asset Tags 2", asset.tags) | |
await show_tags(session) | |
if __name__ == "__main__": | |
asyncio.run(testing1()) | |
# asyncio.run(testing2()) | |
# asyncio.run(testing3()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment