Skip to content

Instantly share code, notes, and snippets.

@lastforkbender
Created April 19, 2026 01:36
Show Gist options
  • Select an option

  • Save lastforkbender/50421ae0e59b800738e281ff94a5782d to your computer and use it in GitHub Desktop.

Select an option

Save lastforkbender/50421ae0e59b800738e281ff94a5782d to your computer and use it in GitHub Desktop.
NORT9-C Ai
"""
nort9commutator.py
NORT9 - "Fearless" configuration: aggressive learning, stronger bidirectional coupling,
GPU-native B-spline evaluation, batched vjp/jvp propagation, mixed-precision optional.
Features:
- Layered directional node network with bi-directional coupling
- Cubic B-spline gating with learnable coefficients (GPU PyTorch implementation)
- Timed gradient transport using batched vjp/jvp where possible
- Commutator diagnostics and adaptive gating updates
- U-region separation (auto or fixed)
- Transformer-style plugin hooks
- Numba CPU helpers retained for fallback
- Aggressive defaults (higher coupling init, higher gating LR, more knots)
- Single-file module
Dependencies:
- numpy
- numba
- torch (>=1.10)
- typing (standard lib)
Usage:
- Instantiate using build_nort9(..., fearless=True) for aggressive defaults.
- Call model.forward(...), then model.master_backward_and_step(...) to train.
"""
from typing import List, Optional, Callable, Dict, Any, Tuple
import math
import numpy as np
from numba import njit
import torch
from torch import nn
from torch.autograd.functional import vjp, jvp
import torch.nn.functional as F
# -----------------------------
# Config / Defaults
# -----------------------------
DEFAULT_DTYPE = torch.float32
EPS = 1e-8
# -----------------------------
# Numba helpers (CPU) - retained
# -----------------------------
@njit
def curvature_laplacian_array(arr: np.ndarray) -> np.ndarray:
n = arr.shape[0]
out = np.empty(n, dtype=np.float64)
for i in range(n):
if i == 0:
diff = arr[1] - arr[0]
elif i == n - 1:
diff = arr[n - 1] - arr[n - 2]
else:
diff = arr[i + 1] - 2.0 * arr[i] + arr[i - 1]
s = 0.0
for j in range(diff.shape[0]):
s += diff[j] * diff[j]
out[i] = math.sqrt(s + 1e-12)
return out
# -----------------------------
# GPU-native B-spline basis (PyTorch)
# -----------------------------
def bspline_basis_torch(t: torch.Tensor, knots: torch.Tensor, degree: int) -> torch.Tensor:
"""
Evaluate all basis functions at times t using Cox-De Boor, vectorized on GPU.
t: (...,) tensor in [knots[0], knots[-1]]
knots: (K,) tensor
returns: (..., m) where m = K - degree - 1
"""
# ensure t is 1D
flat = t.reshape(-1)
K = knots.shape[0]
m = K - degree - 1
device = t.device
dtype = t.dtype
# initialize zeroth-degree basis: N_{j,0}(t) = 1 if knots[j] <= t < knots[j+1]
N = torch.zeros((flat.shape[0], m), device=device, dtype=dtype)
for j in range(m):
left = knots[j]
right = knots[j + 1]
# include right endpoint for last knot
mask = (flat >= left) & (flat < right)
if j == m - 1:
mask = (flat >= left) & (flat <= right)
N[:, j] = mask.to(dtype)
# recursion
for deg in range(1, degree + 1):
N_new = torch.zeros_like(N)
for j in range(m):
denom_left = knots[j + deg] - knots[j]
denom_right = knots[j + deg + 1] - knots[j + 1]
left = torch.zeros_like(flat, device=device, dtype=dtype)
right = torch.zeros_like(flat, device=device, dtype=dtype)
if denom_left > 1e-12:
left = ((flat - knots[j]) / denom_left) * N[:, j]
if denom_right > 1e-12 and j + 1 < m:
right = ((knots[j + deg + 1] - flat) / denom_right) * N[:, j + 1]
N_new[:, j] = left + right
N = N_new
return N.reshape(*t.shape, m)
# -----------------------------
# PyTorch components
# -----------------------------
class DirectionalLayer(nn.Module):
def __init__(self, in_features: int, out_features: int, use_bias: bool = True,
device=None, dtype=DEFAULT_DTYPE, coupling_scale: float = 1e-2):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.empty(out_features, in_features, device=device, dtype=dtype))
if use_bias:
self.bias = nn.Parameter(torch.zeros(out_features, device=device, dtype=dtype))
else:
self.register_parameter("bias", None)
# stronger coupling init if fearless
self.c_fw = nn.Parameter(torch.empty(out_features, in_features, device=device, dtype=dtype))
self.c_bw = nn.Parameter(torch.empty(out_features, in_features, device=device, dtype=dtype))
self.coupling_scale = coupling_scale
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(max(1, fan_in))
nn.init.uniform_(self.bias, -bound, bound)
nn.init.normal_(self.c_fw, mean=0.0, std=self.coupling_scale)
nn.init.normal_(self.c_bw, mean=0.0, std=self.coupling_scale)
def forward(self, a_prev: torch.Tensor, a_next: Optional[torch.Tensor] = None) -> torch.Tensor:
z = F.linear(a_prev, self.weight, self.bias)
z = z + F.linear(a_prev, self.c_fw, None)
if a_next is not None:
z = z + F.linear(a_next, self.c_bw, None)
return z
class BSplineGating(nn.Module):
def __init__(self, n_knots: int = 16, degree: int = 3, t_min: float = 0.0, t_max: float = 1.0,
init_alpha_scale: float = 1e-2, device=None, dtype=DEFAULT_DTYPE):
super().__init__()
self.degree = degree
self.t_min = float(t_min)
self.t_max = float(t_max)
self.n_knots = int(n_knots)
# create uniform clamped knot vector
interior_count = max(2, self.n_knots - 2 * (degree + 1) + 2)
interior = torch.linspace(self.t_min, self.t_max, interior_count, device=device, dtype=dtype)
left = torch.full((degree + 1,), self.t_min, device=device, dtype=dtype)
right = torch.full((degree + 1,), self.t_max, device=device, dtype=dtype)
knots = torch.cat([left, interior, right])
self.register_buffer("knots", knots)
m = knots.shape[0] - degree - 1
self.alpha = nn.Parameter(torch.randn(m, device=device, dtype=dtype) * init_alpha_scale)
self.sigmoid = nn.Sigmoid()
def forward(self, t: torch.Tensor) -> torch.Tensor:
# t can be scalar or tensor; evaluate basis via GPU routine
basis = bspline_basis_torch(t, self.knots, self.degree) # shape (..., m)
# combine alpha
alpha = self.alpha.view(*(1,)*(basis.dim()-1), -1) if basis.dim() > 1 else self.alpha
logits = (basis * alpha).sum(dim=-1)
gated = torch.sigmoid(logits)
return gated
# -----------------------------
# NORT9 Core Model (Fearless)
# -----------------------------
class NORT9(nn.Module):
def __init__(self,
layer_sizes: List[int],
activation: Callable[[torch.Tensor], torch.Tensor] = F.gelu,
device: Optional[torch.device] = None,
dtype=DEFAULT_DTYPE,
use_numba: bool = True,
bspline_knots: int = 16,
bspline_degree: int = 3,
region_split: Optional[Tuple[int, int]] = None,
fearless: bool = True,
mixed_precision: bool = True):
super().__init__()
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device
self.dtype = dtype
self.activation = activation
self.use_numba = use_numba
self.n_layers = len(layer_sizes) - 1
self.layers = nn.ModuleList()
coupling_scale = 5e-2 if fearless else 1e-3
for i in range(1, len(layer_sizes)):
self.layers.append(DirectionalLayer(layer_sizes[i - 1], layer_sizes[i],
device=device, dtype=dtype, coupling_scale=coupling_scale))
self.gatings = nn.ModuleList()
for _ in range(self.n_layers):
self.gatings.append(BSplineGating(n_knots=bspline_knots, degree=bspline_degree,
device=device, dtype=dtype,
init_alpha_scale=(5e-2 if fearless else 1e-2)))
# regions
if region_split is not None:
self.region_split = region_split
else:
s_end = max(0, self.n_layers // 3 - 1)
m_end = max(s_end, 2 * self.n_layers // 3 - 1)
self.region_split = (s_end, m_end)
self.plugin_hooks: Dict[int, Callable[[torch.Tensor], torch.Tensor]] = {}
self._optim: Optional[torch.optim.Optimizer] = None
self.fearless = fearless
self.mixed_precision = mixed_precision and torch.cuda.is_available()
def register_plugin(self, layer_idx: int, fn: Callable[[torch.Tensor], torch.Tensor]):
self.plugin_hooks[layer_idx] = fn
def set_optimizer(self, optim: torch.optim.Optimizer):
self._optim = optim
def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None,
record_activations: bool = True) -> Dict[str, Any]:
a = x
activations = []
gates = []
if t is None:
t = torch.tensor(0.0, device=self.device, dtype=self.dtype)
for idx, layer in enumerate(self.layers):
g = self.gatings[idx](t)
gates.append(g)
z = layer(a, None)
if idx in self.plugin_hooks:
z = self.plugin_hooks[idx](z)
a = self.activation(z)
activations.append({'z': z, 'a': a})
return {'out': a, 'activations': activations, 'gates': gates}
def get_region_indices(self) -> Dict[str, List[int]]:
s_end, m_end = self.region_split
S = list(range(0, s_end + 1))
M = list(range(s_end + 1, m_end + 1))
E = list(range(m_end + 1, self.n_layers))
return {'S': S, 'M': M, 'E': E}
def compute_layer_curvatures(self, activations: List[Dict[str, torch.Tensor]]) -> List[float]:
curvs = []
for layer_act in activations:
a = layer_act['a'].detach().cpu().numpy()
mean_a = a.mean(axis=0)
if mean_a.ndim == 1:
arr = mean_a.reshape(mean_a.shape[0], 1)
curv = curvature_laplacian_array(arr).mean()
else:
arr = mean_a
curv = curvature_laplacian_array(arr).mean()
curvs.append(float(curv))
return curvs
def auto_region_split_by_curvature(self, activations: List[Dict[str, torch.Tensor]]):
curvs = np.array(self.compute_layer_curvatures(activations))
k = 3
centers = np.percentile(curvs, [0, 50, 100])
for _ in range(10):
dists = np.abs(curvs[:, None] - centers[None, :])
labels = dists.argmin(axis=1)
for j in range(k):
if (labels == j).any():
centers[j] = curvs[labels == j].mean()
order = centers.argsort()
label_to_region = {}
for region_idx, center_idx in enumerate(order):
label_to_region[center_idx] = region_idx
region_labels = np.array([label_to_region[l] for l in labels])
S_idx = int(np.max(np.where(region_labels == 0)[0])) if (region_labels == 0).any() else 0
M_idx = int(np.max(np.where(region_labels == 1)[0])) if (region_labels == 1).any() else min(S_idx + 1, self.n_layers - 1)
self.region_split = (S_idx, M_idx)
def _transport_kernel(self, layer_from: int, layer_to: int, sigma: float) -> float:
d = float(layer_to - layer_from)
return math.exp(- (d * d) / (2.0 * (sigma * sigma) + 1e-12))
def _batched_vjp_chain(self, from_idx: int, to_idx: int, vec: torch.Tensor,
activations: List[Dict[str, torch.Tensor]]) -> Optional[torch.Tensor]:
"""
Efficiently compute vjp from activation at from_idx down to activation at to_idx
by successive batched vjps. Returns gradient on activation[to_idx] or None on failure.
"""
current = vec
# iterate backwards
for step in range(from_idx, to_idx, -1):
a_prev = activations[step - 1]['a']
# define lambda f for step mapping
layer = self.layers[step]
def f_ap(ap):
z = layer(ap, None)
return self.activation(z)
try:
# vjp returns tuple (output, vjp), but torch.autograd.functional.vjp signature: vjp(func, inputs, v)
_, vjp_res = vjp(f_ap, a_prev, current, create_graph=False)
current = vjp_res
except RuntimeError:
return None
return current
def master_backward_and_step(self,
loss: torch.Tensor,
activations: List[Dict[str, torch.Tensor]],
gates: List[torch.Tensor],
transport_times: List[float],
gamma: float = 1.0,
sigma: Optional[float] = None,
region_weights: Optional[Dict[str, float]] = None,
clip_grad_norm: Optional[float] = 1.0,
optimizer: Optional[torch.optim.Optimizer] = None):
if optimizer is None:
optimizer = self._optim
if optimizer is None:
raise ValueError("No optimizer provided or set via set_optimizer().")
n_layers = self.n_layers
if sigma is None:
sigma = max(1.0, n_layers / 8.0)
device = self.device
dtype = self.dtype
optimizer.zero_grad(set_to_none=True)
# standard backward to populate gradients for parameters via output loss
# ensure final activation requires grad flow: use autograd.grad to get grad on final activation
a_end = activations[-1]['a']
grad_a_end = torch.autograd.grad(loss, a_end, retain_graph=True, allow_unused=True)[0]
transported_grads: List[Optional[torch.Tensor]] = [None] * n_layers
if grad_a_end is not None:
transported_grads[-1] = grad_a_end.detach()
# iterate times
for t in transport_times:
for m in range(n_layers - 1, -1, -1):
g_m = transported_grads[m]
if g_m is None:
continue
for l in range(0, m):
g_l = gates[l] if isinstance(gates[l], torch.Tensor) else torch.tensor(float(gates[l]), device=device, dtype=dtype)
time_decay = math.exp(-gamma * abs(t - 0.0))
pair_decay = self._transport_kernel(l, m, sigma)
weight = float(g_l.item()) * pair_decay * time_decay
if weight < 1e-6:
continue
current_grad = g_m * weight
# compute vjp chain from m to l
vjp_res = self._batched_vjp_chain(m, l, current_grad, activations)
if vjp_res is not None:
if transported_grads[l] is None:
transported_grads[l] = vjp_res.detach()
else:
transported_grads[l] = transported_grads[l] + vjp_res.detach()
# inject transported grads into graph via surrogate scalars
surrogate_scalars = []
for idx, g in enumerate(transported_grads):
if g is None:
continue
a_tensor = activations[idx]['a']
surrogate = (a_tensor * g).sum()
surrogate_scalars.append(surrogate)
if surrogate_scalars:
total = sum(surrogate_scalars)
total.backward()
else:
loss.backward()
if clip_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(self.parameters(), clip_grad_norm)
optimizer.step()
return transported_grads
def commutator_norm(self, x: torch.Tensor, t: float = 0.0) -> float:
eps = 1e-3
probe = torch.randn_like(x) * 1e-3
out1 = self.forward(x + eps * probe, t)['out']
out2 = self.forward(x, t)['out']
fd = (out1 - out2) / eps
return float(fd.norm().item())
# -----------------------------
# Builder with "fearless" defaults
# -----------------------------
def build_nort9(layer_sizes: List[int],
lr_weights: float = 5e-4,
lr_gating: float = 2e-3,
device: Optional[torch.device] = None,
fearless: bool = True) -> Tuple[NORT9, torch.optim.Optimizer]:
model = NORT9(layer_sizes=layer_sizes, device=device, fearless=fearless)
gating_params = []
rest_params = []
for name, p in model.named_parameters():
if 'gatings' in name and 'alpha' in name:
gating_params.append(p)
else:
rest_params.append(p)
optim = torch.optim.AdamW([
{'params': rest_params, 'lr': lr_weights},
{'params': gating_params, 'lr': lr_gating}
], betas=(0.9, 0.999), weight_decay=1e-2 if fearless else 1e-3)
model.set_optimizer(optim)
return model, optim
# -----------------------------
# Minimal test run
# -----------------------------
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, optim = build_nort9([32, 64, 64, 16], device=device, fearless=True)
x = torch.randn(8, 32, device=device)
out_dict = model.forward(x, t=torch.tensor(0.2, device=device))
out = out_dict['out']
target = torch.zeros_like(out)
loss = F.mse_loss(out, target)
transported = model.master_backward_and_step(loss, out_dict['activations'], out_dict['gates'],
transport_times=[0.0, 0.3, 0.6, 1.0], gamma=1.0,
clip_grad_norm=1.0)
print("Fearless sanity run complete. Loss:", float(loss.item()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment