Created
April 5, 2020 11:40
-
-
Save jokull/64d31c0a5630e8a26416502b9816d81d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from enum import Enum as PyEnum | |
from collections import namedtuple | |
from fastapi import Depends, FastAPI, HTTPException | |
from starlette.requests import Request | |
from starlette.templating import Jinja2Templates | |
from sqlalchemy import ( | |
create_engine, | |
Column, | |
ForeignKey, | |
Integer, | |
String, | |
DateTime, | |
Enum, | |
func, | |
) | |
from sqlalchemy.orm import relationship, sessionmaker, Session | |
from sqlalchemy.ext.declarative import declarative_base | |
SQLALCHEMY_DATABASE_URL = os.environ.get("DATABASE_URL") | |
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={}) | |
Base = declarative_base() | |
class Sport(Base): | |
__tablename__ = "sports" | |
id = Column(Integer, primary_key=True, index=True) | |
name = Column(String) | |
class League(Base): | |
__tablename__ = "leagues" | |
id = Column(Integer, primary_key=True, index=True) | |
name = Column(String) | |
sport_id = Column(Integer, ForeignKey(Sport.id), nullable=True) | |
sport = relationship(Sport) | |
class Match(Base): | |
__tablename__ = "matches" | |
id = Column(Integer, primary_key=True, index=True) | |
name = Column(String) | |
start = Column(DateTime) | |
league_id = Column(Integer, ForeignKey(League.id)) | |
league = relationship(League) | |
EnumValue = namedtuple("EnumValue", ("slug", "label")) | |
class EventTypeEnum(PyEnum): | |
point = EnumValue("point", "Point Scored") | |
penalty = EnumValue("penalty", "Penalty") | |
foul = EnumValue("foul", "Foul") | |
class Event(Base): | |
__tablename__ = "events" | |
id = Column(Integer, primary_key=True, index=True) | |
type = Column(Enum(EventTypeEnum), nullable=True) | |
match_id = Column(Integer, ForeignKey(Match.id), nullable=True) | |
match = relationship(Match) | |
remarks = Column(String) | |
class MatchWrap: | |
""" Thin wrapper around Match objects. Attribute are passed to the `_match` | |
object so that it acts like the wrapped object with the `counts` attribute | |
added. | |
""" | |
def __init__(self, match, counts): | |
self._match = match | |
self.counts = dict(zip((enum for enum in ENUMS), counts)) | |
self.year = match.start.year | |
def __getattribute__(self, key): | |
if key != "_match" and hasattr(self._match, key): | |
return getattr(self._match, key) | |
return super().__getattribute__(key) | |
@property | |
def event_count(self): | |
return sum(self.counts.values()) | |
class NoneEnum: | |
name = "other" | |
value = {"label": "No event category"} | |
ENUMS = list(EventTypeEnum) + [NoneEnum()] | |
class MatchView: | |
""" Wraps a match query that number of each type of event. This is useful to | |
display an stream of matches in a card or table layout with useful metadata. | |
Queries are constructed by iterating through an enum, applying an outerjoin | |
subquery with a count for each enum value. The query results are wrapped in | |
instances of the `MatchWrap` class for better accessing of the counter values. | |
What’s cool is that no enums are hardcoded. | |
Usage: | |
>>> for match in MatchView(db, sport, offset=0, limit=100): | |
>>> print(match.name, match.counter[EventTypeEnum.delayed]) | |
This would give you a list of matches along with number of events where the | |
type was `delayed`. | |
""" | |
def __init__(self, db, sport, offset, limit): | |
self.db = db | |
self.sport = sport | |
self.offset = offset | |
self.limit = limit | |
self.query = None | |
def get_query(self): | |
if self.query is not None: | |
return self.query | |
def get_sq(enum): | |
return ( | |
self.db.query( | |
func.count(Event.id).label("count"), | |
Event.match_id.label("match_id"), | |
) | |
.filter(Event.type == (None if isinstance(enum, NoneEnum) else enum)) | |
.group_by(Event.match_id) | |
).subquery(enum.name) | |
sqs = [(get_sq(enum), enum) for enum in ENUMS] | |
query = self.db.query(Match).select_from(Match).join(League) | |
for sq, enum in sqs: | |
query = query.outerjoin(sq, sq.c.match_id == Match.id).add_columns( | |
func.coalesce(sq.c.count, 0).label(enum.name) | |
) | |
query = ( | |
query.filter(League.sport_id == self.sport.id) | |
.order_by(Match.start.desc()) | |
.offset(self.offset) | |
) | |
if self.limit is not None: | |
query = query.limit(self.limit) | |
return query | |
def __iter__(self): | |
query = self.get_query() | |
for match, *counts in query: | |
yield MatchWrap(match, counts) | |
app = FastAPI() | |
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | |
def get_db(): | |
try: | |
db = SessionLocal() | |
yield db | |
finally: | |
db.close() | |
templates = Jinja2Templates(directory="templates") | |
@app.get("/sports/{sport_id}") | |
async def get_sport(request: Request, sport_id: str, db: Session = Depends(get_db)): | |
sport = db.query(Sport).get(sport_id) | |
if sport is None: | |
raise HTTPException(type_code=404, detail="Sport not found") | |
return templates.TemplateResponse( | |
"sport.html", | |
{ | |
"sport": sport, | |
"matches": MatchView(db, sport, offset=0, limit=100), | |
"request": request, | |
}, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment