Last active
November 25, 2025 01:09
-
-
Save pszemraj/3218a51be79c64274473870cf66c7b4b to your computer and use it in GitHub Desktop.
working eggroll impl from various LLMs & yours truly
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| EGGROLL: Evolution Guided General Optimization via Low-rank Learning | |
| NumPy Implementation - Direct translation of working PyTorch code | |
| Paper: arXiv:2511.16652v1 | |
| """ | |
| import numpy as np | |
| from dataclasses import dataclass | |
| from typing import Tuple, Optional | |
| @dataclass | |
| class EggrollConfig: | |
| rank: int = 1 | |
| sigma: float = 0.1 | |
| learning_rate: float = 0.02 | |
| pop_size: int = 512 | |
| use_rank_transform: bool = True | |
| def centered_rank_transform(fitnesses: np.ndarray) -> np.ndarray: | |
| """ | |
| Appendix D.2 / Table 1: Centered rank transformation. | |
| Maps to [-0.5, 0.5] based on rank. | |
| """ | |
| n = len(fitnesses) | |
| # argsort twice gives ranks | |
| ranks = np.argsort(np.argsort(fitnesses)).astype(np.float64) | |
| return (ranks / (n - 1)) - 0.5 if n > 1 else np.zeros_like(fitnesses) | |
| def relu(x: np.ndarray) -> np.ndarray: | |
| return np.maximum(0, x) | |
| class EggrollLinear: | |
| """ | |
| EGGROLL-enabled Linear layer. | |
| Paper Section 4.3: y = xW^T + b + (σ/√r)(xB)A^T | |
| """ | |
| def __init__(self, in_features: int, out_features: int, bias: bool = True): | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.has_bias = bias | |
| # Kaiming initialization | |
| std = np.sqrt(2.0 / in_features) | |
| self.weight = np.random.randn(out_features, in_features) * std | |
| self.bias = np.zeros(out_features) if bias else None | |
| # Adam state | |
| self.m_w = np.zeros_like(self.weight) | |
| self.v_w = np.zeros_like(self.weight) | |
| self.m_b = np.zeros_like(self.bias) if bias else None | |
| self.v_b = np.zeros_like(self.bias) if bias else None | |
| def forward(self, x: np.ndarray) -> np.ndarray: | |
| """Standard forward: y = xW^T + b""" | |
| out = x @ self.weight.T | |
| if self.bias is not None: | |
| out = out + self.bias | |
| return out | |
| def generate_noise( | |
| self, seed: int, rank: int | |
| ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]: | |
| """Generate A, B, and bias noise from seed.""" | |
| # Ensure seed is within valid range for numpy | |
| rng = np.random.RandomState(seed % (2**32 - 1)) | |
| A = rng.randn(self.out_features, rank) | |
| B = rng.randn(self.in_features, rank) | |
| E_b = rng.randn(self.out_features) if self.has_bias else None | |
| return A, B, E_b | |
| def forward_perturbed( | |
| self, | |
| x: np.ndarray, | |
| A: np.ndarray, | |
| B: np.ndarray, | |
| E_b: Optional[np.ndarray], | |
| sigma: float, | |
| rank: int, | |
| ) -> np.ndarray: | |
| """ | |
| Forward with low-rank perturbation. | |
| Paper Section 4.3: y = xμ + (σ/√r)(xB)A^T | |
| Args: | |
| x: (Batch, In) or (Pop, Batch, In) | |
| A: (Out, Rank) or (Pop, Out, Rank) | |
| B: (In, Rank) or (Pop, In, Rank) | |
| """ | |
| # Base: xW^T + b | |
| base = x @ self.weight.T | |
| if self.bias is not None: | |
| base = base + self.bias | |
| scale = sigma / np.sqrt(rank) | |
| # Perturbation | |
| if A.ndim == 2: | |
| # Single perturbation | |
| xB = x @ B # (..., Rank) | |
| pert = xB @ A.T # (..., Out) | |
| else: | |
| # Batched: A is (Pop, Out, Rank), B is (Pop, In, Rank) | |
| if x.ndim == 2: | |
| # x: (Batch, In) -> broadcast | |
| xB = np.einsum("bi,pir->pbr", x, B) | |
| else: | |
| # x: (Pop, Batch, In) | |
| xB = np.einsum("pbi,pir->pbr", x, B) | |
| pert = np.einsum("pbr,por->pbo", xB, A) | |
| out = base + scale * pert | |
| # Bias perturbation | |
| if E_b is not None: | |
| if E_b.ndim == 1: | |
| out = out + sigma * E_b | |
| else: | |
| # E_b: (Pop, Out) -> (Pop, 1, Out) for broadcast | |
| out = out + sigma * E_b[:, np.newaxis, :] | |
| return out | |
| class EggrollMLP: | |
| """Simple 2-layer MLP with EGGROLL.""" | |
| def __init__(self, in_dim: int, hidden_dim: int, out_dim: int): | |
| self.layer1 = EggrollLinear(in_dim, hidden_dim) | |
| self.layer2 = EggrollLinear(hidden_dim, out_dim) | |
| self.layers = [self.layer1, self.layer2] | |
| def forward(self, x: np.ndarray) -> np.ndarray: | |
| """Standard forward pass with mean parameters.""" | |
| h = relu(self.layer1.forward(x)) | |
| return self.layer2.forward(h) | |
| def forward_population( | |
| self, | |
| x: np.ndarray, | |
| seeds: np.ndarray, | |
| config: EggrollConfig, | |
| ) -> np.ndarray: | |
| """ | |
| Batched forward for entire population. | |
| Paper Section 4.3: Efficient batched inference. | |
| """ | |
| pop_size = len(seeds) | |
| # Generate all noise | |
| all_noise = [] | |
| for layer_idx, layer in enumerate(self.layers): | |
| A_batch = np.empty((pop_size, layer.out_features, config.rank)) | |
| B_batch = np.empty((pop_size, layer.in_features, config.rank)) | |
| Eb_batch = ( | |
| np.empty((pop_size, layer.out_features)) if layer.has_bias else None | |
| ) | |
| for i, seed in enumerate(seeds): | |
| # Derive layer-specific seed (avoid overflow) | |
| layer_seed = (int(seed) + layer_idx * 997) % (2**32 - 1) | |
| A, B, E_b = layer.generate_noise(layer_seed, config.rank) | |
| A_batch[i] = A | |
| B_batch[i] = B | |
| if Eb_batch is not None: | |
| Eb_batch[i] = E_b | |
| all_noise.append((A_batch, B_batch, Eb_batch)) | |
| # Forward pass | |
| # Layer 1 | |
| A1, B1, Eb1 = all_noise[0] | |
| h = self.layer1.forward_perturbed(x, A1, B1, Eb1, config.sigma, config.rank) | |
| h = relu(h) # (Pop, Batch, Hidden) | |
| # Layer 2 | |
| A2, B2, Eb2 = all_noise[1] | |
| out = self.layer2.forward_perturbed(h, A2, B2, Eb2, config.sigma, config.rank) | |
| return out # (Pop, Batch, Out) | |
| def compute_gradient( | |
| layer: EggrollLinear, | |
| seeds: np.ndarray, | |
| fitnesses: np.ndarray, | |
| layer_idx: int, | |
| config: EggrollConfig, | |
| ) -> Tuple[np.ndarray, Optional[np.ndarray]]: | |
| """ | |
| Compute EGGROLL pseudo-gradient for one layer. | |
| Paper Eq. 8: Δμ ∝ (1/N) Σ E_i f_i | |
| With E_i = (1/√r) A_i B_i^T | |
| Paper Section 4.3 optimization: | |
| "calculate the expression as (A ⊙ f)B^T" | |
| """ | |
| pop_size = len(seeds) | |
| rank = config.rank | |
| sigma = config.sigma | |
| # Reconstruct noise and accumulate | |
| A_all = np.empty((pop_size, layer.out_features, rank)) | |
| B_all = np.empty((pop_size, layer.in_features, rank)) | |
| Eb_all = np.empty((pop_size, layer.out_features)) if layer.has_bias else None | |
| for i, seed in enumerate(seeds): | |
| layer_seed = (int(seed) + layer_idx * 997) % (2**32 - 1) | |
| A, B, E_b = layer.generate_noise(layer_seed, rank) | |
| A_all[i] = A | |
| B_all[i] = B | |
| if Eb_all is not None: | |
| Eb_all[i] = E_b | |
| # Paper Section 4.3: (A ⊙ f)^T @ B | |
| # fitnesses: (Pop,) -> (Pop, 1, 1) | |
| f_view = fitnesses.reshape(-1, 1, 1) | |
| A_weighted = A_all * f_view # (Pop, Out, Rank) | |
| # Sum: Σ A_i B_i^T f_i = einsum('por,pir->oi', A_weighted, B_all) | |
| grad_w = np.einsum("por,pir->oi", A_weighted, B_all) | |
| # Scale: 1/(N * σ * √r) per original PyTorch code | |
| scale = 1.0 / (pop_size * sigma * np.sqrt(rank)) | |
| grad_w *= scale | |
| # Bias gradient | |
| grad_b = None | |
| if Eb_all is not None: | |
| # (Pop,) @ (Pop, Out) -> (Out,) | |
| grad_b = fitnesses @ Eb_all | |
| grad_b *= 1.0 / (pop_size * sigma) | |
| return grad_w, grad_b | |
| def adam_update( | |
| param: np.ndarray, | |
| grad: np.ndarray, | |
| m: np.ndarray, | |
| v: np.ndarray, | |
| lr: float, | |
| t: int, | |
| beta1: float = 0.9, | |
| beta2: float = 0.999, | |
| eps: float = 1e-8, | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| """Adam optimizer step.""" | |
| m = beta1 * m + (1 - beta1) * grad | |
| v = beta2 * v + (1 - beta2) * (grad**2) | |
| m_hat = m / (1 - beta1**t) | |
| v_hat = v / (1 - beta2**t) | |
| # Add gradient (we're maximizing fitness, grad is already positive direction) | |
| param = param + lr * m_hat / (np.sqrt(v_hat) + eps) | |
| return param, m, v | |
| def train_xor(run_id: int, verbose: bool = True) -> bool: | |
| """Train on XOR problem.""" | |
| if verbose: | |
| print(f"\n--- Run {run_id}: XOR Training ---") | |
| # Data | |
| inputs = np.array([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]]) | |
| targets = np.array([[0.0], [1.0], [1.0], [0.0]]) | |
| # Model | |
| np.random.seed(run_id * 12345) # Different init per run | |
| model = EggrollMLP(2, 16, 1) | |
| # Config | |
| config = EggrollConfig( | |
| rank=1, | |
| sigma=0.1, | |
| learning_rate=0.02, | |
| pop_size=512, | |
| use_rank_transform=True, | |
| ) | |
| t = 0 # Adam timestep | |
| for gen in range(301): | |
| # Generate seeds for this generation (avoid overflow) | |
| base_seed = ((run_id * 10000) + gen * config.pop_size) % (2**31) | |
| seeds = np.arange(base_seed, base_seed + config.pop_size) | |
| # Population forward | |
| pop_outputs = model.forward_population(inputs, seeds, config) | |
| # pop_outputs: (Pop, Batch, 1) | |
| # Fitness: negative MSE | |
| diff = pop_outputs - targets[np.newaxis, :, :] | |
| mse = (diff**2).mean(axis=(1, 2)) | |
| fitnesses = -mse | |
| # Fitness shaping | |
| if config.use_rank_transform: | |
| shaped_fitness = centered_rank_transform(fitnesses) | |
| else: | |
| shaped_fitness = fitnesses | |
| # Update each layer | |
| t += 1 | |
| for layer_idx, layer in enumerate(model.layers): | |
| grad_w, grad_b = compute_gradient( | |
| layer, seeds, shaped_fitness, layer_idx, config | |
| ) | |
| # Adam update (maximize, so we add) | |
| layer.weight, layer.m_w, layer.v_w = adam_update( | |
| layer.weight, grad_w, layer.m_w, layer.v_w, config.learning_rate, t | |
| ) | |
| if grad_b is not None: | |
| layer.bias, layer.m_b, layer.v_b = adam_update( | |
| layer.bias, grad_b, layer.m_b, layer.v_b, config.learning_rate, t | |
| ) | |
| # Evaluate mean parameters | |
| if gen % 20 == 0 or gen < 5: | |
| pred = model.forward(inputs) | |
| loss = ((pred - targets) ** 2).mean() | |
| if verbose: | |
| print(f"Gen {gen:3d}: MSE = {loss:.4f}") | |
| if loss < 0.005: | |
| if verbose: | |
| print(f"Run {run_id}: Converged at Gen {gen}!") | |
| return True | |
| # Final check | |
| pred = model.forward(inputs) | |
| loss = ((pred - targets) ** 2).mean() | |
| if loss < 0.01: | |
| if verbose: | |
| print(f"Run {run_id}: Converged at Final (Loss {loss:.4f})") | |
| return True | |
| if verbose: | |
| print(f"Run {run_id}: Failed (Loss {loss:.4f})") | |
| return False | |
| if __name__ == "__main__": | |
| successes = 0 | |
| total = 5 | |
| for i in range(total): | |
| if train_xor(i + 1): | |
| successes += 1 | |
| print(f"\n{'=' * 40}") | |
| print(f"Results: {successes}/{total} converged") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| EGGROLL: Evolution Guided General Optimization via Low-rank Learning | |
| PyTorch Implementation - Based on audited NumPy version | |
| Paper: arXiv:2511.16652v1 [cs.LG] | |
| Key equations (with paper references): | |
| - Section 4.1: E = (1/√r) AB^T (low-rank perturbation) | |
| - Section 4.3: y = xμ + (σ/√r)(xB)A^T (efficient forward) | |
| - Proposition 1: ∇J = -(1/σ) E[E·f] (ES gradient) | |
| - Section 4.3: Σ E_i f_i = (A ⊙ f)B^T (efficient update) | |
| - Table 1: optimizer ∈ {adam, adamw, sgd} (Adam is valid) | |
| """ | |
| import math | |
| from dataclasses import dataclass | |
| from typing import Tuple, Optional | |
| import torch | |
| import torch.nn as nn | |
| @dataclass | |
| class EggrollConfig: | |
| """ | |
| Configuration for EGGROLL training. | |
| Hyperparameters from Table 1. | |
| """ | |
| rank: int = 1 # r, Paper shows r=1 is sufficient (Section 5, Figure 3) | |
| sigma: float = 0.1 # Noise scale σ (Table 1: 0.05, 0.2, 0.5) | |
| learning_rate: float = 0.02 # α (Table 1: 1e-3, 1e-2, 1e-1) | |
| pop_size: int = 512 # N_workers (Paper tests up to 262,144) | |
| use_rank_transform: bool = True # Table 1: rank_transform | |
| def centered_rank_transform(fitnesses: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Centered Rank Transformation (Table 1: rank_transform). | |
| Maps fitness values to [-0.5, 0.5] based on rank ordering. | |
| Args: | |
| fitnesses: Tensor of shape (PopSize,) | |
| Returns: | |
| Tensor of shape (PopSize,) with values in [-0.5, 0.5] | |
| """ | |
| n = len(fitnesses) | |
| ranks = torch.argsort(torch.argsort(fitnesses)).float() | |
| return (ranks / (n - 1)) - 0.5 if n > 1 else torch.zeros_like(fitnesses) | |
| def get_layer_seed(base_seed: int, layer_idx: int) -> int: | |
| """ | |
| Derive layer-specific seed to decorrelate perturbations across layers. | |
| Paper Algorithm 1: "workers with known random seeds ς" | |
| """ | |
| return (base_seed + layer_idx * 997) % (2**31) | |
| class EggrollLinear(nn.Module): | |
| """ | |
| EGGROLL-enabled Linear layer. | |
| Paper Section 4.3: | |
| "x_i(μ + σE_i) = x_iμ + (σ/√r)(x_i B_i)A_i^T" | |
| """ | |
| def __init__(self, in_features: int, out_features: int, bias: bool = True): | |
| super().__init__() | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.has_bias = bias | |
| # Mean parameters μ (what we optimize) | |
| self.weight = nn.Parameter(torch.empty(out_features, in_features)) | |
| if bias: | |
| self.bias = nn.Parameter(torch.empty(out_features)) | |
| else: | |
| self.register_parameter("bias", None) | |
| self._reset_parameters() | |
| def _reset_parameters(self): | |
| nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) | |
| if self.bias is not None: | |
| fan_in = self.weight.size(1) | |
| bound = 1 / math.sqrt(fan_in) | |
| nn.init.uniform_(self.bias, -bound, bound) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Standard forward with mean parameters (inference).""" | |
| out = x @ self.weight.T | |
| if self.bias is not None: | |
| out = out + self.bias | |
| return out | |
| def generate_noise( | |
| self, | |
| seed: int, | |
| rank: int, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | |
| """ | |
| Generate A, B, and bias noise deterministically from seed. | |
| Paper Section 4.1: "sample A ∈ R^{m×r} and B ∈ R^{n×r}" | |
| Paper Section 4: "counter-based deterministic RNG to reconstruct noise on demand" | |
| """ | |
| g = torch.Generator(device="cpu").manual_seed(seed) | |
| A = torch.randn(self.out_features, rank, generator=g, dtype=dtype).to(device) | |
| B = torch.randn(self.in_features, rank, generator=g, dtype=dtype).to(device) | |
| E_b = None | |
| if self.has_bias: | |
| E_b = torch.randn(self.out_features, generator=g, dtype=dtype).to(device) | |
| return A, B, E_b | |
| def forward_perturbed( | |
| self, | |
| x: torch.Tensor, | |
| A: torch.Tensor, | |
| B: torch.Tensor, | |
| E_b: Optional[torch.Tensor], | |
| sigma: float, | |
| rank: int, | |
| ) -> torch.Tensor: | |
| """ | |
| Forward pass with low-rank perturbation. | |
| Paper Section 4.3: | |
| "y = xμ + (σ/√r)(xB)A^T" | |
| Args: | |
| x: Input (Batch, In) or (Pop, Batch, In) | |
| A: (Pop, Out, Rank) | |
| B: (Pop, In, Rank) | |
| E_b: (Pop, Out) or None | |
| sigma: Noise scale σ | |
| rank: Rank r | |
| Returns: | |
| Output tensor (Pop, Batch, Out) | |
| """ | |
| # Base computation: xW^T + b | |
| base = x @ self.weight.T | |
| if self.bias is not None: | |
| base = base + self.bias | |
| # Paper Section 4.3: scale is σ/√r | |
| scale = sigma / math.sqrt(rank) | |
| # Perturbation: (σ/√r)(xB)A^T | |
| if x.dim() == 2: | |
| # x: (Batch, In) -> broadcast across population | |
| # B: (Pop, In, Rank) | |
| xB = torch.einsum("bi,pir->pbr", x, B) | |
| else: | |
| # x: (Pop, Batch, In) | |
| xB = torch.einsum("pbi,pir->pbr", x, B) | |
| # A: (Pop, Out, Rank) | |
| perturbation = torch.einsum("pbr,por->pbo", xB, A) | |
| out = base + scale * perturbation | |
| # Bias perturbation | |
| if E_b is not None: | |
| # E_b: (Pop, Out) -> (Pop, 1, Out) for broadcast | |
| out = out + sigma * E_b.unsqueeze(1) | |
| return out | |
| class EggrollMLP(nn.Module): | |
| """ | |
| Simple MLP with EGGROLL-enabled layers. | |
| """ | |
| def __init__(self, in_dim: int, hidden_dim: int, out_dim: int): | |
| super().__init__() | |
| self.layer1 = EggrollLinear(in_dim, hidden_dim) | |
| self.layer2 = EggrollLinear(hidden_dim, out_dim) | |
| self.layers = [self.layer1, self.layer2] | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Standard forward pass with mean parameters.""" | |
| h = torch.relu(self.layer1(x)) | |
| return self.layer2(h) | |
| def forward_population( | |
| self, | |
| x: torch.Tensor, | |
| seeds: torch.Tensor, | |
| config: EggrollConfig, | |
| ) -> torch.Tensor: | |
| """ | |
| Batched forward for entire population. | |
| Paper Section 4.3: "batches a population of low-rank adapters | |
| and shares the base activations" | |
| Args: | |
| x: Input (Batch, In) | |
| seeds: Population seeds (Pop,) | |
| config: EGGROLL config | |
| Returns: | |
| Outputs (Pop, Batch, Out) | |
| """ | |
| pop_size = len(seeds) | |
| device = x.device | |
| dtype = x.dtype | |
| # Pre-generate all noise for all layers | |
| all_noise = [] | |
| for layer_idx, layer in enumerate(self.layers): | |
| A_batch = torch.empty( | |
| pop_size, layer.out_features, config.rank, device=device, dtype=dtype | |
| ) | |
| B_batch = torch.empty( | |
| pop_size, layer.in_features, config.rank, device=device, dtype=dtype | |
| ) | |
| Eb_batch = None | |
| if layer.has_bias: | |
| Eb_batch = torch.empty( | |
| pop_size, layer.out_features, device=device, dtype=dtype | |
| ) | |
| for i, seed in enumerate(seeds): | |
| layer_seed = get_layer_seed(int(seed.item()), layer_idx) | |
| A, B, E_b = layer.generate_noise(layer_seed, config.rank, device, dtype) | |
| A_batch[i] = A | |
| B_batch[i] = B | |
| if Eb_batch is not None: | |
| Eb_batch[i] = E_b | |
| all_noise.append((A_batch, B_batch, Eb_batch)) | |
| # Forward pass through layers | |
| # Layer 1 + ReLU | |
| A1, B1, Eb1 = all_noise[0] | |
| h = self.layer1.forward_perturbed(x, A1, B1, Eb1, config.sigma, config.rank) | |
| h = torch.relu(h) # (Pop, Batch, Hidden) | |
| # Layer 2 | |
| A2, B2, Eb2 = all_noise[1] | |
| out = self.layer2.forward_perturbed(h, A2, B2, Eb2, config.sigma, config.rank) | |
| return out # (Pop, Batch, Out) | |
| def compute_eggroll_gradient( | |
| layer: EggrollLinear, | |
| seeds: torch.Tensor, | |
| fitnesses: torch.Tensor, | |
| layer_idx: int, | |
| config: EggrollConfig, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| """ | |
| Compute EGGROLL pseudo-gradient for one layer. | |
| Paper Eq. 8: "μ_{t+1} = μ_t + (α/N) Σ E_{i,t} f_i" | |
| With E_i = (1/√r) A_i B_i^T | |
| Paper Proposition 1: "∇J = -(1/σ) E[E·f]" | |
| Paper Section 4.3 optimization: | |
| "calculate the expression as (A ⊙ f)B^T" | |
| Args: | |
| layer: The EGGROLL linear layer | |
| seeds: Population seeds (Pop,) | |
| fitnesses: Shaped fitness values (Pop,) | |
| layer_idx: Index for seed derivation | |
| config: EGGROLL config | |
| Returns: | |
| (grad_weight, grad_bias) tensors | |
| """ | |
| pop_size = len(seeds) | |
| rank = config.rank | |
| sigma = config.sigma | |
| device = layer.weight.device | |
| dtype = layer.weight.dtype | |
| # Reconstruct noise | |
| A_all = torch.empty(pop_size, layer.out_features, rank, device=device, dtype=dtype) | |
| B_all = torch.empty(pop_size, layer.in_features, rank, device=device, dtype=dtype) | |
| Eb_all = None | |
| if layer.has_bias: | |
| Eb_all = torch.empty(pop_size, layer.out_features, device=device, dtype=dtype) | |
| for i, seed in enumerate(seeds): | |
| layer_seed = get_layer_seed(int(seed.item()), layer_idx) | |
| A, B, E_b = layer.generate_noise(layer_seed, rank, device, dtype) | |
| A_all[i] = A | |
| B_all[i] = B | |
| if Eb_all is not None: | |
| Eb_all[i] = E_b | |
| # Paper Section 4.3: "(A ⊙ f)B^T" | |
| # fitnesses: (Pop,) -> (Pop, 1, 1) | |
| f_view = fitnesses.view(-1, 1, 1) | |
| A_weighted = A_all * f_view # (Pop, Out, Rank) | |
| # Σ A_i B_i^T f_i via einsum | |
| grad_w = torch.einsum("por,pir->oi", A_weighted, B_all) | |
| # Scale: 1/(N·σ·√r) from Proposition 1 | |
| # The 1/σ comes from ∇J = -(1/σ)E[E·f] | |
| # The 1/√r comes from E = (1/√r)AB^T | |
| scale = 1.0 / (pop_size * sigma * math.sqrt(rank)) | |
| grad_w = grad_w * scale | |
| # Bias gradient | |
| grad_b = None | |
| if Eb_all is not None: | |
| # (Pop,) @ (Pop, Out) -> (Out,) | |
| grad_b = fitnesses @ Eb_all | |
| grad_b = grad_b * (1.0 / (pop_size * sigma)) | |
| return grad_w, grad_b | |
| class EggrollOptimizer: | |
| """ | |
| EGGROLL optimizer with Adam (Table 1: optimizer ∈ {adam, adamw, sgd}). | |
| Handles pseudo-gradient computation and parameter updates. | |
| """ | |
| def __init__(self, model: EggrollMLP, config: EggrollConfig): | |
| self.model = model | |
| self.config = config | |
| # Adam state for each layer | |
| self.state = {} | |
| for idx, layer in enumerate(model.layers): | |
| self.state[f"layer{idx}_w"] = { | |
| "m": torch.zeros_like(layer.weight), | |
| "v": torch.zeros_like(layer.weight), | |
| } | |
| if layer.bias is not None: | |
| self.state[f"layer{idx}_b"] = { | |
| "m": torch.zeros_like(layer.bias), | |
| "v": torch.zeros_like(layer.bias), | |
| } | |
| self.t = 0 | |
| self.beta1 = 0.9 | |
| self.beta2 = 0.999 | |
| self.eps = 1e-8 | |
| def _adam_update( | |
| self, | |
| param: torch.Tensor, | |
| grad: torch.Tensor, | |
| state: dict, | |
| ) -> None: | |
| """In-place Adam update.""" | |
| state["m"].mul_(self.beta1).add_(grad, alpha=1 - self.beta1) | |
| state["v"].mul_(self.beta2).addcmul_(grad, grad, value=1 - self.beta2) | |
| m_hat = state["m"] / (1 - self.beta1**self.t) | |
| v_hat = state["v"] / (1 - self.beta2**self.t) | |
| # Maximize fitness: add gradient (not subtract) | |
| param.data.add_( | |
| m_hat / (v_hat.sqrt() + self.eps), alpha=self.config.learning_rate | |
| ) | |
| def step(self, seeds: torch.Tensor, fitnesses: torch.Tensor) -> None: | |
| """ | |
| Perform one EGGROLL optimization step. | |
| Paper Algorithm 1: | |
| "μ ← μ + α · (1/N_workers) · Σ E_j · f_j" | |
| Args: | |
| seeds: Population seeds used in forward pass (Pop,) | |
| fitnesses: Raw fitness values (Pop,) | |
| """ | |
| # Fitness shaping (Table 1: rank_transform) | |
| if self.config.use_rank_transform: | |
| shaped_fitness = centered_rank_transform(fitnesses) | |
| else: | |
| shaped_fitness = fitnesses | |
| self.t += 1 | |
| # Update each layer | |
| for layer_idx, layer in enumerate(self.model.layers): | |
| grad_w, grad_b = compute_eggroll_gradient( | |
| layer, seeds, shaped_fitness, layer_idx, self.config | |
| ) | |
| # Adam update for weight | |
| self._adam_update( | |
| layer.weight, | |
| grad_w, | |
| self.state[f"layer{layer_idx}_w"], | |
| ) | |
| # Adam update for bias | |
| if grad_b is not None: | |
| self._adam_update( | |
| layer.bias, | |
| grad_b, | |
| self.state[f"layer{layer_idx}_b"], | |
| ) | |
| def train_xor(run_id: int, verbose: bool = True) -> bool: | |
| """ | |
| Train on XOR problem using EGGROLL. | |
| Args: | |
| run_id: Run identifier for different random seeds | |
| verbose: Whether to print progress | |
| Returns: | |
| True if converged, False otherwise | |
| """ | |
| if verbose: | |
| print(f"\n--- Run {run_id}: XOR Training ---") | |
| # Device setup | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Data | |
| inputs = torch.tensor( | |
| [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]], device=device | |
| ) | |
| targets = torch.tensor([[0.0], [1.0], [1.0], [0.0]], device=device) | |
| # Model initialization with different seed per run | |
| torch.manual_seed(run_id * 12345) | |
| model = EggrollMLP(2, 16, 1).to(device) | |
| # Config | |
| config = EggrollConfig( | |
| rank=1, | |
| sigma=0.1, | |
| learning_rate=0.02, | |
| pop_size=512, | |
| use_rank_transform=True, | |
| ) | |
| optimizer = EggrollOptimizer(model, config) | |
| if verbose: | |
| print( | |
| f"Config: rank={config.rank}, σ={config.sigma}, lr={config.learning_rate}" | |
| ) | |
| print(f"Population size: {config.pop_size}") | |
| for gen in range(301): | |
| # Generate seeds for this generation | |
| base_seed = ((run_id * 10000) + gen * config.pop_size) % (2**31) | |
| seeds = torch.arange(base_seed, base_seed + config.pop_size, device=device) | |
| # Population forward pass | |
| with torch.no_grad(): | |
| pop_outputs = model.forward_population(inputs, seeds, config) | |
| # pop_outputs: (Pop, Batch, 1) | |
| # Fitness: negative MSE (maximize) | |
| diff = pop_outputs - targets.unsqueeze(0) | |
| mse = (diff**2).mean(dim=(1, 2)) | |
| fitnesses = -mse | |
| # Optimization step | |
| optimizer.step(seeds, fitnesses) | |
| # Evaluate mean parameters | |
| if gen % 20 == 0: | |
| with torch.no_grad(): | |
| pred = model(inputs) | |
| loss = ((pred - targets) ** 2).mean().item() | |
| if verbose: | |
| print(f"Gen {gen:3d}: MSE = {loss:.4f}") | |
| if loss < 0.005: | |
| if verbose: | |
| print(f"Run {run_id}: Converged at Gen {gen}!") | |
| return True | |
| # Final check | |
| with torch.no_grad(): | |
| pred = model(inputs) | |
| loss = ((pred - targets) ** 2).mean().item() | |
| if loss < 0.01: | |
| if verbose: | |
| print(f"Run {run_id}: Converged at Final (Loss {loss:.4f})") | |
| return True | |
| if verbose: | |
| print(f"Run {run_id}: Failed (Loss {loss:.4f})") | |
| return False | |
| if __name__ == "__main__": | |
| successes = 0 | |
| total = 5 | |
| for i in range(total): | |
| if train_xor(i + 1): | |
| successes += 1 | |
| print(f"\n{'=' * 40}") | |
| print(f"Results: {successes}/{total} converged") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment