Created
February 1, 2025 02:32
-
-
Save profsucrose/698c203f45df8e8e4dbbffbcae9821d7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from dataclasses import dataclass | |
| from typing import Callable, Any, Literal, Type | |
| from enum import Enum | |
| El = int | float | str | |
| S = list[El] | |
| BoolMat = list[list[bool]] | |
| def default_for_type(t: Type) -> El: | |
| if isinstance(t, int): | |
| return 0 | |
| elif isinstance(t, float): | |
| return 0.0 | |
| else: # is str | |
| return "_" | |
| def aggregate(mat: BoolMat, vs: S) -> S: | |
| v_type = type(vs[0]) | |
| aggregated_vs = [None] * len(vs) | |
| for i, mask in enumerate(mat): | |
| masked = [v for m, v in zip(mask, vs) if m] | |
| if len(masked) == 0: | |
| aggregated = 0 | |
| elif len(masked) == 1: | |
| aggregated = masked[0] | |
| else: | |
| assert v_type in (float, int), f"Tried to sum type {v_type.__name__}" | |
| aggregated = sum(masked) / len(masked) | |
| aggregated_vs[i] = aggregated | |
| return aggregated_vs | |
| def apply_sel_inner(sel: "Sel", seq: S) -> BoolMat: | |
| match sel: | |
| case KQ(key, query, pred): | |
| ks = apply_sop(key.op, seq) | |
| qs = apply_sop(query.op, seq) | |
| return [[pred(k, q) for k in ks] for q in qs] | |
| case BinSel(lhs, rhs, pred): | |
| ls = lhs(seq).mat | |
| rs = rhs(seq).mat | |
| return [[pred(l, r) for l, r in zip(lrow, rrow)] for lrow, rrow in zip(ls, rs)] | |
| def apply_sel(sel: "Sel", seq: S) -> "Selector": | |
| return Selector(apply_sel_inner(sel, seq)) | |
| def apply_sop(op: Literal["SOp"] | S | El, seq: S) -> S: | |
| match op: | |
| case Symbol(name): | |
| match name: | |
| case "tokens": | |
| return seq | |
| case "indices": | |
| return list(range(len(seq))) | |
| case "length": | |
| return [len(seq)] * len(seq) | |
| case _: | |
| raise f"Unexpected symbol: {name=}" | |
| case BinSeq(lhs, rhs, op): | |
| ls = apply_sop(lhs, seq) | |
| rs = apply_sop(rhs, seq) | |
| if isinstance(op, Arith): | |
| return [op(l, r) for l, r in zip(ls, rs)] | |
| else: # is BinPred | |
| return [int(op(l, r)) for l, r in zip(ls, rs)] | |
| case Un(seq, pred): | |
| return [pred(x) for x in seq] | |
| case UnMap(seq, fn): | |
| return [fn(x) for x in seq] | |
| case Selector(mat): | |
| return aggregate(mat, seq) | |
| case Select(sel, value): | |
| selector = apply_sel(sel, seq) | |
| vs = apply_sop(value, seq) | |
| return apply_sop(selector, vs) | |
| case SelectorWidth(sel): | |
| mat = apply_sel(sel, seq).mat | |
| return [sum(row) for row in mat] | |
| case _ if isinstance(op, (int, float)): | |
| return [op] * len(seq) | |
| case _ if isinstance(op, str): | |
| if len(op) == 1: | |
| return op * len(seq) | |
| else: | |
| assert len(op) == len(seq), f"lengths don't equal: {op=} vs. {seq=}" | |
| return op | |
| case _: | |
| assert False, f"Unexpected op: {op}" | |
| class BinElOp(Enum): | |
| pass | |
| class Arith(BinElOp): | |
| ADD = "+" | |
| SUB = "-" | |
| MUL = "*" | |
| DIV = "/" | |
| def __call__(self, x: El, y: El) -> El: | |
| if x_is_str := isinstance(x, str): | |
| x_str = x | |
| x = ord(x) | |
| if y_is_str := isinstance(y, str): | |
| y_str = y | |
| y = ord(y) | |
| match self: | |
| case Arith.ADD: | |
| z = x + y | |
| case Arith.SUB: | |
| z = x - y | |
| case Arith.MUL: | |
| z = x * y | |
| case Arith.DIV: | |
| z = x / y | |
| if x_is_str or y_is_str: | |
| return chr(z) | |
| else: | |
| return z | |
| class BinPred(BinElOp): | |
| EQ = "==" | |
| GT = ">" | |
| LT = "<" | |
| GE = ">=" | |
| LE = "<=" | |
| AND = "&" | |
| OR = "|" | |
| def __call__(self, x: El, y: El) -> bool: | |
| match self: | |
| case BinPred.EQ: | |
| return x == y | |
| case BinPred.GT: | |
| return x > y | |
| case BinPred.LT: | |
| return x < y | |
| case BinPred.GE: | |
| return x >= y | |
| case BinPred.LE: | |
| return x <= y | |
| case BinPred.AND: | |
| return x and y | |
| case BinPred.OR: | |
| return x or y | |
| def SOp_bin_met(op: "BinElOp") -> Callable[["SOp", "SOp"], "BinSeq"]: | |
| def method(self: "SOp", rhs: "SOp") -> "BinSeq": | |
| return BinSeq(self, rhs, op) | |
| return method | |
| @dataclass | |
| class SOp: | |
| def __call__(self, seq: S) -> S: | |
| return apply_sop(self, seq) | |
| __add__ = SOp_bin_met(Arith.ADD) | |
| __sub__ = SOp_bin_met(Arith.SUB) | |
| __mul__ = SOp_bin_met(Arith.MUL) | |
| __div__ = SOp_bin_met(Arith.DIV) | |
| __eq__ = SOp_bin_met(BinPred.EQ) | |
| __gt__ = SOp_bin_met(BinPred.GT) | |
| __lt__ = SOp_bin_met(BinPred.LT) | |
| __ge__ = SOp_bin_met(BinPred.GE) | |
| __le__ = SOp_bin_met(BinPred.LE) | |
| __and__ = SOp_bin_met(BinPred.AND) | |
| __or__ = SOp_bin_met(BinPred.OR) | |
| @dataclass | |
| class Symbol(SOp): | |
| name: str | |
| def __repr__(self): return self.name | |
| indices = Symbol("indices") | |
| length = Symbol("length") | |
| tokens = Symbol("tokens") | |
| class UnPred(Enum): | |
| INVERT = "~" | |
| def __call__(self, x: bool) -> bool: | |
| match self: | |
| case PredUn.INVERT: | |
| return not x | |
| @dataclass | |
| class BinSeq(SOp): | |
| lhs: SOp | |
| rhs: SOp | |
| op: BinElOp | |
| @dataclass | |
| class Un(SOp): | |
| seq: SOp | |
| pred: UnPred | |
| @dataclass | |
| class UnMap(SOp): | |
| fn: Callable[[El], El] | |
| seq: SOp | |
| smap = UnMap | |
| @dataclass | |
| class Selector(SOp): | |
| mat: BoolMat | |
| def __repr__(self): | |
| return "[" + ",\n ".join("[" + ", ".join("1" if x else "0" for x in row) + "]" for row in self.mat) + "]" | |
| @dataclass | |
| class Sel: | |
| __call__ = apply_sel | |
| def __invert__(self) -> Un: | |
| return Un(self, UnPred.INVERT) | |
| def value(self, op: SOp) -> "Select": | |
| return Select(self, op) | |
| def __or__(self, rhs: "Sel") -> "BinSel": | |
| return BinSel(self, rhs, pred=BinPred.OR) | |
| def __and__(self, rhs: "Sel") -> "BinSel": | |
| return BinSel(self, rhs, pred=BinPred.AND) | |
| @dataclass | |
| class SelectorWidth(SOp): | |
| sel: Sel | |
| selector_width = SelectorWidth | |
| @dataclass | |
| class KQ(Sel): | |
| key: SOp | |
| query: SOp | |
| pred: BinPred | |
| def KeyQuery_bin_met(op: BinPred) -> Callable[["KeyQuery", "KeyQuery"], "KQ"]: | |
| def method(self: "KeyQuery", other: "KeyQuery") -> "KQ": | |
| return self.bind(other, op) | |
| return method | |
| @dataclass | |
| class KeyQuery: | |
| op: SOp | |
| def bind(self, other: "KeyQuery", pred: BinPred) -> "KQ": | |
| if isinstance(self, Key): | |
| assert isinstance(other, Query), f"Can only attend Key with Query, got {other}" | |
| key, query = self, other | |
| else: # self is Query | |
| assert isinstance(other, Key), f"Can only attend Query with Key, got {other}" | |
| key, query = other, self | |
| return KQ(key, query, pred) | |
| __eq__ = KeyQuery_bin_met(BinPred.EQ) | |
| __gt__ = KeyQuery_bin_met(BinPred.GT) | |
| __lt__ = KeyQuery_bin_met(BinPred.LT) | |
| __ge__ = KeyQuery_bin_met(BinPred.GE) | |
| __le__ = KeyQuery_bin_met(BinPred.LE) | |
| __and__ = KeyQuery_bin_met(BinPred.AND) | |
| __or__ = KeyQuery_bin_met(BinPred.OR) | |
| class Key(KeyQuery): | |
| pass | |
| key = Key | |
| class Query(KeyQuery): | |
| pass | |
| query = Query | |
| @dataclass | |
| class BinSel(Sel): | |
| lhs: Sel | |
| rhs: Sel | |
| pred: BinPred | |
| @dataclass | |
| class UnSel(Sel): | |
| lhs: Sel | |
| rhs: Sel | |
| pred: UnPred | |
| @dataclass | |
| class Select(SOp): | |
| sel: Sel | |
| value: SOp | |
| """ | |
| Select( | |
| selector=BinSel( | |
| lhs=KQ( | |
| key=indices, | |
| query=indices, | |
| ), | |
| rhs=KQ( | |
| key=indices, | |
| query=indices, | |
| ), | |
| op=BinOp.OR, | |
| ), | |
| value=BinSeq( | |
| lhs=tokens, | |
| rhs="x", | |
| op=BinOp.EQ | |
| ), | |
| ) | |
| Attention( | |
| key=indices, | |
| query=indices, | |
| value=Bin( | |
| lhs=tokens, | |
| rhs=Constant("x"), | |
| op=BinOp.EQ | |
| ), | |
| ) | |
| """ | |
| def sort(): | |
| argsort = selector_width(key(tokens) < query(tokens)) | |
| return (key(argsort) == query(indices)).value(tokens) | |
| def move(): | |
| # [0, 0, 0, 1, 0, 0] | |
| # -> [0, 0, 0, 0, 1, 0] | |
| dir_sel = (key(indices) == query(0)) | |
| dir = dir_sel.value(tokens) | |
| before_state = ((key(indices) == query(indices)) & (key(indices) <= query(0))).value(tokens) | |
| state = ((key(indices) == query(indices)) & (key(indices) > query(0))).value(tokens) | |
| sh_state = (key(indices + dir) == query(indices)).value(state) | |
| return before_state | sh_state | |
| # TODO | |
| def snake(size: int): | |
| """ | |
| [ <x>, <y>, <dx>, <dy>, <score>, | |
| 0, 0, | |
| 0, 0, | |
| 0, 0, | |
| 0, 0, | |
| 1, 2, | |
| 3, 2, | |
| 0, 0, 0, 0, 0, 0, 0, | |
| 0, 0, 0, 0, 0, 0, 0, | |
| 0, 0, 0, 0, 0, 0, 0, | |
| 0, 0, 0, 1, 0, 0, 0, | |
| 0, 0, 1, 1, 0, 0, 0, | |
| 0, 1, 1, 0, 0, 0, 0, | |
| 0, 0, 0, 0, 0, 0, 0, | |
| ] | |
| """ | |
| new_x = x + 1 | |
| new_y = y + 1 | |
| head = (key(new_x) == query(indices)) & (key(indices) == query(new_y)) | |
| new_board = board | head | |
| eye = key(indices) == query(indices) | |
| def sel_row(where): | |
| return key(indices) == query(where) | |
| def sel(x): | |
| return (eye & sel_row(x)).value(tokens) | |
| new_state = ( | |
| (sel(0) * new_x) | |
| + (sel(1) * new_y) | |
| ) | |
| return new_state + new_board | |
| # seq = [5, 0.3, -0.3, 10] | |
| # print(sort()(seq)) | |
| seq = [2, 0, 0, 1, 0, 0] | |
| print(move()) | |
| print(move()(seq)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment