Skip to content

Instantly share code, notes, and snippets.

@profsucrose
Created February 1, 2025 02:32
Show Gist options
  • Save profsucrose/698c203f45df8e8e4dbbffbcae9821d7 to your computer and use it in GitHub Desktop.
Save profsucrose/698c203f45df8e8e4dbbffbcae9821d7 to your computer and use it in GitHub Desktop.
from dataclasses import dataclass
from typing import Callable, Any, Literal, Type
from enum import Enum
El = int | float | str
S = list[El]
BoolMat = list[list[bool]]
def default_for_type(t: Type) -> El:
if isinstance(t, int):
return 0
elif isinstance(t, float):
return 0.0
else: # is str
return "_"
def aggregate(mat: BoolMat, vs: S) -> S:
v_type = type(vs[0])
aggregated_vs = [None] * len(vs)
for i, mask in enumerate(mat):
masked = [v for m, v in zip(mask, vs) if m]
if len(masked) == 0:
aggregated = 0
elif len(masked) == 1:
aggregated = masked[0]
else:
assert v_type in (float, int), f"Tried to sum type {v_type.__name__}"
aggregated = sum(masked) / len(masked)
aggregated_vs[i] = aggregated
return aggregated_vs
def apply_sel_inner(sel: "Sel", seq: S) -> BoolMat:
match sel:
case KQ(key, query, pred):
ks = apply_sop(key.op, seq)
qs = apply_sop(query.op, seq)
return [[pred(k, q) for k in ks] for q in qs]
case BinSel(lhs, rhs, pred):
ls = lhs(seq).mat
rs = rhs(seq).mat
return [[pred(l, r) for l, r in zip(lrow, rrow)] for lrow, rrow in zip(ls, rs)]
def apply_sel(sel: "Sel", seq: S) -> "Selector":
return Selector(apply_sel_inner(sel, seq))
def apply_sop(op: Literal["SOp"] | S | El, seq: S) -> S:
match op:
case Symbol(name):
match name:
case "tokens":
return seq
case "indices":
return list(range(len(seq)))
case "length":
return [len(seq)] * len(seq)
case _:
raise f"Unexpected symbol: {name=}"
case BinSeq(lhs, rhs, op):
ls = apply_sop(lhs, seq)
rs = apply_sop(rhs, seq)
if isinstance(op, Arith):
return [op(l, r) for l, r in zip(ls, rs)]
else: # is BinPred
return [int(op(l, r)) for l, r in zip(ls, rs)]
case Un(seq, pred):
return [pred(x) for x in seq]
case UnMap(seq, fn):
return [fn(x) for x in seq]
case Selector(mat):
return aggregate(mat, seq)
case Select(sel, value):
selector = apply_sel(sel, seq)
vs = apply_sop(value, seq)
return apply_sop(selector, vs)
case SelectorWidth(sel):
mat = apply_sel(sel, seq).mat
return [sum(row) for row in mat]
case _ if isinstance(op, (int, float)):
return [op] * len(seq)
case _ if isinstance(op, str):
if len(op) == 1:
return op * len(seq)
else:
assert len(op) == len(seq), f"lengths don't equal: {op=} vs. {seq=}"
return op
case _:
assert False, f"Unexpected op: {op}"
class BinElOp(Enum):
pass
class Arith(BinElOp):
ADD = "+"
SUB = "-"
MUL = "*"
DIV = "/"
def __call__(self, x: El, y: El) -> El:
if x_is_str := isinstance(x, str):
x_str = x
x = ord(x)
if y_is_str := isinstance(y, str):
y_str = y
y = ord(y)
match self:
case Arith.ADD:
z = x + y
case Arith.SUB:
z = x - y
case Arith.MUL:
z = x * y
case Arith.DIV:
z = x / y
if x_is_str or y_is_str:
return chr(z)
else:
return z
class BinPred(BinElOp):
EQ = "=="
GT = ">"
LT = "<"
GE = ">="
LE = "<="
AND = "&"
OR = "|"
def __call__(self, x: El, y: El) -> bool:
match self:
case BinPred.EQ:
return x == y
case BinPred.GT:
return x > y
case BinPred.LT:
return x < y
case BinPred.GE:
return x >= y
case BinPred.LE:
return x <= y
case BinPred.AND:
return x and y
case BinPred.OR:
return x or y
def SOp_bin_met(op: "BinElOp") -> Callable[["SOp", "SOp"], "BinSeq"]:
def method(self: "SOp", rhs: "SOp") -> "BinSeq":
return BinSeq(self, rhs, op)
return method
@dataclass
class SOp:
def __call__(self, seq: S) -> S:
return apply_sop(self, seq)
__add__ = SOp_bin_met(Arith.ADD)
__sub__ = SOp_bin_met(Arith.SUB)
__mul__ = SOp_bin_met(Arith.MUL)
__div__ = SOp_bin_met(Arith.DIV)
__eq__ = SOp_bin_met(BinPred.EQ)
__gt__ = SOp_bin_met(BinPred.GT)
__lt__ = SOp_bin_met(BinPred.LT)
__ge__ = SOp_bin_met(BinPred.GE)
__le__ = SOp_bin_met(BinPred.LE)
__and__ = SOp_bin_met(BinPred.AND)
__or__ = SOp_bin_met(BinPred.OR)
@dataclass
class Symbol(SOp):
name: str
def __repr__(self): return self.name
indices = Symbol("indices")
length = Symbol("length")
tokens = Symbol("tokens")
class UnPred(Enum):
INVERT = "~"
def __call__(self, x: bool) -> bool:
match self:
case PredUn.INVERT:
return not x
@dataclass
class BinSeq(SOp):
lhs: SOp
rhs: SOp
op: BinElOp
@dataclass
class Un(SOp):
seq: SOp
pred: UnPred
@dataclass
class UnMap(SOp):
fn: Callable[[El], El]
seq: SOp
smap = UnMap
@dataclass
class Selector(SOp):
mat: BoolMat
def __repr__(self):
return "[" + ",\n ".join("[" + ", ".join("1" if x else "0" for x in row) + "]" for row in self.mat) + "]"
@dataclass
class Sel:
__call__ = apply_sel
def __invert__(self) -> Un:
return Un(self, UnPred.INVERT)
def value(self, op: SOp) -> "Select":
return Select(self, op)
def __or__(self, rhs: "Sel") -> "BinSel":
return BinSel(self, rhs, pred=BinPred.OR)
def __and__(self, rhs: "Sel") -> "BinSel":
return BinSel(self, rhs, pred=BinPred.AND)
@dataclass
class SelectorWidth(SOp):
sel: Sel
selector_width = SelectorWidth
@dataclass
class KQ(Sel):
key: SOp
query: SOp
pred: BinPred
def KeyQuery_bin_met(op: BinPred) -> Callable[["KeyQuery", "KeyQuery"], "KQ"]:
def method(self: "KeyQuery", other: "KeyQuery") -> "KQ":
return self.bind(other, op)
return method
@dataclass
class KeyQuery:
op: SOp
def bind(self, other: "KeyQuery", pred: BinPred) -> "KQ":
if isinstance(self, Key):
assert isinstance(other, Query), f"Can only attend Key with Query, got {other}"
key, query = self, other
else: # self is Query
assert isinstance(other, Key), f"Can only attend Query with Key, got {other}"
key, query = other, self
return KQ(key, query, pred)
__eq__ = KeyQuery_bin_met(BinPred.EQ)
__gt__ = KeyQuery_bin_met(BinPred.GT)
__lt__ = KeyQuery_bin_met(BinPred.LT)
__ge__ = KeyQuery_bin_met(BinPred.GE)
__le__ = KeyQuery_bin_met(BinPred.LE)
__and__ = KeyQuery_bin_met(BinPred.AND)
__or__ = KeyQuery_bin_met(BinPred.OR)
class Key(KeyQuery):
pass
key = Key
class Query(KeyQuery):
pass
query = Query
@dataclass
class BinSel(Sel):
lhs: Sel
rhs: Sel
pred: BinPred
@dataclass
class UnSel(Sel):
lhs: Sel
rhs: Sel
pred: UnPred
@dataclass
class Select(SOp):
sel: Sel
value: SOp
"""
Select(
selector=BinSel(
lhs=KQ(
key=indices,
query=indices,
),
rhs=KQ(
key=indices,
query=indices,
),
op=BinOp.OR,
),
value=BinSeq(
lhs=tokens,
rhs="x",
op=BinOp.EQ
),
)
Attention(
key=indices,
query=indices,
value=Bin(
lhs=tokens,
rhs=Constant("x"),
op=BinOp.EQ
),
)
"""
def sort():
argsort = selector_width(key(tokens) < query(tokens))
return (key(argsort) == query(indices)).value(tokens)
def move():
# [0, 0, 0, 1, 0, 0]
# -> [0, 0, 0, 0, 1, 0]
dir_sel = (key(indices) == query(0))
dir = dir_sel.value(tokens)
before_state = ((key(indices) == query(indices)) & (key(indices) <= query(0))).value(tokens)
state = ((key(indices) == query(indices)) & (key(indices) > query(0))).value(tokens)
sh_state = (key(indices + dir) == query(indices)).value(state)
return before_state | sh_state
# TODO
def snake(size: int):
"""
[ <x>, <y>, <dx>, <dy>, <score>,
0, 0,
0, 0,
0, 0,
0, 0,
1, 2,
3, 2,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 1, 0, 0, 0,
0, 0, 1, 1, 0, 0, 0,
0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
]
"""
new_x = x + 1
new_y = y + 1
head = (key(new_x) == query(indices)) & (key(indices) == query(new_y))
new_board = board | head
eye = key(indices) == query(indices)
def sel_row(where):
return key(indices) == query(where)
def sel(x):
return (eye & sel_row(x)).value(tokens)
new_state = (
(sel(0) * new_x)
+ (sel(1) * new_y)
)
return new_state + new_board
# seq = [5, 0.3, -0.3, 10]
# print(sort()(seq))
seq = [2, 0, 0, 1, 0, 0]
print(move())
print(move()(seq))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment