Created
July 3, 2025 20:40
-
-
Save cheery/c30d30bc9b63f7c84f56e4cf0f3f86ce to your computer and use it in GitHub Desktop.
Rhythm trees
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
from dataclasses import dataclass, field | |
from typing import List, Dict, Optional, Callable, Tuple, Any | |
from fractions import Fraction | |
import itertools | |
class RT: | |
def __repr__(self): | |
return str(self) | |
class Tree(RT): | |
def __init__(self, children=None): | |
self.children = children if children is not None else [] | |
def __iter__(self): | |
return iter(self.children) | |
def __len__(self): | |
return len(self.children) | |
def __getitem__(self, index): | |
return self.children[index] | |
def __setitem__(self, index, value): | |
self.children[index] = value | |
def copy(self): | |
return Tree([c.copy() for c in self]) | |
def __str__(self): | |
return f"({','.join(str(c) for c in self)})" | |
@dataclass(eq=False) | |
class Symbol(RT): | |
name : str | |
def __str__(self): | |
return self.name | |
def __repr__(self): | |
return self.name | |
def copy(self): | |
return self | |
n = Symbol('n') | |
s = Symbol('s') | |
def val(rt, duration): | |
output = [] | |
def visit(rt, duration): | |
duration = duration / len(rt) | |
for x in rt: | |
if isinstance(x, Tree): | |
visit(x, duration) | |
elif x is n or (len(output)==0 and x is s): | |
output.append(duration) | |
elif x is s: | |
output[-1] += duration | |
visit(rt, duration) | |
return output | |
class Nonterminal: | |
def __init__(self, name, prod=None): | |
self.name = name | |
self.prod = [] if prod is None else prod | |
def copy(self): | |
return self | |
def __repr__(self): | |
return str(self.name) | |
class Exhausted(Exception): | |
pass | |
def k_best(k, q): | |
bests = defaultdict(list) | |
cands = defaultdict(list) | |
def init_run(rt, init): | |
output = [] | |
def visit(rt): | |
for x in rt: | |
if isinstance(x, Tree): | |
visit(x) | |
elif isinstance(x, Nonterminal): | |
output.append((init(x), 0)) | |
visit(rt) | |
return output | |
def initial(q): | |
if q in cands: | |
return q | |
for w, x in q.prod: | |
if isinstance(x, Symbol): | |
cands[q].append((w, w, (), x)) | |
elif isinstance(x, Tree): | |
cands[q].append((None, w, init_run(x, initial), x)) | |
return q | |
initial(q) | |
def best(k, q): | |
while len(cands[q]) > 0 and k >= len(bests[q]): | |
cand = [] | |
for w, W, run, x in cands[q]: | |
if w is None: | |
try: | |
w = W + sum(best(i,r)[0] for r,i in run) | |
except Exhausted: | |
continue | |
cand.append((w, W, run, x)) | |
cand.sort(key=lambda x: x[0]) | |
if len(cand) > 0: | |
w, W, run, x = cand.pop(0) | |
bests[q].append((w, run, x)) | |
for j in range(len(run)): | |
cand.append((None, W, [(r,i + 1*(j==j1)) for j1,(r,i) in enumerate(run)], x)) | |
cands[q] = cand | |
if k < len(bests[q]): | |
return bests[q][k] | |
raise Exhausted | |
def rewrite(rt, pattern): | |
if isinstance(rt, Symbol): | |
return rt | |
if isinstance(rt, Tree): | |
out = [] | |
for x in rt: | |
out.append(rewrite(x, pattern)) | |
return Tree(out) | |
if isinstance(rt, Nonterminal): | |
r, i = next(pattern) | |
_, run, x = best(i, r) | |
return rewrite(x, iter(run)) | |
out = [] | |
for i in range(k): | |
try: | |
w, run, x = best(i, q) | |
out.append((w, rewrite(x, iter(run)))) | |
except Exhausted: | |
return out | |
return out | |
def equivalent_rhythms(nt, ioi): | |
tbl = {} | |
def produce(name): | |
if name in tbl: | |
return tbl[name] | |
tbl[name] = nt = Nonterminal(name) | |
head, q, t = name | |
for w,x in q.prod: | |
if x is head and len(t) == 1: | |
nt.prod.append((w,x)) | |
elif isinstance(x, Tree) and (p := partition(head, x, t)): | |
nt.prod.append((w,p)) | |
return nt | |
def partition(head, q, t): | |
items = iter(q) | |
target = sum(t) / len(q) | |
current = 0 | |
part = [] | |
parts = [] | |
for d in t: | |
current += d | |
part.append(d) | |
while current > target: | |
p = current - target | |
part[-1] -= p | |
parts.append(produce((head, next(items), tuple(part)))) | |
head = s | |
current = p | |
part = [p] | |
if current == target: | |
parts.append(produce((head, next(items), tuple(part)))) | |
head = n | |
current = 0 | |
part = [] | |
return Tree(parts) | |
return produce((n, nt, tuple(ioi))) | |
q1 = Nonterminal('q1') | |
q2 = Nonterminal('q2') | |
q3 = Nonterminal('q3') | |
q4 = Nonterminal('q4') | |
q5 = Nonterminal('q5') | |
q1.prod.extend([ | |
(0.1, n), | |
(0.35, Tree([q2, q2])), | |
(0.45, Tree([q3, q3, q3])), | |
]) | |
q2.prod.extend([ | |
(0.2, s), | |
(0.1, n), | |
(0.5, Tree([q4, q4])), | |
(0.6, Tree([q5, q5, q5])), | |
]) | |
q3.prod.extend([ | |
(0.2, s), | |
(0.1, n), | |
(0.5, Tree([q5, q5])), | |
]) | |
q4.prod.extend([ | |
(0.2, s), | |
(0.1, n), | |
(0.75, Tree([q5, q5, q5])), | |
]) | |
q5.prod.extend([ | |
(0.2, s), | |
(0.1, n), | |
]) | |
f = Fraction(1) | |
r1 = equivalent_rhythms(q1, [f/2, f/3, f/3, f/3, f/2]) | |
for w, row in k_best(10, r1): | |
print(w, row, val(row, Fraction(2))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment