Created
October 23, 2024 01:00
-
-
Save colonelpanic8/34c2c8fdba0de185003b4637e9e1cadc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import TYPE_CHECKING, Optional | |
import sqlalchemy as sa | |
from numpy import pi | |
from sqlalchemy import type_coerce | |
from sqlalchemy.dialects.postgresql import ARRAY, ENUM | |
from sqlalchemy.ext.declarative import declared_attr | |
from sqlalchemy.orm import Mapped, deferred, mapped_column, relationship | |
from sqlalchemy.types import String, TypeDecorator | |
from railbird import config | |
from railbird.datatypes import gql | |
from . import query_builder_types as qbt | |
from .base import Base, _qb | |
if TYPE_CHECKING: | |
from .shot import ShotModel | |
# The default angle for determining if a shot is Left, Right, or Straight. | |
DEFAULT_DIRECTION_ANGLE_THRESHOLD = 10 | |
DEFAULT_DRAW_ANGLE = 100 | |
DEFAULT_FOLLOW_ANGLE = 70 | |
SpinTypeEnum = ENUM( | |
gql.SpinTypeEnum, | |
name="spin_type_enum", | |
schema="railbird", | |
) | |
PocketEnum = ENUM( | |
gql.PocketEnum, | |
name="pocket_enum", | |
schema="railbird", | |
) | |
WallTypeEnum = ENUM( | |
gql.WallTypeEnum, | |
name="wall_enum", | |
schema="railbird", | |
) | |
ShotDirectionEnum = ENUM( | |
gql.ShotDirectionEnum, | |
name="direction_enum", | |
schema="railbird", | |
) | |
class EnumType(TypeDecorator): | |
impl = String | |
cache_ok = True | |
def __init__(self, enum_class): | |
super().__init__() | |
self.enum_class = enum_class | |
def make_value(self, name): | |
return type_coerce(name, self.enum_class) | |
__call__ = make_value | |
def process_result_value(self, value, dialect): | |
if value is not None: | |
return self.enum_class(value) | |
return None | |
DecoratedSpinType = EnumType(SpinTypeEnum) | |
class CueObjectFeatures(Base): | |
"""Features that are defined when the cue ball collides with an object ball.""" | |
__tablename__ = "cue_object_features" | |
shot_id: Mapped[int] = mapped_column( | |
sa.BIGINT, | |
sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"), | |
primary_key=True, | |
nullable=False, | |
) | |
cue_object_distance: Mapped[float] = mapped_column( | |
sa.DECIMAL(precision=7, scale=3), | |
index=True, | |
info=_qb(), | |
) | |
cue_object_angle: Mapped[float] = mapped_column( | |
sa.DECIMAL(precision=6, scale=3), | |
index=True, | |
info=_qb(), | |
) | |
cue_angle_after_object: Mapped[float] = mapped_column( | |
sa.DECIMAL(precision=6, scale=3), | |
nullable=True, | |
index=True, | |
info=_qb(), | |
) | |
spin_type: Mapped[gql.SpinTypeEnum] = mapped_column( | |
SpinTypeEnum, | |
nullable=True, | |
index=True, | |
info=_qb( | |
{ | |
"others": [ | |
{ | |
"name": "spin_type_counts", | |
"selectable_constructor": ( | |
qbt.QueryBuilderEnumCountsSelectable.constructor_for_enum( | |
gql.SpinTypeEnum | |
) | |
), | |
} | |
] | |
} | |
), | |
) | |
cue_speed_after_object: Mapped[float] = mapped_column( | |
sa.DECIMAL(precision=7, scale=3), | |
nullable=True, | |
index=True, | |
info=_qb(), | |
) | |
cue_ball_speed: Mapped[float] = mapped_column( | |
sa.DECIMAL(precision=7, scale=3), | |
index=True, | |
info=_qb(), | |
) | |
shot_direction: Mapped[gql.ShotDirectionEnum] = mapped_column( | |
ShotDirectionEnum, | |
index=True, | |
info=_qb( | |
{ | |
"others": [ | |
{ | |
"name": "shot_direction_counts", | |
"selectable_constructor": ( | |
qbt.QueryBuilderEnumCountsSelectable.constructor_for_enum( | |
gql.ShotDirectionEnum | |
) | |
), | |
} | |
] | |
} | |
), | |
) | |
shot: Mapped["ShotModel"] = relationship( | |
"ShotModel", back_populates="cue_object_features" | |
) | |
DecoratedSpinType = EnumType(SpinTypeEnum) | |
@classmethod | |
def spin_type_by( | |
cls, | |
follow_angle_threshold=DEFAULT_FOLLOW_ANGLE, | |
draw_angle_threshold=DEFAULT_DRAW_ANGLE, | |
): | |
return sa.case( | |
( | |
cls.cue_angle_after_object >= draw_angle_threshold, | |
DecoratedSpinType("DRAW"), | |
), | |
( | |
cls.cue_angle_after_object <= follow_angle_threshold, | |
DecoratedSpinType("FOLLOW"), | |
), | |
else_=DecoratedSpinType("CENTER"), | |
) | |
@classmethod | |
def is_straight_by( | |
cls, angle_threshold=DEFAULT_DIRECTION_ANGLE_THRESHOLD | |
) -> sa.ColumnElement[bool]: | |
return cls.cue_object_angle <= angle_threshold | |
@classmethod | |
def is_left_by( | |
cls, angle_threshold=DEFAULT_DIRECTION_ANGLE_THRESHOLD | |
) -> sa.ColumnElement[bool]: | |
return sa.and_( | |
cls.shot_direction == gql.ShotDirectionEnum.LEFT, | |
cls.cue_object_angle > angle_threshold, | |
) | |
@classmethod | |
def is_right_by( | |
cls, angle_threshold=DEFAULT_DIRECTION_ANGLE_THRESHOLD | |
) -> sa.ColumnElement[bool]: | |
return sa.and_( | |
cls.shot_direction == gql.ShotDirectionEnum.RIGHT, | |
cls.cue_object_angle > angle_threshold, | |
) | |
@classmethod | |
def __declare_last__(cls): | |
cls.is_straight = deferred(cls.is_straight_by(), info=_qb()) | |
cls.is_left = deferred(cls.is_left_by(), info=_qb()) | |
cls.is_right = deferred(cls.is_right_by(), info=_qb()) | |
# XXX: Commenting this out to get spin type to resolve the column. | |
# Someone can revert this at some point but I couldnt figure it out using qb. | |
# cls.spin_type = deferred(cls.spin_type_by()) | |
class PocketingIntentionFeatures(Base): | |
__tablename__ = "pocketing_intention_features" | |
shot_id: Mapped[int] = mapped_column( | |
sa.BIGINT, | |
sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"), | |
primary_key=True, | |
nullable=False, | |
) | |
target_pocket_angle: Mapped[Optional[float]] = mapped_column( | |
sa.FLOAT, nullable=True, index=True, info=_qb() | |
) | |
target_pocket_angle_direction: Mapped[Optional[gql.ShotDirectionEnum]] = ( | |
mapped_column(ShotDirectionEnum, nullable=True, index=True, info=_qb()) | |
) | |
backcut: Mapped[Optional[bool]] = mapped_column( | |
sa.BOOLEAN, nullable=True, index=True, info=_qb() | |
) | |
target_pocket_distance: Mapped[float] = mapped_column( | |
sa.DECIMAL(precision=7, scale=3), | |
index=True, | |
info=_qb(), | |
) | |
make: Mapped[bool] = mapped_column( | |
sa.BOOLEAN, | |
info={ | |
"query_builder": { | |
"others": [ | |
{ | |
"name": "make_percentage", | |
"selectable_constructor": qbt.QueryBuilderBoolProportionSelectable, | |
} | |
] | |
} | |
}, | |
index=True, | |
nullable=True, | |
) | |
intended_pocket_type: Mapped[gql.PocketEnum] = mapped_column( | |
PocketEnum, | |
index=True, | |
info=_qb(), | |
) | |
difficulty: Mapped[float] = mapped_column( | |
sa.FLOAT, | |
info={ | |
"query_builder": { | |
"others": [ | |
{ | |
"name": "average_difficulty", | |
"selectable_constructor": qbt.QueryBuilderAverageSelectable, | |
} | |
] | |
} | |
}, | |
index=True, | |
nullable=True, | |
) | |
difficulty_git_commit: Mapped[str] = mapped_column( | |
sa.VARCHAR(length=150), | |
nullable=True, | |
default=config.git_commit_hash, | |
) | |
shot: Mapped["ShotModel"] = relationship( | |
"ShotModel", back_populates="pocketing_intention_features" | |
) | |
class ErrorFeatures(Base): | |
__tablename__ = "error_features" | |
errors: Mapped[str] | |
warnings: Mapped[str] | |
shot_id: Mapped[int] = mapped_column( | |
sa.BIGINT, | |
sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"), | |
primary_key=True, | |
nullable=False, | |
) | |
shot: Mapped["ShotModel"] = relationship( | |
"ShotModel", back_populates="error_features" | |
) | |
class BankFeatures(Base): | |
__tablename__ = "bank_features" | |
shot_id: Mapped[int] = mapped_column( | |
sa.BIGINT, | |
sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"), | |
primary_key=True, | |
nullable=False, | |
) | |
walls_hit = mapped_column( | |
ARRAY(WallTypeEnum), | |
index=True, | |
info=_qb( | |
{ | |
"name": "bank_walls_hit", | |
"filter_constructor": qbt.QueryBuilderRangeFilter, | |
} | |
), | |
) | |
bank_angle: Mapped[float] = mapped_column( | |
sa.DECIMAL(precision=7, scale=3), index=True, info=_qb() | |
) | |
distance: Mapped[float] = mapped_column( | |
sa.DECIMAL(precision=7, scale=3), | |
index=True, | |
info=_qb(dict(name="bank_distance")), | |
) | |
shot: Mapped["ShotModel"] = relationship( | |
"ShotModel", back_populates="bank_features" | |
) | |
class KickFeatures(Base): | |
__tablename__ = "kick_features" | |
shot_id: Mapped[int] = mapped_column( | |
sa.BIGINT, | |
sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"), | |
primary_key=True, | |
nullable=False, | |
) | |
walls_hit = mapped_column(ARRAY(WallTypeEnum), index=True) | |
angle: Mapped[float] = mapped_column( | |
sa.DECIMAL(precision=7, scale=3), | |
index=True, | |
info=_qb(dict(name="kick_angle")), | |
) | |
distance: Mapped[float] = mapped_column( | |
sa.DECIMAL(precision=7, scale=3), | |
index=True, | |
info=_qb(dict(name="kick_distance")), | |
) | |
shot: Mapped["ShotModel"] = relationship( | |
"ShotModel", back_populates="kick_features" | |
) | |
DEFAULT_OVER_UNDER_CUT_THRESHOLD = 5 | |
class MissFeatures(Base): | |
__tablename__ = "miss_features" | |
shot_id: Mapped[int] = mapped_column( | |
sa.BIGINT, | |
sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"), | |
primary_key=True, | |
nullable=False, | |
) | |
miss_angle: Mapped[float] = mapped_column( | |
sa.DECIMAL(precision=7, scale=3), | |
index=True, | |
info=_qb(), | |
) | |
@declared_attr | |
def miss_angle_in_degrees(cls) -> Mapped[float]: | |
return deferred(cls.miss_angle * (180.0 / pi), info=_qb()) # type: ignore | |
shot: Mapped["ShotModel"] = relationship( | |
"ShotModel", back_populates="miss_features" | |
) | |
@classmethod | |
def is_undercut_by( | |
cls, angle_threshold=DEFAULT_OVER_UNDER_CUT_THRESHOLD | |
) -> sa.ColumnElement[bool]: | |
return cls.miss_angle_in_degrees <= (-1 * angle_threshold) | |
@classmethod | |
def is_overcut_by( | |
cls, angle_threshold=DEFAULT_OVER_UNDER_CUT_THRESHOLD | |
) -> sa.ColumnElement[bool]: | |
return cls.miss_angle_in_degrees >= angle_threshold | |
@classmethod | |
def is_miss_in_direction_by( | |
cls, | |
angle_threshold=DEFAULT_OVER_UNDER_CUT_THRESHOLD, | |
direction=gql.ShotDirectionEnum.LEFT, | |
): | |
return ( | |
sa.select( | |
sa.or_( | |
sa.and_( | |
CueObjectFeatures.shot_direction == direction, | |
cls.miss_angle_in_degrees >= angle_threshold, | |
), | |
sa.and_( | |
CueObjectFeatures.shot_direction != direction, | |
cls.miss_angle_in_degrees <= -angle_threshold, | |
), | |
) | |
) | |
.where(CueObjectFeatures.shot_id == cls.shot_id) | |
.scalar_subquery() | |
) | |
@classmethod | |
def __declare_last__(cls): | |
cls.is_overcut = deferred(cls.is_overcut_by(), info=_qb()) | |
cls.is_undercut = deferred(cls.is_undercut_by(), info=_qb()) | |
cls.is_left_miss = deferred( | |
cls.is_miss_in_direction_by( | |
direction=gql.ShotDirectionEnum.LEFT, | |
), | |
info=_qb(), | |
) | |
cls.is_right_miss = deferred( | |
cls.is_miss_in_direction_by( | |
direction=gql.ShotDirectionEnum.RIGHT, | |
), | |
info=_qb(), | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment