Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Last active November 25, 2025 01:09
Show Gist options
  • Select an option

  • Save pszemraj/3218a51be79c64274473870cf66c7b4b to your computer and use it in GitHub Desktop.

Select an option

Save pszemraj/3218a51be79c64274473870cf66c7b4b to your computer and use it in GitHub Desktop.
working eggroll impl from various LLMs & yours truly
"""
EGGROLL: Evolution Guided General Optimization via Low-rank Learning
NumPy Implementation - Direct translation of working PyTorch code
Paper: arXiv:2511.16652v1
"""
import numpy as np
from dataclasses import dataclass
from typing import Tuple, Optional
@dataclass
class EggrollConfig:
rank: int = 1
sigma: float = 0.1
learning_rate: float = 0.02
pop_size: int = 512
use_rank_transform: bool = True
def centered_rank_transform(fitnesses: np.ndarray) -> np.ndarray:
"""
Appendix D.2 / Table 1: Centered rank transformation.
Maps to [-0.5, 0.5] based on rank.
"""
n = len(fitnesses)
# argsort twice gives ranks
ranks = np.argsort(np.argsort(fitnesses)).astype(np.float64)
return (ranks / (n - 1)) - 0.5 if n > 1 else np.zeros_like(fitnesses)
def relu(x: np.ndarray) -> np.ndarray:
return np.maximum(0, x)
class EggrollLinear:
"""
EGGROLL-enabled Linear layer.
Paper Section 4.3: y = xW^T + b + (σ/√r)(xB)A^T
"""
def __init__(self, in_features: int, out_features: int, bias: bool = True):
self.in_features = in_features
self.out_features = out_features
self.has_bias = bias
# Kaiming initialization
std = np.sqrt(2.0 / in_features)
self.weight = np.random.randn(out_features, in_features) * std
self.bias = np.zeros(out_features) if bias else None
# Adam state
self.m_w = np.zeros_like(self.weight)
self.v_w = np.zeros_like(self.weight)
self.m_b = np.zeros_like(self.bias) if bias else None
self.v_b = np.zeros_like(self.bias) if bias else None
def forward(self, x: np.ndarray) -> np.ndarray:
"""Standard forward: y = xW^T + b"""
out = x @ self.weight.T
if self.bias is not None:
out = out + self.bias
return out
def generate_noise(
self, seed: int, rank: int
) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
"""Generate A, B, and bias noise from seed."""
# Ensure seed is within valid range for numpy
rng = np.random.RandomState(seed % (2**32 - 1))
A = rng.randn(self.out_features, rank)
B = rng.randn(self.in_features, rank)
E_b = rng.randn(self.out_features) if self.has_bias else None
return A, B, E_b
def forward_perturbed(
self,
x: np.ndarray,
A: np.ndarray,
B: np.ndarray,
E_b: Optional[np.ndarray],
sigma: float,
rank: int,
) -> np.ndarray:
"""
Forward with low-rank perturbation.
Paper Section 4.3: y = xμ + (σ/√r)(xB)A^T
Args:
x: (Batch, In) or (Pop, Batch, In)
A: (Out, Rank) or (Pop, Out, Rank)
B: (In, Rank) or (Pop, In, Rank)
"""
# Base: xW^T + b
base = x @ self.weight.T
if self.bias is not None:
base = base + self.bias
scale = sigma / np.sqrt(rank)
# Perturbation
if A.ndim == 2:
# Single perturbation
xB = x @ B # (..., Rank)
pert = xB @ A.T # (..., Out)
else:
# Batched: A is (Pop, Out, Rank), B is (Pop, In, Rank)
if x.ndim == 2:
# x: (Batch, In) -> broadcast
xB = np.einsum("bi,pir->pbr", x, B)
else:
# x: (Pop, Batch, In)
xB = np.einsum("pbi,pir->pbr", x, B)
pert = np.einsum("pbr,por->pbo", xB, A)
out = base + scale * pert
# Bias perturbation
if E_b is not None:
if E_b.ndim == 1:
out = out + sigma * E_b
else:
# E_b: (Pop, Out) -> (Pop, 1, Out) for broadcast
out = out + sigma * E_b[:, np.newaxis, :]
return out
class EggrollMLP:
"""Simple 2-layer MLP with EGGROLL."""
def __init__(self, in_dim: int, hidden_dim: int, out_dim: int):
self.layer1 = EggrollLinear(in_dim, hidden_dim)
self.layer2 = EggrollLinear(hidden_dim, out_dim)
self.layers = [self.layer1, self.layer2]
def forward(self, x: np.ndarray) -> np.ndarray:
"""Standard forward pass with mean parameters."""
h = relu(self.layer1.forward(x))
return self.layer2.forward(h)
def forward_population(
self,
x: np.ndarray,
seeds: np.ndarray,
config: EggrollConfig,
) -> np.ndarray:
"""
Batched forward for entire population.
Paper Section 4.3: Efficient batched inference.
"""
pop_size = len(seeds)
# Generate all noise
all_noise = []
for layer_idx, layer in enumerate(self.layers):
A_batch = np.empty((pop_size, layer.out_features, config.rank))
B_batch = np.empty((pop_size, layer.in_features, config.rank))
Eb_batch = (
np.empty((pop_size, layer.out_features)) if layer.has_bias else None
)
for i, seed in enumerate(seeds):
# Derive layer-specific seed (avoid overflow)
layer_seed = (int(seed) + layer_idx * 997) % (2**32 - 1)
A, B, E_b = layer.generate_noise(layer_seed, config.rank)
A_batch[i] = A
B_batch[i] = B
if Eb_batch is not None:
Eb_batch[i] = E_b
all_noise.append((A_batch, B_batch, Eb_batch))
# Forward pass
# Layer 1
A1, B1, Eb1 = all_noise[0]
h = self.layer1.forward_perturbed(x, A1, B1, Eb1, config.sigma, config.rank)
h = relu(h) # (Pop, Batch, Hidden)
# Layer 2
A2, B2, Eb2 = all_noise[1]
out = self.layer2.forward_perturbed(h, A2, B2, Eb2, config.sigma, config.rank)
return out # (Pop, Batch, Out)
def compute_gradient(
layer: EggrollLinear,
seeds: np.ndarray,
fitnesses: np.ndarray,
layer_idx: int,
config: EggrollConfig,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""
Compute EGGROLL pseudo-gradient for one layer.
Paper Eq. 8: Δμ ∝ (1/N) Σ E_i f_i
With E_i = (1/√r) A_i B_i^T
Paper Section 4.3 optimization:
"calculate the expression as (A ⊙ f)B^T"
"""
pop_size = len(seeds)
rank = config.rank
sigma = config.sigma
# Reconstruct noise and accumulate
A_all = np.empty((pop_size, layer.out_features, rank))
B_all = np.empty((pop_size, layer.in_features, rank))
Eb_all = np.empty((pop_size, layer.out_features)) if layer.has_bias else None
for i, seed in enumerate(seeds):
layer_seed = (int(seed) + layer_idx * 997) % (2**32 - 1)
A, B, E_b = layer.generate_noise(layer_seed, rank)
A_all[i] = A
B_all[i] = B
if Eb_all is not None:
Eb_all[i] = E_b
# Paper Section 4.3: (A ⊙ f)^T @ B
# fitnesses: (Pop,) -> (Pop, 1, 1)
f_view = fitnesses.reshape(-1, 1, 1)
A_weighted = A_all * f_view # (Pop, Out, Rank)
# Sum: Σ A_i B_i^T f_i = einsum('por,pir->oi', A_weighted, B_all)
grad_w = np.einsum("por,pir->oi", A_weighted, B_all)
# Scale: 1/(N * σ * √r) per original PyTorch code
scale = 1.0 / (pop_size * sigma * np.sqrt(rank))
grad_w *= scale
# Bias gradient
grad_b = None
if Eb_all is not None:
# (Pop,) @ (Pop, Out) -> (Out,)
grad_b = fitnesses @ Eb_all
grad_b *= 1.0 / (pop_size * sigma)
return grad_w, grad_b
def adam_update(
param: np.ndarray,
grad: np.ndarray,
m: np.ndarray,
v: np.ndarray,
lr: float,
t: int,
beta1: float = 0.9,
beta2: float = 0.999,
eps: float = 1e-8,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Adam optimizer step."""
m = beta1 * m + (1 - beta1) * grad
v = beta2 * v + (1 - beta2) * (grad**2)
m_hat = m / (1 - beta1**t)
v_hat = v / (1 - beta2**t)
# Add gradient (we're maximizing fitness, grad is already positive direction)
param = param + lr * m_hat / (np.sqrt(v_hat) + eps)
return param, m, v
def train_xor(run_id: int, verbose: bool = True) -> bool:
"""Train on XOR problem."""
if verbose:
print(f"\n--- Run {run_id}: XOR Training ---")
# Data
inputs = np.array([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]])
targets = np.array([[0.0], [1.0], [1.0], [0.0]])
# Model
np.random.seed(run_id * 12345) # Different init per run
model = EggrollMLP(2, 16, 1)
# Config
config = EggrollConfig(
rank=1,
sigma=0.1,
learning_rate=0.02,
pop_size=512,
use_rank_transform=True,
)
t = 0 # Adam timestep
for gen in range(301):
# Generate seeds for this generation (avoid overflow)
base_seed = ((run_id * 10000) + gen * config.pop_size) % (2**31)
seeds = np.arange(base_seed, base_seed + config.pop_size)
# Population forward
pop_outputs = model.forward_population(inputs, seeds, config)
# pop_outputs: (Pop, Batch, 1)
# Fitness: negative MSE
diff = pop_outputs - targets[np.newaxis, :, :]
mse = (diff**2).mean(axis=(1, 2))
fitnesses = -mse
# Fitness shaping
if config.use_rank_transform:
shaped_fitness = centered_rank_transform(fitnesses)
else:
shaped_fitness = fitnesses
# Update each layer
t += 1
for layer_idx, layer in enumerate(model.layers):
grad_w, grad_b = compute_gradient(
layer, seeds, shaped_fitness, layer_idx, config
)
# Adam update (maximize, so we add)
layer.weight, layer.m_w, layer.v_w = adam_update(
layer.weight, grad_w, layer.m_w, layer.v_w, config.learning_rate, t
)
if grad_b is not None:
layer.bias, layer.m_b, layer.v_b = adam_update(
layer.bias, grad_b, layer.m_b, layer.v_b, config.learning_rate, t
)
# Evaluate mean parameters
if gen % 20 == 0 or gen < 5:
pred = model.forward(inputs)
loss = ((pred - targets) ** 2).mean()
if verbose:
print(f"Gen {gen:3d}: MSE = {loss:.4f}")
if loss < 0.005:
if verbose:
print(f"Run {run_id}: Converged at Gen {gen}!")
return True
# Final check
pred = model.forward(inputs)
loss = ((pred - targets) ** 2).mean()
if loss < 0.01:
if verbose:
print(f"Run {run_id}: Converged at Final (Loss {loss:.4f})")
return True
if verbose:
print(f"Run {run_id}: Failed (Loss {loss:.4f})")
return False
if __name__ == "__main__":
successes = 0
total = 5
for i in range(total):
if train_xor(i + 1):
successes += 1
print(f"\n{'=' * 40}")
print(f"Results: {successes}/{total} converged")
"""
EGGROLL: Evolution Guided General Optimization via Low-rank Learning
PyTorch Implementation - Based on audited NumPy version
Paper: arXiv:2511.16652v1 [cs.LG]
Key equations (with paper references):
- Section 4.1: E = (1/√r) AB^T (low-rank perturbation)
- Section 4.3: y = xμ + (σ/√r)(xB)A^T (efficient forward)
- Proposition 1: ∇J = -(1/σ) E[E·f] (ES gradient)
- Section 4.3: Σ E_i f_i = (A ⊙ f)B^T (efficient update)
- Table 1: optimizer ∈ {adam, adamw, sgd} (Adam is valid)
"""
import math
from dataclasses import dataclass
from typing import Tuple, Optional
import torch
import torch.nn as nn
@dataclass
class EggrollConfig:
"""
Configuration for EGGROLL training.
Hyperparameters from Table 1.
"""
rank: int = 1 # r, Paper shows r=1 is sufficient (Section 5, Figure 3)
sigma: float = 0.1 # Noise scale σ (Table 1: 0.05, 0.2, 0.5)
learning_rate: float = 0.02 # α (Table 1: 1e-3, 1e-2, 1e-1)
pop_size: int = 512 # N_workers (Paper tests up to 262,144)
use_rank_transform: bool = True # Table 1: rank_transform
def centered_rank_transform(fitnesses: torch.Tensor) -> torch.Tensor:
"""
Centered Rank Transformation (Table 1: rank_transform).
Maps fitness values to [-0.5, 0.5] based on rank ordering.
Args:
fitnesses: Tensor of shape (PopSize,)
Returns:
Tensor of shape (PopSize,) with values in [-0.5, 0.5]
"""
n = len(fitnesses)
ranks = torch.argsort(torch.argsort(fitnesses)).float()
return (ranks / (n - 1)) - 0.5 if n > 1 else torch.zeros_like(fitnesses)
def get_layer_seed(base_seed: int, layer_idx: int) -> int:
"""
Derive layer-specific seed to decorrelate perturbations across layers.
Paper Algorithm 1: "workers with known random seeds ς"
"""
return (base_seed + layer_idx * 997) % (2**31)
class EggrollLinear(nn.Module):
"""
EGGROLL-enabled Linear layer.
Paper Section 4.3:
"x_i(μ + σE_i) = x_iμ + (σ/√r)(x_i B_i)A_i^T"
"""
def __init__(self, in_features: int, out_features: int, bias: bool = True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.has_bias = bias
# Mean parameters μ (what we optimize)
self.weight = nn.Parameter(torch.empty(out_features, in_features))
if bias:
self.bias = nn.Parameter(torch.empty(out_features))
else:
self.register_parameter("bias", None)
self._reset_parameters()
def _reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in = self.weight.size(1)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Standard forward with mean parameters (inference)."""
out = x @ self.weight.T
if self.bias is not None:
out = out + self.bias
return out
def generate_noise(
self,
seed: int,
rank: int,
device: torch.device,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Generate A, B, and bias noise deterministically from seed.
Paper Section 4.1: "sample A ∈ R^{m×r} and B ∈ R^{n×r}"
Paper Section 4: "counter-based deterministic RNG to reconstruct noise on demand"
"""
g = torch.Generator(device="cpu").manual_seed(seed)
A = torch.randn(self.out_features, rank, generator=g, dtype=dtype).to(device)
B = torch.randn(self.in_features, rank, generator=g, dtype=dtype).to(device)
E_b = None
if self.has_bias:
E_b = torch.randn(self.out_features, generator=g, dtype=dtype).to(device)
return A, B, E_b
def forward_perturbed(
self,
x: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
E_b: Optional[torch.Tensor],
sigma: float,
rank: int,
) -> torch.Tensor:
"""
Forward pass with low-rank perturbation.
Paper Section 4.3:
"y = xμ + (σ/√r)(xB)A^T"
Args:
x: Input (Batch, In) or (Pop, Batch, In)
A: (Pop, Out, Rank)
B: (Pop, In, Rank)
E_b: (Pop, Out) or None
sigma: Noise scale σ
rank: Rank r
Returns:
Output tensor (Pop, Batch, Out)
"""
# Base computation: xW^T + b
base = x @ self.weight.T
if self.bias is not None:
base = base + self.bias
# Paper Section 4.3: scale is σ/√r
scale = sigma / math.sqrt(rank)
# Perturbation: (σ/√r)(xB)A^T
if x.dim() == 2:
# x: (Batch, In) -> broadcast across population
# B: (Pop, In, Rank)
xB = torch.einsum("bi,pir->pbr", x, B)
else:
# x: (Pop, Batch, In)
xB = torch.einsum("pbi,pir->pbr", x, B)
# A: (Pop, Out, Rank)
perturbation = torch.einsum("pbr,por->pbo", xB, A)
out = base + scale * perturbation
# Bias perturbation
if E_b is not None:
# E_b: (Pop, Out) -> (Pop, 1, Out) for broadcast
out = out + sigma * E_b.unsqueeze(1)
return out
class EggrollMLP(nn.Module):
"""
Simple MLP with EGGROLL-enabled layers.
"""
def __init__(self, in_dim: int, hidden_dim: int, out_dim: int):
super().__init__()
self.layer1 = EggrollLinear(in_dim, hidden_dim)
self.layer2 = EggrollLinear(hidden_dim, out_dim)
self.layers = [self.layer1, self.layer2]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Standard forward pass with mean parameters."""
h = torch.relu(self.layer1(x))
return self.layer2(h)
def forward_population(
self,
x: torch.Tensor,
seeds: torch.Tensor,
config: EggrollConfig,
) -> torch.Tensor:
"""
Batched forward for entire population.
Paper Section 4.3: "batches a population of low-rank adapters
and shares the base activations"
Args:
x: Input (Batch, In)
seeds: Population seeds (Pop,)
config: EGGROLL config
Returns:
Outputs (Pop, Batch, Out)
"""
pop_size = len(seeds)
device = x.device
dtype = x.dtype
# Pre-generate all noise for all layers
all_noise = []
for layer_idx, layer in enumerate(self.layers):
A_batch = torch.empty(
pop_size, layer.out_features, config.rank, device=device, dtype=dtype
)
B_batch = torch.empty(
pop_size, layer.in_features, config.rank, device=device, dtype=dtype
)
Eb_batch = None
if layer.has_bias:
Eb_batch = torch.empty(
pop_size, layer.out_features, device=device, dtype=dtype
)
for i, seed in enumerate(seeds):
layer_seed = get_layer_seed(int(seed.item()), layer_idx)
A, B, E_b = layer.generate_noise(layer_seed, config.rank, device, dtype)
A_batch[i] = A
B_batch[i] = B
if Eb_batch is not None:
Eb_batch[i] = E_b
all_noise.append((A_batch, B_batch, Eb_batch))
# Forward pass through layers
# Layer 1 + ReLU
A1, B1, Eb1 = all_noise[0]
h = self.layer1.forward_perturbed(x, A1, B1, Eb1, config.sigma, config.rank)
h = torch.relu(h) # (Pop, Batch, Hidden)
# Layer 2
A2, B2, Eb2 = all_noise[1]
out = self.layer2.forward_perturbed(h, A2, B2, Eb2, config.sigma, config.rank)
return out # (Pop, Batch, Out)
def compute_eggroll_gradient(
layer: EggrollLinear,
seeds: torch.Tensor,
fitnesses: torch.Tensor,
layer_idx: int,
config: EggrollConfig,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Compute EGGROLL pseudo-gradient for one layer.
Paper Eq. 8: "μ_{t+1} = μ_t + (α/N) Σ E_{i,t} f_i"
With E_i = (1/√r) A_i B_i^T
Paper Proposition 1: "∇J = -(1/σ) E[E·f]"
Paper Section 4.3 optimization:
"calculate the expression as (A ⊙ f)B^T"
Args:
layer: The EGGROLL linear layer
seeds: Population seeds (Pop,)
fitnesses: Shaped fitness values (Pop,)
layer_idx: Index for seed derivation
config: EGGROLL config
Returns:
(grad_weight, grad_bias) tensors
"""
pop_size = len(seeds)
rank = config.rank
sigma = config.sigma
device = layer.weight.device
dtype = layer.weight.dtype
# Reconstruct noise
A_all = torch.empty(pop_size, layer.out_features, rank, device=device, dtype=dtype)
B_all = torch.empty(pop_size, layer.in_features, rank, device=device, dtype=dtype)
Eb_all = None
if layer.has_bias:
Eb_all = torch.empty(pop_size, layer.out_features, device=device, dtype=dtype)
for i, seed in enumerate(seeds):
layer_seed = get_layer_seed(int(seed.item()), layer_idx)
A, B, E_b = layer.generate_noise(layer_seed, rank, device, dtype)
A_all[i] = A
B_all[i] = B
if Eb_all is not None:
Eb_all[i] = E_b
# Paper Section 4.3: "(A ⊙ f)B^T"
# fitnesses: (Pop,) -> (Pop, 1, 1)
f_view = fitnesses.view(-1, 1, 1)
A_weighted = A_all * f_view # (Pop, Out, Rank)
# Σ A_i B_i^T f_i via einsum
grad_w = torch.einsum("por,pir->oi", A_weighted, B_all)
# Scale: 1/(N·σ·√r) from Proposition 1
# The 1/σ comes from ∇J = -(1/σ)E[E·f]
# The 1/√r comes from E = (1/√r)AB^T
scale = 1.0 / (pop_size * sigma * math.sqrt(rank))
grad_w = grad_w * scale
# Bias gradient
grad_b = None
if Eb_all is not None:
# (Pop,) @ (Pop, Out) -> (Out,)
grad_b = fitnesses @ Eb_all
grad_b = grad_b * (1.0 / (pop_size * sigma))
return grad_w, grad_b
class EggrollOptimizer:
"""
EGGROLL optimizer with Adam (Table 1: optimizer ∈ {adam, adamw, sgd}).
Handles pseudo-gradient computation and parameter updates.
"""
def __init__(self, model: EggrollMLP, config: EggrollConfig):
self.model = model
self.config = config
# Adam state for each layer
self.state = {}
for idx, layer in enumerate(model.layers):
self.state[f"layer{idx}_w"] = {
"m": torch.zeros_like(layer.weight),
"v": torch.zeros_like(layer.weight),
}
if layer.bias is not None:
self.state[f"layer{idx}_b"] = {
"m": torch.zeros_like(layer.bias),
"v": torch.zeros_like(layer.bias),
}
self.t = 0
self.beta1 = 0.9
self.beta2 = 0.999
self.eps = 1e-8
def _adam_update(
self,
param: torch.Tensor,
grad: torch.Tensor,
state: dict,
) -> None:
"""In-place Adam update."""
state["m"].mul_(self.beta1).add_(grad, alpha=1 - self.beta1)
state["v"].mul_(self.beta2).addcmul_(grad, grad, value=1 - self.beta2)
m_hat = state["m"] / (1 - self.beta1**self.t)
v_hat = state["v"] / (1 - self.beta2**self.t)
# Maximize fitness: add gradient (not subtract)
param.data.add_(
m_hat / (v_hat.sqrt() + self.eps), alpha=self.config.learning_rate
)
def step(self, seeds: torch.Tensor, fitnesses: torch.Tensor) -> None:
"""
Perform one EGGROLL optimization step.
Paper Algorithm 1:
"μ ← μ + α · (1/N_workers) · Σ E_j · f_j"
Args:
seeds: Population seeds used in forward pass (Pop,)
fitnesses: Raw fitness values (Pop,)
"""
# Fitness shaping (Table 1: rank_transform)
if self.config.use_rank_transform:
shaped_fitness = centered_rank_transform(fitnesses)
else:
shaped_fitness = fitnesses
self.t += 1
# Update each layer
for layer_idx, layer in enumerate(self.model.layers):
grad_w, grad_b = compute_eggroll_gradient(
layer, seeds, shaped_fitness, layer_idx, self.config
)
# Adam update for weight
self._adam_update(
layer.weight,
grad_w,
self.state[f"layer{layer_idx}_w"],
)
# Adam update for bias
if grad_b is not None:
self._adam_update(
layer.bias,
grad_b,
self.state[f"layer{layer_idx}_b"],
)
def train_xor(run_id: int, verbose: bool = True) -> bool:
"""
Train on XOR problem using EGGROLL.
Args:
run_id: Run identifier for different random seeds
verbose: Whether to print progress
Returns:
True if converged, False otherwise
"""
if verbose:
print(f"\n--- Run {run_id}: XOR Training ---")
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Data
inputs = torch.tensor(
[[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]], device=device
)
targets = torch.tensor([[0.0], [1.0], [1.0], [0.0]], device=device)
# Model initialization with different seed per run
torch.manual_seed(run_id * 12345)
model = EggrollMLP(2, 16, 1).to(device)
# Config
config = EggrollConfig(
rank=1,
sigma=0.1,
learning_rate=0.02,
pop_size=512,
use_rank_transform=True,
)
optimizer = EggrollOptimizer(model, config)
if verbose:
print(
f"Config: rank={config.rank}, σ={config.sigma}, lr={config.learning_rate}"
)
print(f"Population size: {config.pop_size}")
for gen in range(301):
# Generate seeds for this generation
base_seed = ((run_id * 10000) + gen * config.pop_size) % (2**31)
seeds = torch.arange(base_seed, base_seed + config.pop_size, device=device)
# Population forward pass
with torch.no_grad():
pop_outputs = model.forward_population(inputs, seeds, config)
# pop_outputs: (Pop, Batch, 1)
# Fitness: negative MSE (maximize)
diff = pop_outputs - targets.unsqueeze(0)
mse = (diff**2).mean(dim=(1, 2))
fitnesses = -mse
# Optimization step
optimizer.step(seeds, fitnesses)
# Evaluate mean parameters
if gen % 20 == 0:
with torch.no_grad():
pred = model(inputs)
loss = ((pred - targets) ** 2).mean().item()
if verbose:
print(f"Gen {gen:3d}: MSE = {loss:.4f}")
if loss < 0.005:
if verbose:
print(f"Run {run_id}: Converged at Gen {gen}!")
return True
# Final check
with torch.no_grad():
pred = model(inputs)
loss = ((pred - targets) ** 2).mean().item()
if loss < 0.01:
if verbose:
print(f"Run {run_id}: Converged at Final (Loss {loss:.4f})")
return True
if verbose:
print(f"Run {run_id}: Failed (Loss {loss:.4f})")
return False
if __name__ == "__main__":
successes = 0
total = 5
for i in range(total):
if train_xor(i + 1):
successes += 1
print(f"\n{'=' * 40}")
print(f"Results: {successes}/{total} converged")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment