Skip to content

Instantly share code, notes, and snippets.

@colonelpanic8
Created October 23, 2024 01:00
Show Gist options
  • Save colonelpanic8/34c2c8fdba0de185003b4637e9e1cadc to your computer and use it in GitHub Desktop.
Save colonelpanic8/34c2c8fdba0de185003b4637e9e1cadc to your computer and use it in GitHub Desktop.
from typing import TYPE_CHECKING, Optional
import sqlalchemy as sa
from numpy import pi
from sqlalchemy import type_coerce
from sqlalchemy.dialects.postgresql import ARRAY, ENUM
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import Mapped, deferred, mapped_column, relationship
from sqlalchemy.types import String, TypeDecorator
from railbird import config
from railbird.datatypes import gql
from . import query_builder_types as qbt
from .base import Base, _qb
if TYPE_CHECKING:
from .shot import ShotModel
# The default angle for determining if a shot is Left, Right, or Straight.
DEFAULT_DIRECTION_ANGLE_THRESHOLD = 10
DEFAULT_DRAW_ANGLE = 100
DEFAULT_FOLLOW_ANGLE = 70
SpinTypeEnum = ENUM(
gql.SpinTypeEnum,
name="spin_type_enum",
schema="railbird",
)
PocketEnum = ENUM(
gql.PocketEnum,
name="pocket_enum",
schema="railbird",
)
WallTypeEnum = ENUM(
gql.WallTypeEnum,
name="wall_enum",
schema="railbird",
)
ShotDirectionEnum = ENUM(
gql.ShotDirectionEnum,
name="direction_enum",
schema="railbird",
)
class EnumType(TypeDecorator):
impl = String
cache_ok = True
def __init__(self, enum_class):
super().__init__()
self.enum_class = enum_class
def make_value(self, name):
return type_coerce(name, self.enum_class)
__call__ = make_value
def process_result_value(self, value, dialect):
if value is not None:
return self.enum_class(value)
return None
DecoratedSpinType = EnumType(SpinTypeEnum)
class CueObjectFeatures(Base):
"""Features that are defined when the cue ball collides with an object ball."""
__tablename__ = "cue_object_features"
shot_id: Mapped[int] = mapped_column(
sa.BIGINT,
sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True,
nullable=False,
)
cue_object_distance: Mapped[float] = mapped_column(
sa.DECIMAL(precision=7, scale=3),
index=True,
info=_qb(),
)
cue_object_angle: Mapped[float] = mapped_column(
sa.DECIMAL(precision=6, scale=3),
index=True,
info=_qb(),
)
cue_angle_after_object: Mapped[float] = mapped_column(
sa.DECIMAL(precision=6, scale=3),
nullable=True,
index=True,
info=_qb(),
)
spin_type: Mapped[gql.SpinTypeEnum] = mapped_column(
SpinTypeEnum,
nullable=True,
index=True,
info=_qb(
{
"others": [
{
"name": "spin_type_counts",
"selectable_constructor": (
qbt.QueryBuilderEnumCountsSelectable.constructor_for_enum(
gql.SpinTypeEnum
)
),
}
]
}
),
)
cue_speed_after_object: Mapped[float] = mapped_column(
sa.DECIMAL(precision=7, scale=3),
nullable=True,
index=True,
info=_qb(),
)
cue_ball_speed: Mapped[float] = mapped_column(
sa.DECIMAL(precision=7, scale=3),
index=True,
info=_qb(),
)
shot_direction: Mapped[gql.ShotDirectionEnum] = mapped_column(
ShotDirectionEnum,
index=True,
info=_qb(
{
"others": [
{
"name": "shot_direction_counts",
"selectable_constructor": (
qbt.QueryBuilderEnumCountsSelectable.constructor_for_enum(
gql.ShotDirectionEnum
)
),
}
]
}
),
)
shot: Mapped["ShotModel"] = relationship(
"ShotModel", back_populates="cue_object_features"
)
DecoratedSpinType = EnumType(SpinTypeEnum)
@classmethod
def spin_type_by(
cls,
follow_angle_threshold=DEFAULT_FOLLOW_ANGLE,
draw_angle_threshold=DEFAULT_DRAW_ANGLE,
):
return sa.case(
(
cls.cue_angle_after_object >= draw_angle_threshold,
DecoratedSpinType("DRAW"),
),
(
cls.cue_angle_after_object <= follow_angle_threshold,
DecoratedSpinType("FOLLOW"),
),
else_=DecoratedSpinType("CENTER"),
)
@classmethod
def is_straight_by(
cls, angle_threshold=DEFAULT_DIRECTION_ANGLE_THRESHOLD
) -> sa.ColumnElement[bool]:
return cls.cue_object_angle <= angle_threshold
@classmethod
def is_left_by(
cls, angle_threshold=DEFAULT_DIRECTION_ANGLE_THRESHOLD
) -> sa.ColumnElement[bool]:
return sa.and_(
cls.shot_direction == gql.ShotDirectionEnum.LEFT,
cls.cue_object_angle > angle_threshold,
)
@classmethod
def is_right_by(
cls, angle_threshold=DEFAULT_DIRECTION_ANGLE_THRESHOLD
) -> sa.ColumnElement[bool]:
return sa.and_(
cls.shot_direction == gql.ShotDirectionEnum.RIGHT,
cls.cue_object_angle > angle_threshold,
)
@classmethod
def __declare_last__(cls):
cls.is_straight = deferred(cls.is_straight_by(), info=_qb())
cls.is_left = deferred(cls.is_left_by(), info=_qb())
cls.is_right = deferred(cls.is_right_by(), info=_qb())
# XXX: Commenting this out to get spin type to resolve the column.
# Someone can revert this at some point but I couldnt figure it out using qb.
# cls.spin_type = deferred(cls.spin_type_by())
class PocketingIntentionFeatures(Base):
__tablename__ = "pocketing_intention_features"
shot_id: Mapped[int] = mapped_column(
sa.BIGINT,
sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True,
nullable=False,
)
target_pocket_angle: Mapped[Optional[float]] = mapped_column(
sa.FLOAT, nullable=True, index=True, info=_qb()
)
target_pocket_angle_direction: Mapped[Optional[gql.ShotDirectionEnum]] = (
mapped_column(ShotDirectionEnum, nullable=True, index=True, info=_qb())
)
backcut: Mapped[Optional[bool]] = mapped_column(
sa.BOOLEAN, nullable=True, index=True, info=_qb()
)
target_pocket_distance: Mapped[float] = mapped_column(
sa.DECIMAL(precision=7, scale=3),
index=True,
info=_qb(),
)
make: Mapped[bool] = mapped_column(
sa.BOOLEAN,
info={
"query_builder": {
"others": [
{
"name": "make_percentage",
"selectable_constructor": qbt.QueryBuilderBoolProportionSelectable,
}
]
}
},
index=True,
nullable=True,
)
intended_pocket_type: Mapped[gql.PocketEnum] = mapped_column(
PocketEnum,
index=True,
info=_qb(),
)
difficulty: Mapped[float] = mapped_column(
sa.FLOAT,
info={
"query_builder": {
"others": [
{
"name": "average_difficulty",
"selectable_constructor": qbt.QueryBuilderAverageSelectable,
}
]
}
},
index=True,
nullable=True,
)
difficulty_git_commit: Mapped[str] = mapped_column(
sa.VARCHAR(length=150),
nullable=True,
default=config.git_commit_hash,
)
shot: Mapped["ShotModel"] = relationship(
"ShotModel", back_populates="pocketing_intention_features"
)
class ErrorFeatures(Base):
__tablename__ = "error_features"
errors: Mapped[str]
warnings: Mapped[str]
shot_id: Mapped[int] = mapped_column(
sa.BIGINT,
sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True,
nullable=False,
)
shot: Mapped["ShotModel"] = relationship(
"ShotModel", back_populates="error_features"
)
class BankFeatures(Base):
__tablename__ = "bank_features"
shot_id: Mapped[int] = mapped_column(
sa.BIGINT,
sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True,
nullable=False,
)
walls_hit = mapped_column(
ARRAY(WallTypeEnum),
index=True,
info=_qb(
{
"name": "bank_walls_hit",
"filter_constructor": qbt.QueryBuilderRangeFilter,
}
),
)
bank_angle: Mapped[float] = mapped_column(
sa.DECIMAL(precision=7, scale=3), index=True, info=_qb()
)
distance: Mapped[float] = mapped_column(
sa.DECIMAL(precision=7, scale=3),
index=True,
info=_qb(dict(name="bank_distance")),
)
shot: Mapped["ShotModel"] = relationship(
"ShotModel", back_populates="bank_features"
)
class KickFeatures(Base):
__tablename__ = "kick_features"
shot_id: Mapped[int] = mapped_column(
sa.BIGINT,
sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True,
nullable=False,
)
walls_hit = mapped_column(ARRAY(WallTypeEnum), index=True)
angle: Mapped[float] = mapped_column(
sa.DECIMAL(precision=7, scale=3),
index=True,
info=_qb(dict(name="kick_angle")),
)
distance: Mapped[float] = mapped_column(
sa.DECIMAL(precision=7, scale=3),
index=True,
info=_qb(dict(name="kick_distance")),
)
shot: Mapped["ShotModel"] = relationship(
"ShotModel", back_populates="kick_features"
)
DEFAULT_OVER_UNDER_CUT_THRESHOLD = 5
class MissFeatures(Base):
__tablename__ = "miss_features"
shot_id: Mapped[int] = mapped_column(
sa.BIGINT,
sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"),
primary_key=True,
nullable=False,
)
miss_angle: Mapped[float] = mapped_column(
sa.DECIMAL(precision=7, scale=3),
index=True,
info=_qb(),
)
@declared_attr
def miss_angle_in_degrees(cls) -> Mapped[float]:
return deferred(cls.miss_angle * (180.0 / pi), info=_qb()) # type: ignore
shot: Mapped["ShotModel"] = relationship(
"ShotModel", back_populates="miss_features"
)
@classmethod
def is_undercut_by(
cls, angle_threshold=DEFAULT_OVER_UNDER_CUT_THRESHOLD
) -> sa.ColumnElement[bool]:
return cls.miss_angle_in_degrees <= (-1 * angle_threshold)
@classmethod
def is_overcut_by(
cls, angle_threshold=DEFAULT_OVER_UNDER_CUT_THRESHOLD
) -> sa.ColumnElement[bool]:
return cls.miss_angle_in_degrees >= angle_threshold
@classmethod
def is_miss_in_direction_by(
cls,
angle_threshold=DEFAULT_OVER_UNDER_CUT_THRESHOLD,
direction=gql.ShotDirectionEnum.LEFT,
):
return (
sa.select(
sa.or_(
sa.and_(
CueObjectFeatures.shot_direction == direction,
cls.miss_angle_in_degrees >= angle_threshold,
),
sa.and_(
CueObjectFeatures.shot_direction != direction,
cls.miss_angle_in_degrees <= -angle_threshold,
),
)
)
.where(CueObjectFeatures.shot_id == cls.shot_id)
.scalar_subquery()
)
@classmethod
def __declare_last__(cls):
cls.is_overcut = deferred(cls.is_overcut_by(), info=_qb())
cls.is_undercut = deferred(cls.is_undercut_by(), info=_qb())
cls.is_left_miss = deferred(
cls.is_miss_in_direction_by(
direction=gql.ShotDirectionEnum.LEFT,
),
info=_qb(),
)
cls.is_right_miss = deferred(
cls.is_miss_in_direction_by(
direction=gql.ShotDirectionEnum.RIGHT,
),
info=_qb(),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment