Last active
June 25, 2026 01:46
-
-
Save kayuksel/cbb304038471befead2c1926610e9ff4 to your computer and use it in GitHub Desktop.
EvoForest-WM: Discovered Neuro-Symbolic Foundational World Model for Multivariate Time-Series
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| L = 64 | |
| def _g(seed): | |
| return torch.Generator().manual_seed(seed) | |
| class WMTransform: | |
| """The frozen TS-WM patch encoder phi: (B, L) -> (B, 141). Pure closed form over seeded banks.""" | |
| def __init__(self, L=L, device="cpu", dtype=torch.float32): | |
| self.L = L | |
| self.device = device | |
| # --- seeded random-projection banks (the only stochastic content; fixed integer seeds) --- | |
| self.trf_mu = torch.rand(12, generator=_g(11)) # Gaussian-window centers in [0,1] | |
| self.trf_sig = torch.exp(torch.rand(12, generator=_g(12)) * 2.59 - 3.51) # window widths | |
| self.srf_W = torch.randn(12, 6, 12, generator=_g(21)) / 3.4641 # random MLP on the 12 stats | |
| self.srf_b = torch.randn(12, 6, generator=_g(22)) | |
| self.srf_u = torch.randn(12, 6, generator=_g(23)) / 2.4495 | |
| self.crf_w = torch.randn(16, 1, 9, generator=_g(31)) # 16 random conv kernels (size 9) | |
| for k in ("trf_mu", "trf_sig", "srf_W", "srf_b", "srf_u", "crf_w"): | |
| setattr(self, k, getattr(self, k).to(device=device, dtype=dtype)) | |
| # the column width and ordered builders of the 18 families | |
| self._families = [ | |
| ("stats", 12), ("srf_mlp", 12), ("autocorr", 2), ("spectral", 2), ("turning", 2), | |
| ("trf_gausswin", 12), ("crf_ppv", 32), ("hilbert_env", 2), ("crf_max", 32), | |
| ("morphology_updown", 4), ("fftbands", 6), ("perm_entropy", 3), ("curvature", 3), | |
| ("conv_position", 3), ("ar_residual", 4), ("acf_first_min", 3), ("histogram_mode", 3), | |
| ("ricker_wavelet", 4), | |
| ] | |
| # ----- family 0: stats (12) — also consumed by srf_mlp ----- | |
| def stats(self, w): | |
| m = w.mean(1); s = w.std(1).clamp(min=1e-8); c = w - w.mean(1, keepdim=True) | |
| return torch.stack([ | |
| m, s, | |
| (c ** 3).mean(1) / (s ** 3 + 1e-8), # skew | |
| (c ** 4).mean(1) / (s ** 4 + 1e-8) - 3.0, # excess kurtosis | |
| torch.quantile(w, 0.25, dim=1), torch.quantile(w, 0.5, dim=1), torch.quantile(w, 0.75, dim=1), | |
| torch.quantile(w, 0.75, dim=1) - torch.quantile(w, 0.25, dim=1), # IQR | |
| (c[:, :-1] * c[:, 1:]).sum(1) / ((c[:, :-1] ** 2).sum(1).sqrt() * (c[:, 1:] ** 2).sum(1).sqrt() + 1e-8), | |
| (w[:, 1:] - w[:, :-1]).abs().mean(1), w.amin(1), w.amax(1), | |
| ], dim=1) | |
| # ----- family 1: random-MLP projection of the stats vector (12) ----- | |
| def srf_mlp(self, w, st): | |
| h = torch.relu(torch.einsum("nm,kdm->nkd", st, self.srf_W) + self.srf_b.unsqueeze(0)) | |
| return (h * self.srf_u.unsqueeze(0)).sum(2) | |
| # ----- family 2: autocorrelation at lags 1,2 (2) ----- | |
| def autocorr(self, w): | |
| c = w - w.mean(1, keepdim=True) | |
| return torch.stack([ | |
| (c[:, :-1] * c[:, 1:]).sum(1) / ((c[:, :-1] ** 2).sum(1) + 1e-8), | |
| (c[:, :-2] * c[:, 2:]).sum(1) / ((c[:, :-2] ** 2).sum(1) + 1e-8), | |
| ], dim=1) | |
| # ----- family 3: spectral centroid + entropy (2) ----- | |
| def spectral(self, w): | |
| p = torch.fft.rfft(w, dim=1).abs()[:, 1:] | |
| p = p / (p.sum(1, keepdim=True) + 1e-8) | |
| f = torch.arange(p.shape[1], device=w.device, dtype=torch.float32) | |
| return torch.stack([(f[None, :] * p).sum(1), -(p * (p + 1e-12).log()).sum(1)], dim=1) | |
| # ----- family 4: turning-point rate + fraction-positive-diff (2) ----- | |
| def turning(self, w): | |
| dw = w[:, 1:] - w[:, :-1] | |
| return torch.stack([(dw[:, 1:] * dw[:, :-1] < 0).float().mean(1), (dw > 0).float().mean(1)], dim=1) | |
| # ----- family 5: Gaussian-window weighted means at 12 time centers (12) ----- | |
| def trf_gausswin(self, w): | |
| t = torch.arange(w.shape[1], device=w.device, dtype=torch.float32) / w.shape[1] | |
| win = torch.exp(-((t[None, :] - self.trf_mu[:, None]) ** 2) / (2 * self.trf_sig[:, None] ** 2 + 1e-8)) | |
| return w @ (win / (win.sum(1, keepdim=True) + 1e-8)).transpose(0, 1) | |
| # ----- family 6: conv positive-fraction (PPV) at dilations 2,4 (32) ----- | |
| def crf_ppv(self, w): | |
| kw = self.crf_w.shape[2] | |
| o1 = F.conv1d(w.unsqueeze(1), self.crf_w, padding=2 * (kw // 2), dilation=2) | |
| o2 = F.conv1d(w.unsqueeze(1), self.crf_w, padding=4 * (kw // 2), dilation=4) | |
| return torch.cat([(o1 > 0).float().mean(2), (o2 > 0).float().mean(2)], dim=1) | |
| # ----- family 7: analytic-signal (Hilbert) envelope stats (2) ----- | |
| def hilbert_env(self, w): | |
| n = w.shape[1]; k = torch.arange(n, device=w.device, dtype=torch.float32) | |
| h = torch.where(k == 0, 1.0, torch.where(k < n / 2, 2.0, torch.where(k == n // 2, 1.0, 0.0))) | |
| env = torch.fft.ifft(torch.fft.fft(w, dim=1) * h[None, :], dim=1).abs() | |
| return torch.stack([env.std(1) / (env.mean(1) + 1e-8), | |
| (env[:, 1:] - env[:, :-1]).abs().mean(1) / (env.mean(1) + 1e-8)], dim=1) | |
| # ----- family 8: conv max-pool at dilations 2,4 (32) ----- | |
| def crf_max(self, w): | |
| kw = self.crf_w.shape[2] | |
| return torch.cat([F.conv1d(w.unsqueeze(1), self.crf_w, padding=d * (kw // 2), dilation=d).amax(2) | |
| for d in (2, 4)], dim=1) | |
| # ----- family 9: up/down-stroke morphology asymmetry (4) ----- | |
| def morphology_updown(self, w): | |
| dw = w[:, 1:] - w[:, :-1]; p = dw.clamp(min=0); n = (-dw).clamp(min=0) | |
| return torch.stack([ | |
| p.amax(1), n.amax(1), | |
| ((p.amax(1) + 1e-6) / (n.amax(1) + 1e-6)).log(), | |
| (((dw.clamp(min=0) ** 2).sum(1) + 1e-6) / ((dw.clamp(max=0) ** 2).sum(1) + 1e-6)).log(), | |
| ], dim=1) | |
| # ----- family 10: binned FFT log band power, 6 bands (6) ----- | |
| def fftbands(self, w): | |
| p = (torch.fft.rfft(w, dim=1).abs()[:, 1:]) ** 2 | |
| pn = p / (p.sum(1, keepdim=True) + 1e-8) | |
| return torch.stack([b.sum(1) for b in pn.split(6, dim=1)], dim=1).clamp(min=1e-8).log() | |
| # ----- family 11: permutation entropy (len-3) + Hjorth mobility/complexity (3) ----- | |
| def perm_entropy(self, w): | |
| code = (w[:, :-2] < w[:, 1:-1]).long() * 4 + (w[:, 1:-1] < w[:, 2:]).long() * 2 + (w[:, :-2] < w[:, 2:]).long() | |
| H = torch.stack([(code == k).float().mean(1) for k in range(8)], dim=1) | |
| d1 = w[:, 1:] - w[:, :-1]; d2 = w[:, 2:] - 2 * w[:, 1:-1] + w[:, :-2] | |
| mob = ((d1.var(1) + 1e-8) / (w.var(1) + 1e-8)).sqrt() | |
| comp = ((d2.var(1) + 1e-8) / (d1.var(1) + 1e-8)).sqrt() / (mob + 1e-8) | |
| return torch.stack([-(H * (H + 1e-12).log()).sum(1) / math.log(6.0), mob, comp], dim=1) | |
| # ----- family 12: local curvature via 2nd difference (3) ----- | |
| def curvature(self, w): | |
| c = w[:, 2:] - 2 * w[:, 1:-1] + w[:, :-2] | |
| return torch.stack([c.abs().mean(1), c.abs().amax(1), | |
| (((c.clamp(min=0) ** 2).sum(1) + 1e-6) / ((c.clamp(max=0) ** 2).sum(1) + 1e-6)).log()], dim=1) | |
| # ----- family 13: conv max-response position summary (3) ----- | |
| def conv_position(self, w): | |
| kw = self.crf_w.shape[2] | |
| pos = F.conv1d(w.unsqueeze(1), self.crf_w, padding=2 * (kw // 2), dilation=2).argmax(2).float() / w.shape[1] | |
| return torch.stack([pos.mean(1), pos.std(1), pos.amax(1) - pos.amin(1)], dim=1) | |
| # ----- family 14: AR(2)+AR(3) ridge-fit residual energy + coeff norm (4) ----- | |
| def ar_residual(self, w): | |
| c = w - w.mean(1, keepdim=True) | |
| def fit(X, y): | |
| A = X.transpose(1, 2) @ X + 1e-3 * torch.eye(X.shape[2], device=w.device) | |
| beta = torch.linalg.solve(A, X.transpose(1, 2) @ y.unsqueeze(2)).squeeze(2) | |
| resid = (y - (X @ beta.unsqueeze(2)).squeeze(2)).var(1) | |
| return torch.stack([(resid + 1e-8).log(), (beta ** 2).sum(1).clamp(min=1e-8).log()], 1) | |
| ar2 = fit(torch.stack([c[:, 1:-1], c[:, :-2]], 2), c[:, 2:]) | |
| ar3 = fit(torch.stack([c[:, 2:-1], c[:, 1:-2], c[:, :-3]], 2), c[:, 3:]) | |
| return torch.cat([ar2, ar3], 1) | |
| # ----- family 15: first-minimum / first-zero of the autocorrelation function (3) ----- | |
| def acf_first_min(self, w): | |
| c = w - w.mean(1, keepdim=True); n = w.shape[1] | |
| acf = torch.stack([(c[:, :n - k] * c[:, k:]).sum(1) / ((c ** 2).sum(1) + 1e-8) for k in range(1, 33)], dim=1) | |
| ismin = (acf[:, 1:-1] < acf[:, :-2]) & (acf[:, 1:-1] <= acf[:, 2:]) | |
| has = ismin.any(1); fm = ismin.float().argmax(1) | |
| first_min = torch.where(has, fm.float() + 2.0, torch.tensor(float(acf.shape[1]), device=w.device)) / n | |
| first_zero = ((acf < 0).cumsum(1).clamp(max=1) * torch.arange(acf.shape[1], 0, -1, device=w.device)[None, :]).argmax(1).float() / n | |
| val_after = acf.gather(1, (fm + 1).clamp(max=acf.shape[1] - 1).unsqueeze(1)).squeeze(1) | |
| return torch.stack([first_min, first_zero, val_after], dim=1) | |
| # ----- family 16: distribution mode (soft) + mode-median + central mass (3) ----- | |
| def histogram_mode(self, w): | |
| z = (w - w.mean(1, keepdim=True)) / w.std(1, keepdim=True).clamp(min=1e-6) | |
| ctr = torch.linspace(-2.5, 2.5, 10, device=w.device) | |
| soft = (4.0 * ((-0.5 * ((ctr[None, None, :] - z.unsqueeze(2)) / (5.0 / 9.0)) ** 2).exp().sum(1))).softmax(1) | |
| m10 = soft.mul(ctr[None, :]).sum(1) | |
| return torch.stack([m10, m10 - z.median(1).values, (z.abs() < 0.5).float().mean(1)], dim=1) | |
| # ----- family 17: Ricker (Mexican-hat) wavelet PPV at 4 scales (4) ----- | |
| def ricker_wavelet(self, w): | |
| x = torch.linspace(-7.0, 7.0, 15, device=w.device) | |
| ks = [] | |
| for s in (1.5, 2.5, 4.0, 6.0): | |
| r = (1.0 - (x / s) ** 2) * torch.exp(-(x ** 2) / (2 * s * s)) | |
| ks.append((r - r.mean()) / (r.abs().sum() + 1e-8)) | |
| W = torch.stack(ks).unsqueeze(1) | |
| return (F.conv1d(w.unsqueeze(1), W, padding=15 // 2) > 0).float().mean(2) | |
| # ----- assemble phi in the discovered order ----- | |
| def phi(self, w): | |
| st = self.stats(w) | |
| cols = [ | |
| st, # stats | |
| self.srf_mlp(w, st), # srf_mlp | |
| self.autocorr(w), self.spectral(w), self.turning(w), | |
| self.trf_gausswin(w), self.crf_ppv(w), self.hilbert_env(w), self.crf_max(w), | |
| self.morphology_updown(w), self.fftbands(w), self.perm_entropy(w), self.curvature(w), | |
| self.conv_position(w), self.ar_residual(w), self.acf_first_min(w), self.histogram_mode(w), | |
| self.ricker_wavelet(w), | |
| ] | |
| return torch.cat([c if c.dim() == 2 else c.unsqueeze(1) for c in cols], dim=1) | |
| __call__ = phi | |
| @property | |
| def families(self): | |
| return list(self._families) | |
| @property | |
| def n_cols(self): | |
| return sum(c for _, c in self._families) |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
POOL_MAP_245 = {
"stats": ["std", "max"], "turning": ["std", "max"], "trf_gausswin": ["mean", "max"],
"crf_ppv": ["mean", "std", "max"], "crf_max": ["mean", "max"], "morphology_updown": ["mean"],
"fftbands": ["mean"], "curvature": ["mean", "max"], "conv_position": ["max"], "ar_residual": ["mean"],
"histogram_mode": ["std", "max"], "ricker_wavelet": ["std"]}