Skip to content

Instantly share code, notes, and snippets.

@kayuksel
Last active June 25, 2026 01:46
Show Gist options
  • Select an option

  • Save kayuksel/cbb304038471befead2c1926610e9ff4 to your computer and use it in GitHub Desktop.

Select an option

Save kayuksel/cbb304038471befead2c1926610e9ff4 to your computer and use it in GitHub Desktop.
EvoForest-WM: Discovered Neuro-Symbolic Foundational World Model for Multivariate Time-Series
import math
import torch
import torch.nn.functional as F
L = 64
def _g(seed):
return torch.Generator().manual_seed(seed)
class WMTransform:
"""The frozen TS-WM patch encoder phi: (B, L) -> (B, 141). Pure closed form over seeded banks."""
def __init__(self, L=L, device="cpu", dtype=torch.float32):
self.L = L
self.device = device
# --- seeded random-projection banks (the only stochastic content; fixed integer seeds) ---
self.trf_mu = torch.rand(12, generator=_g(11)) # Gaussian-window centers in [0,1]
self.trf_sig = torch.exp(torch.rand(12, generator=_g(12)) * 2.59 - 3.51) # window widths
self.srf_W = torch.randn(12, 6, 12, generator=_g(21)) / 3.4641 # random MLP on the 12 stats
self.srf_b = torch.randn(12, 6, generator=_g(22))
self.srf_u = torch.randn(12, 6, generator=_g(23)) / 2.4495
self.crf_w = torch.randn(16, 1, 9, generator=_g(31)) # 16 random conv kernels (size 9)
for k in ("trf_mu", "trf_sig", "srf_W", "srf_b", "srf_u", "crf_w"):
setattr(self, k, getattr(self, k).to(device=device, dtype=dtype))
# the column width and ordered builders of the 18 families
self._families = [
("stats", 12), ("srf_mlp", 12), ("autocorr", 2), ("spectral", 2), ("turning", 2),
("trf_gausswin", 12), ("crf_ppv", 32), ("hilbert_env", 2), ("crf_max", 32),
("morphology_updown", 4), ("fftbands", 6), ("perm_entropy", 3), ("curvature", 3),
("conv_position", 3), ("ar_residual", 4), ("acf_first_min", 3), ("histogram_mode", 3),
("ricker_wavelet", 4),
]
# ----- family 0: stats (12) — also consumed by srf_mlp -----
def stats(self, w):
m = w.mean(1); s = w.std(1).clamp(min=1e-8); c = w - w.mean(1, keepdim=True)
return torch.stack([
m, s,
(c ** 3).mean(1) / (s ** 3 + 1e-8), # skew
(c ** 4).mean(1) / (s ** 4 + 1e-8) - 3.0, # excess kurtosis
torch.quantile(w, 0.25, dim=1), torch.quantile(w, 0.5, dim=1), torch.quantile(w, 0.75, dim=1),
torch.quantile(w, 0.75, dim=1) - torch.quantile(w, 0.25, dim=1), # IQR
(c[:, :-1] * c[:, 1:]).sum(1) / ((c[:, :-1] ** 2).sum(1).sqrt() * (c[:, 1:] ** 2).sum(1).sqrt() + 1e-8),
(w[:, 1:] - w[:, :-1]).abs().mean(1), w.amin(1), w.amax(1),
], dim=1)
# ----- family 1: random-MLP projection of the stats vector (12) -----
def srf_mlp(self, w, st):
h = torch.relu(torch.einsum("nm,kdm->nkd", st, self.srf_W) + self.srf_b.unsqueeze(0))
return (h * self.srf_u.unsqueeze(0)).sum(2)
# ----- family 2: autocorrelation at lags 1,2 (2) -----
def autocorr(self, w):
c = w - w.mean(1, keepdim=True)
return torch.stack([
(c[:, :-1] * c[:, 1:]).sum(1) / ((c[:, :-1] ** 2).sum(1) + 1e-8),
(c[:, :-2] * c[:, 2:]).sum(1) / ((c[:, :-2] ** 2).sum(1) + 1e-8),
], dim=1)
# ----- family 3: spectral centroid + entropy (2) -----
def spectral(self, w):
p = torch.fft.rfft(w, dim=1).abs()[:, 1:]
p = p / (p.sum(1, keepdim=True) + 1e-8)
f = torch.arange(p.shape[1], device=w.device, dtype=torch.float32)
return torch.stack([(f[None, :] * p).sum(1), -(p * (p + 1e-12).log()).sum(1)], dim=1)
# ----- family 4: turning-point rate + fraction-positive-diff (2) -----
def turning(self, w):
dw = w[:, 1:] - w[:, :-1]
return torch.stack([(dw[:, 1:] * dw[:, :-1] < 0).float().mean(1), (dw > 0).float().mean(1)], dim=1)
# ----- family 5: Gaussian-window weighted means at 12 time centers (12) -----
def trf_gausswin(self, w):
t = torch.arange(w.shape[1], device=w.device, dtype=torch.float32) / w.shape[1]
win = torch.exp(-((t[None, :] - self.trf_mu[:, None]) ** 2) / (2 * self.trf_sig[:, None] ** 2 + 1e-8))
return w @ (win / (win.sum(1, keepdim=True) + 1e-8)).transpose(0, 1)
# ----- family 6: conv positive-fraction (PPV) at dilations 2,4 (32) -----
def crf_ppv(self, w):
kw = self.crf_w.shape[2]
o1 = F.conv1d(w.unsqueeze(1), self.crf_w, padding=2 * (kw // 2), dilation=2)
o2 = F.conv1d(w.unsqueeze(1), self.crf_w, padding=4 * (kw // 2), dilation=4)
return torch.cat([(o1 > 0).float().mean(2), (o2 > 0).float().mean(2)], dim=1)
# ----- family 7: analytic-signal (Hilbert) envelope stats (2) -----
def hilbert_env(self, w):
n = w.shape[1]; k = torch.arange(n, device=w.device, dtype=torch.float32)
h = torch.where(k == 0, 1.0, torch.where(k < n / 2, 2.0, torch.where(k == n // 2, 1.0, 0.0)))
env = torch.fft.ifft(torch.fft.fft(w, dim=1) * h[None, :], dim=1).abs()
return torch.stack([env.std(1) / (env.mean(1) + 1e-8),
(env[:, 1:] - env[:, :-1]).abs().mean(1) / (env.mean(1) + 1e-8)], dim=1)
# ----- family 8: conv max-pool at dilations 2,4 (32) -----
def crf_max(self, w):
kw = self.crf_w.shape[2]
return torch.cat([F.conv1d(w.unsqueeze(1), self.crf_w, padding=d * (kw // 2), dilation=d).amax(2)
for d in (2, 4)], dim=1)
# ----- family 9: up/down-stroke morphology asymmetry (4) -----
def morphology_updown(self, w):
dw = w[:, 1:] - w[:, :-1]; p = dw.clamp(min=0); n = (-dw).clamp(min=0)
return torch.stack([
p.amax(1), n.amax(1),
((p.amax(1) + 1e-6) / (n.amax(1) + 1e-6)).log(),
(((dw.clamp(min=0) ** 2).sum(1) + 1e-6) / ((dw.clamp(max=0) ** 2).sum(1) + 1e-6)).log(),
], dim=1)
# ----- family 10: binned FFT log band power, 6 bands (6) -----
def fftbands(self, w):
p = (torch.fft.rfft(w, dim=1).abs()[:, 1:]) ** 2
pn = p / (p.sum(1, keepdim=True) + 1e-8)
return torch.stack([b.sum(1) for b in pn.split(6, dim=1)], dim=1).clamp(min=1e-8).log()
# ----- family 11: permutation entropy (len-3) + Hjorth mobility/complexity (3) -----
def perm_entropy(self, w):
code = (w[:, :-2] < w[:, 1:-1]).long() * 4 + (w[:, 1:-1] < w[:, 2:]).long() * 2 + (w[:, :-2] < w[:, 2:]).long()
H = torch.stack([(code == k).float().mean(1) for k in range(8)], dim=1)
d1 = w[:, 1:] - w[:, :-1]; d2 = w[:, 2:] - 2 * w[:, 1:-1] + w[:, :-2]
mob = ((d1.var(1) + 1e-8) / (w.var(1) + 1e-8)).sqrt()
comp = ((d2.var(1) + 1e-8) / (d1.var(1) + 1e-8)).sqrt() / (mob + 1e-8)
return torch.stack([-(H * (H + 1e-12).log()).sum(1) / math.log(6.0), mob, comp], dim=1)
# ----- family 12: local curvature via 2nd difference (3) -----
def curvature(self, w):
c = w[:, 2:] - 2 * w[:, 1:-1] + w[:, :-2]
return torch.stack([c.abs().mean(1), c.abs().amax(1),
(((c.clamp(min=0) ** 2).sum(1) + 1e-6) / ((c.clamp(max=0) ** 2).sum(1) + 1e-6)).log()], dim=1)
# ----- family 13: conv max-response position summary (3) -----
def conv_position(self, w):
kw = self.crf_w.shape[2]
pos = F.conv1d(w.unsqueeze(1), self.crf_w, padding=2 * (kw // 2), dilation=2).argmax(2).float() / w.shape[1]
return torch.stack([pos.mean(1), pos.std(1), pos.amax(1) - pos.amin(1)], dim=1)
# ----- family 14: AR(2)+AR(3) ridge-fit residual energy + coeff norm (4) -----
def ar_residual(self, w):
c = w - w.mean(1, keepdim=True)
def fit(X, y):
A = X.transpose(1, 2) @ X + 1e-3 * torch.eye(X.shape[2], device=w.device)
beta = torch.linalg.solve(A, X.transpose(1, 2) @ y.unsqueeze(2)).squeeze(2)
resid = (y - (X @ beta.unsqueeze(2)).squeeze(2)).var(1)
return torch.stack([(resid + 1e-8).log(), (beta ** 2).sum(1).clamp(min=1e-8).log()], 1)
ar2 = fit(torch.stack([c[:, 1:-1], c[:, :-2]], 2), c[:, 2:])
ar3 = fit(torch.stack([c[:, 2:-1], c[:, 1:-2], c[:, :-3]], 2), c[:, 3:])
return torch.cat([ar2, ar3], 1)
# ----- family 15: first-minimum / first-zero of the autocorrelation function (3) -----
def acf_first_min(self, w):
c = w - w.mean(1, keepdim=True); n = w.shape[1]
acf = torch.stack([(c[:, :n - k] * c[:, k:]).sum(1) / ((c ** 2).sum(1) + 1e-8) for k in range(1, 33)], dim=1)
ismin = (acf[:, 1:-1] < acf[:, :-2]) & (acf[:, 1:-1] <= acf[:, 2:])
has = ismin.any(1); fm = ismin.float().argmax(1)
first_min = torch.where(has, fm.float() + 2.0, torch.tensor(float(acf.shape[1]), device=w.device)) / n
first_zero = ((acf < 0).cumsum(1).clamp(max=1) * torch.arange(acf.shape[1], 0, -1, device=w.device)[None, :]).argmax(1).float() / n
val_after = acf.gather(1, (fm + 1).clamp(max=acf.shape[1] - 1).unsqueeze(1)).squeeze(1)
return torch.stack([first_min, first_zero, val_after], dim=1)
# ----- family 16: distribution mode (soft) + mode-median + central mass (3) -----
def histogram_mode(self, w):
z = (w - w.mean(1, keepdim=True)) / w.std(1, keepdim=True).clamp(min=1e-6)
ctr = torch.linspace(-2.5, 2.5, 10, device=w.device)
soft = (4.0 * ((-0.5 * ((ctr[None, None, :] - z.unsqueeze(2)) / (5.0 / 9.0)) ** 2).exp().sum(1))).softmax(1)
m10 = soft.mul(ctr[None, :]).sum(1)
return torch.stack([m10, m10 - z.median(1).values, (z.abs() < 0.5).float().mean(1)], dim=1)
# ----- family 17: Ricker (Mexican-hat) wavelet PPV at 4 scales (4) -----
def ricker_wavelet(self, w):
x = torch.linspace(-7.0, 7.0, 15, device=w.device)
ks = []
for s in (1.5, 2.5, 4.0, 6.0):
r = (1.0 - (x / s) ** 2) * torch.exp(-(x ** 2) / (2 * s * s))
ks.append((r - r.mean()) / (r.abs().sum() + 1e-8))
W = torch.stack(ks).unsqueeze(1)
return (F.conv1d(w.unsqueeze(1), W, padding=15 // 2) > 0).float().mean(2)
# ----- assemble phi in the discovered order -----
def phi(self, w):
st = self.stats(w)
cols = [
st, # stats
self.srf_mlp(w, st), # srf_mlp
self.autocorr(w), self.spectral(w), self.turning(w),
self.trf_gausswin(w), self.crf_ppv(w), self.hilbert_env(w), self.crf_max(w),
self.morphology_updown(w), self.fftbands(w), self.perm_entropy(w), self.curvature(w),
self.conv_position(w), self.ar_residual(w), self.acf_first_min(w), self.histogram_mode(w),
self.ricker_wavelet(w),
]
return torch.cat([c if c.dim() == 2 else c.unsqueeze(1) for c in cols], dim=1)
__call__ = phi
@property
def families(self):
return list(self._families)
@property
def n_cols(self):
return sum(c for _, c in self._families)
@kayuksel

Copy link
Copy Markdown
Author

POOL_MAP_245 = {
"stats": ["std", "max"], "turning": ["std", "max"], "trf_gausswin": ["mean", "max"],
"crf_ppv": ["mean", "std", "max"], "crf_max": ["mean", "max"], "morphology_updown": ["mean"],
"fftbands": ["mean"], "curvature": ["mean", "max"], "conv_position": ["max"], "ar_residual": ["mean"],
"histogram_mode": ["std", "max"], "ricker_wavelet": ["std"]}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment