Last active
April 2, 2026 05:43
-
-
Save CypherpunkSamurai/30b5ebb8eb7329421f0bdf90aca166e9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import pygame | |
| import random | |
| import sys | |
| import os | |
| import urllib.request | |
| import math | |
| import numpy as np | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # SPIKING NEURAL NETWORK ─ LIF + RSTDP with eligibility traces | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| class LIFNeuron: | |
| """Leaky Integrate-and-Fire neuron.""" | |
| def __init__(self, n, tau_m=20.0, v_thresh=1.0, v_reset=0.0, v_rest=0.0): | |
| self.n = n | |
| self.tau_m = tau_m # membrane time constant (ms) | |
| self.v_thresh = v_thresh # firing threshold | |
| self.v_reset = v_reset # reset potential after spike | |
| self.v_rest = v_rest # resting potential | |
| self.v = np.full(n, v_rest, dtype=np.float32) | |
| self.spikes = np.zeros(n, dtype=np.float32) | |
| # Spike traces for STDP (low-pass filtered spike train) | |
| self.trace = np.zeros(n, dtype=np.float32) | |
| self.tau_trace = 20.0 # trace decay time constant | |
| def step(self, I_in, dt=1.0): | |
| """One timestep update. I_in: input current vector.""" | |
| # Membrane update (Euler): τ dV/dt = -(V-V_rest) + I | |
| decay = np.exp(-dt / self.tau_m) | |
| self.v = self.v * decay + self.v_rest * (1 - decay) + I_in * (dt / self.tau_m) | |
| # Spike detection | |
| self.spikes = (self.v >= self.v_thresh).astype(np.float32) | |
| # Reset fired neurons | |
| self.v = np.where(self.spikes > 0, self.v_reset, self.v) | |
| # Update spike trace: x += -x/tau + spike | |
| trace_decay = np.exp(-dt / self.tau_trace) | |
| self.trace = self.trace * trace_decay + self.spikes | |
| return self.spikes | |
| def reset_state(self): | |
| self.v[:] = self.v_rest | |
| self.spikes[:] = 0 | |
| self.trace[:] = 0 | |
| class SNNLayer: | |
| """Feedforward SNN layer with RSTDP eligibility traces.""" | |
| def __init__(self, n_pre, n_post, tau_e=100.0, w_min=-2.0, w_max=2.0): | |
| self.n_pre = n_pre | |
| self.n_post = n_post | |
| self.tau_e = tau_e # eligibility trace decay (ms) | |
| self.w_min = w_min | |
| self.w_max = w_max | |
| # Xavier-like init | |
| scale = np.sqrt(2.0 / (n_pre + n_post)) | |
| self.W = np.random.randn(n_post, n_pre).astype(np.float32) * scale | |
| # Eligibility traces: shape (n_post, n_pre) | |
| self.e = np.zeros((n_post, n_pre), dtype=np.float32) | |
| def forward(self, pre_spikes): | |
| """Compute input current for post layer.""" | |
| return self.W @ pre_spikes # (n_post,) | |
| def update_eligibility(self, pre_trace, pre_spikes, post_trace, post_spikes, dt=1.0): | |
| """ | |
| STDP-based eligibility trace update. | |
| e_ij decays and is incremented by: | |
| +A_plus * pre_trace_i * post_spike_j (LTP: pre before post) | |
| -A_minus * post_trace_j * pre_spike_i (LTD: post before pre) | |
| """ | |
| A_plus = 0.02 | |
| A_minus = 0.02 | |
| e_decay = np.exp(-dt / self.tau_e) | |
| # LTP term: outer product of post_spikes and pre_trace | |
| ltp = np.outer(post_spikes, pre_trace) * A_plus | |
| # LTD term: outer product of post_trace and pre_spikes | |
| ltd = np.outer(post_trace, pre_spikes) * A_minus | |
| self.e = self.e * e_decay + ltp - ltd | |
| def apply_dopamine(self, dopamine, eta=0.01): | |
| """Apply reward signal to update weights via eligibility traces.""" | |
| dW = eta * dopamine * self.e | |
| self.W = np.clip(self.W + dW, self.w_min, self.w_max) | |
| def reset_state(self): | |
| self.e[:] = 0.0 | |
| class SNNAgent: | |
| """ | |
| Online RL SNN agent using LIF neurons + RSTDP with eligibility traces. | |
| Architecture: | |
| Input (5 rate-coded neurons) → Hidden (16 LIF) → Output (2 LIF: [no-flap, flap]) | |
| Learning: | |
| Three-factor RSTDP: | |
| ΔW = η × d(t) × e_ij(t) | |
| where d(t) is the dopamine/reward trace and e_ij is the synaptic | |
| eligibility trace capturing recent pre/post spike coincidences. | |
| Inputs (rate-coded, normalised 0→1): | |
| 0: bird_y / screen_height | |
| 1: (bird_vy + 10) / 20 normalised velocity | |
| 2: dist_to_next_pipe / screen_w | |
| 3: pipe_center_y / screen_h | |
| 4: gap_top_y / screen_h (top of the gap) | |
| """ | |
| N_IN = 5 | |
| N_HID = 16 | |
| N_OUT = 2 | |
| T_SIM = 8 # simulation timesteps per game-frame | |
| FIRE_RATE_MAX = 0.8 | |
| def __init__(self, screen_w, screen_h): | |
| self.screen_w = screen_w | |
| self.screen_h = screen_h | |
| self.layer1 = SNNLayer(self.N_IN, self.N_HID, tau_e=80.0) | |
| self.layer2 = SNNLayer(self.N_HID, self.N_OUT, tau_e=80.0) | |
| self.hid = LIFNeuron(self.N_HID, tau_m=15.0, v_thresh=1.0) | |
| self.out = LIFNeuron(self.N_OUT, tau_m=15.0, v_thresh=1.0) | |
| # Dopamine trace | |
| self.dopamine = 0.0 | |
| self.tau_dopa = 60.0 # ms – dopamine decay | |
| # Accumulate spikes for this frame | |
| self.last_out_spikes = np.zeros(self.N_OUT) | |
| self.last_hid_spikes = np.zeros(self.N_HID) | |
| # History for visualisation | |
| self.spike_history_hid = np.zeros((self.N_HID, 60), dtype=np.float32) | |
| self.spike_history_out = np.zeros((self.N_OUT, 60), dtype=np.float32) | |
| self.dopa_history = np.zeros(60, dtype=np.float32) | |
| self.score_history = [] | |
| self.total_reward = 0.0 | |
| self.frames_alive = 0 | |
| self.episode_count = 0 | |
| self.episode_rewards = [] | |
| self.eta = 0.005 # learning rate | |
| self.explore_eps = 0.05 # random exploration rate | |
| # Keep encoder for reuse | |
| self._rng = np.random.default_rng() | |
| # ── encoding ────────────────────────────────────────────────────────────── | |
| def _encode(self, obs): | |
| """Rate-code observation into spike probabilities.""" | |
| return np.clip(obs, 0.0, 1.0).astype(np.float32) * self.FIRE_RATE_MAX | |
| # ── forward pass over T_SIM steps ───────────────────────────────────────── | |
| def forward(self, obs): | |
| rates = self._encode(obs) | |
| out_acc = np.zeros(self.N_OUT) | |
| for _ in range(self.T_SIM): | |
| # Poisson input spikes | |
| in_spikes = (self._rng.random(self.N_IN) < rates).astype(np.float32) | |
| # Layer 1 forward | |
| I_hid = self.layer1.forward(in_spikes) | |
| hid_spikes = self.hid.step(I_hid) | |
| # Layer 2 forward | |
| I_out = self.layer2.forward(hid_spikes) | |
| out_spikes = self.out.step(I_out) | |
| # Update eligibility traces | |
| self.layer1.update_eligibility( | |
| in_spikes, in_spikes, # pre: spikes≈trace for input layer | |
| self.hid.trace, hid_spikes | |
| ) | |
| self.layer2.update_eligibility( | |
| self.hid.trace, hid_spikes, | |
| self.out.trace, out_spikes | |
| ) | |
| out_acc += out_spikes | |
| self.last_out_spikes = out_acc | |
| self.last_hid_spikes = self.hid.spikes.copy() | |
| # Roll spike history | |
| self.spike_history_hid = np.roll(self.spike_history_hid, -1, axis=1) | |
| self.spike_history_hid[:, -1] = self.last_hid_spikes | |
| self.spike_history_out = np.roll(self.spike_history_out, -1, axis=1) | |
| self.spike_history_out[:, -1] = self.last_out_spikes > 0 | |
| # Decision: flap if output neuron 1 fires more than neuron 0 | |
| # With exploration | |
| if self._rng.random() < self.explore_eps: | |
| return self._rng.integers(0, 2) == 1 | |
| return bool(out_acc[1] > out_acc[0]) | |
| # ── reward / learning ───────────────────────────────────────────────────── | |
| def give_reward(self, r): | |
| """Inject reward as dopamine pulse.""" | |
| self.dopamine = np.clip(self.dopamine + r, -5.0, 5.0) | |
| self.total_reward += r | |
| def learn_step(self, dt=1.0): | |
| """ | |
| Apply dopamine-gated weight updates via eligibility traces. | |
| Called every game frame. | |
| """ | |
| if abs(self.dopamine) < 1e-4: | |
| return | |
| self.layer1.apply_dopamine(self.dopamine, eta=self.eta) | |
| self.layer2.apply_dopamine(self.dopamine, eta=self.eta) | |
| # Dopamine decay | |
| dopa_decay = np.exp(-dt / self.tau_dopa) | |
| self.dopamine *= dopa_decay | |
| # Log | |
| self.dopa_history = np.roll(self.dopa_history, -1) | |
| self.dopa_history[-1] = self.dopamine | |
| self.frames_alive += 1 | |
| def on_death(self, score): | |
| self.episode_count += 1 | |
| self.episode_rewards.append(self.total_reward) | |
| self.score_history.append(score) | |
| if len(self.episode_rewards) > 200: | |
| self.episode_rewards.pop(0) | |
| self.score_history.pop(0) | |
| self.total_reward = 0.0 | |
| self.frames_alive = 0 | |
| # Partial reset of neuron states (keep weights / traces) | |
| self.hid.reset_state() | |
| self.out.reset_state() | |
| # Decay eligibility traces on death | |
| self.layer1.e *= 0.1 | |
| self.layer2.e *= 0.1 | |
| # ── observation builder (called from Game) ──────────────────────────────── | |
| def observe(self, bird_y, bird_vy, pipes, screen_w, screen_h): | |
| # Find nearest pipe ahead of bird | |
| ahead = [(p.x, p.height, p.height + p.GAP) | |
| for p in pipes if p.x + p.pipe_width > 50] | |
| if ahead: | |
| ahead.sort() | |
| px, ptop, pbot = ahead[0] | |
| dist = (px - 50) / screen_w | |
| pcen = (ptop + pbot) / 2 / screen_h | |
| ptop_n = ptop / screen_h | |
| else: | |
| dist, pcen, ptop_n = 1.0, 0.5, 0.4 | |
| obs = np.array([ | |
| bird_y / screen_h, | |
| (bird_vy + 10) / 20, | |
| dist, | |
| pcen, | |
| ptop_n, | |
| ], dtype=np.float32) | |
| return obs | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # ORIGINAL GAME ASSETS / HELPERS (unchanged) | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| def cache_assets(): | |
| sprites = [ | |
| "0.png","1.png","2.png","3.png","4.png","5.png","6.png","7.png","8.png","9.png", | |
| "background-day.png","background-night.png","base.png", | |
| "bluebird-downflap.png","bluebird-midflap.png","bluebird-upflap.png", | |
| "gameover.png","message.png","pipe-green.png","pipe-red.png", | |
| "redbird-downflap.png","redbird-midflap.png","redbird-upflap.png", | |
| "yellowbird-downflap.png","yellowbird-midflap.png","yellowbird-upflap.png", | |
| ] | |
| audio_files = ["die.ogg","hit.ogg","point.ogg","swoosh.ogg","wing.ogg"] | |
| sprite_base = "https://raw.githubusercontent.com/samuelcust/flappy-bird-assets/master/sprites/" | |
| audio_base = "https://raw.githubusercontent.com/samuelcust/flappy-bird-assets/master/audio/" | |
| os.makedirs("sprites", exist_ok=True) | |
| os.makedirs("audio", exist_ok=True) | |
| downloaded = 0 | |
| for f in sprites: | |
| path = f"sprites/{f}" | |
| if not os.path.exists(path): | |
| try: urllib.request.urlretrieve(sprite_base + f, path); downloaded += 1 | |
| except: pass | |
| for f in audio_files: | |
| path = f"audio/{f}" | |
| if not os.path.exists(path): | |
| try: urllib.request.urlretrieve(audio_base + f, path); downloaded += 1 | |
| except: pass | |
| if downloaded: print(f"Downloaded {downloaded} assets") | |
| def load_sound(name): | |
| try: | |
| path = f"audio/{name}.ogg" | |
| if os.path.exists(path): | |
| return pygame.mixer.Sound(path) | |
| except: pass | |
| return None | |
| def load_image(name): | |
| try: | |
| path = f"sprites/{name}.png" | |
| if os.path.exists(path): | |
| return pygame.image.load(path).convert_alpha() | |
| except: pass | |
| return None | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # EFFECTS | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| class Particle: | |
| def __init__(self, x, y, color, velocity, lifetime, size=2): | |
| self.x, self.y = x, y | |
| self.color = color | |
| self.vx, self.vy = velocity | |
| self.lifetime = lifetime | |
| self.max_lifetime = lifetime | |
| self.size = size | |
| self.gravity = 0.2 | |
| def update(self): | |
| self.vy += self.gravity | |
| self.x += self.vx; self.y += self.vy | |
| self.lifetime -= 1 | |
| return self.lifetime > 0 | |
| def draw(self, surface, shake_x=0, shake_y=0): | |
| alpha = int(255 * (self.lifetime / self.max_lifetime)) | |
| surf = pygame.Surface((self.size, self.size), pygame.SRCALPHA) | |
| surf.fill((*self.color[:3], alpha)) | |
| surface.blit(surf, (int(self.x + shake_x), int(self.y + shake_y))) | |
| class ParticleSystem: | |
| def __init__(self): self.particles = [] | |
| def spawn_explosion(self, x, y, color=(255,200,0), count=20): | |
| for _ in range(count): | |
| a = random.uniform(0, 2*math.pi); s = random.uniform(2,8) | |
| self.particles.append(Particle(x, y, color, | |
| (math.cos(a)*s, math.sin(a)*s), random.randint(20,40), random.choice([2,3,4]))) | |
| def spawn_sparkle(self, x, y, color=(255,255,255)): | |
| self.particles.append(Particle(x, y, color, | |
| (random.uniform(-1,1), random.uniform(-2,-0.5)), 30, 2)) | |
| def spawn_bird_trail(self, x, y): | |
| if random.random() < 0.3: | |
| self.particles.append(Particle(x, y, (255,255,100,150), | |
| (random.uniform(-1,-3), random.uniform(-0.5,0.5)), 15, 2)) | |
| def update(self): | |
| self.particles = [p for p in self.particles if p.update()] | |
| def draw(self, surface, sx=0, sy=0): | |
| for p in self.particles: p.draw(surface, sx, sy) | |
| def clear(self): self.particles = [] | |
| class ScreenShake: | |
| def __init__(self): self.shake_amount = 0; self.shake_decay = 0.9 | |
| def shake(self, amount=10): self.shake_amount = amount | |
| def update(self): | |
| self.shake_amount *= self.shake_decay | |
| if self.shake_amount < 0.5: self.shake_amount = 0 | |
| def get_offset(self): | |
| if self.shake_amount == 0: return (0,0) | |
| return (random.uniform(-self.shake_amount, self.shake_amount), | |
| random.uniform(-self.shake_amount, self.shake_amount)) | |
| class ScorePopup: | |
| def __init__(self, x, y, score=1): | |
| self.x, self.y = x, y | |
| self.lifetime = 30; self.max_lifetime = 30 | |
| self.scale = 1.0; self.score = score | |
| def update(self): | |
| self.lifetime -= 1 | |
| p = 1 - (self.lifetime / self.max_lifetime) | |
| self.y -= 1 | |
| self.scale = 1.0 + p*1.5 if p < 0.3 else 1.45 - (p-0.3)*0.7 | |
| return self.lifetime > 0 | |
| def draw(self, surface, font): | |
| alpha = int(255 * (self.lifetime / self.max_lifetime)) | |
| text = font.render(f"+{self.score}", True, (255,255,100)) | |
| text.set_alpha(alpha) | |
| scaled = pygame.transform.scale(text, (int(text.get_width()*self.scale), | |
| int(text.get_height()*self.scale))) | |
| surface.blit(scaled, scaled.get_rect(center=(self.x, self.y))) | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # BIRD & PIPE (same as original) | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| class Bird: | |
| def __init__(self, sw, sh): | |
| self.screen_width = sw; self.screen_height = sh | |
| self.x = 50; self.y = sh // 2 | |
| self.velocity = 0; self.gravity = 0.25; self.jump_strength = -4.5 | |
| self.angle = 0; self.target_angle = 0 | |
| self.frames = [load_image("yellowbird-downflap"), | |
| load_image("yellowbird-midflap"), | |
| load_image("yellowbird-upflap")] | |
| if self.frames[0] is None: | |
| self.frames = [self._fallback() for _ in range(3)] | |
| self.current_frame = 0; self.animation_timer = 0; self.flap_timer = 0 | |
| self.rect = pygame.Rect(self.x, self.y, 34, 24) | |
| self.bounce_offset = 0; self.bounce_speed = 0.05; self.bounce_time = 0 | |
| def _fallback(self): | |
| s = pygame.Surface((34,24), pygame.SRCALPHA) | |
| pygame.draw.ellipse(s, (255,255,0), (0,0,34,24)) | |
| return s | |
| def jump(self): | |
| self.velocity = self.jump_strength; self.target_angle = -25; self.flap_timer = 10 | |
| def update(self, game_started=False, base_height=112): | |
| if game_started: | |
| self.velocity += self.gravity; self.y += self.velocity | |
| self.target_angle = max(-25, min(90, self.velocity*3)) | |
| else: | |
| self.bounce_time += self.bounce_speed | |
| self.bounce_offset = math.sin(self.bounce_time) * 5 | |
| self.target_angle = math.sin(self.bounce_time*2) * 5 | |
| self.angle += (self.target_angle - self.angle) * 0.1 | |
| self.animation_timer += 1 | |
| if self.flap_timer > 0: | |
| if self.animation_timer >= 3: | |
| self.animation_timer = 0; self.current_frame = (self.current_frame+1)%3 | |
| self.flap_timer -= 1 | |
| else: | |
| if self.animation_timer >= 5: | |
| self.animation_timer = 0; self.current_frame = (self.current_frame+1)%3 | |
| self.rect.y = int(self.y + self.bounce_offset) | |
| def draw(self, surface): | |
| glow = pygame.Surface((50,50), pygame.SRCALPHA) | |
| pygame.draw.ellipse(glow, (255,255,100,30), (8,8,34,34)) | |
| surface.blit(glow, (self.x-8, self.y+self.bounce_offset-5)) | |
| if self.frames[self.current_frame]: | |
| rotated = pygame.transform.rotate(self.frames[self.current_frame], -self.angle) | |
| rect = rotated.get_rect(center=(self.x+17, self.y+12+self.bounce_offset)) | |
| surface.blit(rotated, rect) | |
| class Pipe: | |
| GAP = 100; VELOCITY = 2 | |
| def __init__(self, x, sh, base_height=112): | |
| self.x = x; self.screen_height = sh; self.base_height = base_height | |
| self.height = random.randint(50, sh - self.GAP - base_height - 50) | |
| self.passed = False; self.flash_timer = 0 | |
| self.pipe_img = load_image("pipe-green") | |
| if self.pipe_img: | |
| self.pipe_width = self.pipe_img.get_width() | |
| self.pipe_height = self.pipe_img.get_height() | |
| else: | |
| self.pipe_width = 52; self.pipe_height = 320 | |
| self.top_rect = pygame.Rect(x, 0, self.pipe_width, self.height) | |
| self.bottom_rect = pygame.Rect(x, self.height+self.GAP, | |
| self.pipe_width, sh-self.height-self.GAP-base_height) | |
| def update(self): | |
| self.x -= self.VELOCITY | |
| self.top_rect.x = self.x; self.bottom_rect.x = self.x | |
| if self.flash_timer > 0: self.flash_timer -= 1 | |
| def draw(self, surface, flash=False): | |
| if self.pipe_img: | |
| top = pygame.transform.flip(self.pipe_img, False, True) | |
| if flash or self.flash_timer > 0: | |
| for img, y in [(top, self.height-self.pipe_height), | |
| (self.pipe_img, self.height+self.GAP)]: | |
| over = img.copy() | |
| over.fill((255,255,255,100), special_flags=pygame.BLEND_RGBA_ADD) | |
| surface.blit(over, (self.x, y)) | |
| else: | |
| surface.blit(top, (self.x, self.height-self.pipe_height)) | |
| surface.blit(self.pipe_img, (self.x, self.height+self.GAP)) | |
| else: | |
| pygame.draw.rect(surface, (0,200,0), self.top_rect) | |
| pygame.draw.rect(surface, (0,200,0), self.bottom_rect) | |
| def flash(self): self.flash_timer = 5 | |
| def off_screen(self): return self.x < -self.pipe_width | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # SNN VISUALISER | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| class SNNVisualiser: | |
| """ | |
| Overlay panel showing real-time SNN activity. | |
| Docked to the right of the game window. | |
| """ | |
| W = 260 | |
| def __init__(self, game_h): | |
| self.game_h = game_h | |
| self.font_sm = pygame.font.Font(None, 18) | |
| self.font_md = pygame.font.Font(None, 22) | |
| self.surf = pygame.Surface((self.W, game_h), pygame.SRCALPHA) | |
| def draw(self, agent: SNNAgent): | |
| s = self.surf | |
| s.fill((10, 10, 25, 220)) | |
| y = 8 | |
| # ── Title ── | |
| t = self.font_md.render("SNN Agent [RSTDP + e-trace]", True, (150,220,255)) | |
| s.blit(t, (6, y)); y += 22 | |
| # ── Episode stats ── | |
| ep_lbl = self.font_sm.render( | |
| f"Episode {agent.episode_count} η={agent.eta:.4f} ε={agent.explore_eps:.2f}", | |
| True, (180,180,180)) | |
| s.blit(ep_lbl, (6, y)); y += 16 | |
| if agent.score_history: | |
| best = max(agent.score_history) | |
| avg = sum(agent.score_history[-10:]) / len(agent.score_history[-10:]) | |
| stat = self.font_sm.render(f"Best: {best} Avg10: {avg:.1f}", True, (255,200,80)) | |
| s.blit(stat, (6, y)) | |
| y += 18 | |
| # ── Separator ── | |
| pygame.draw.line(s, (60,80,120), (4, y), (self.W-4, y)); y += 6 | |
| # ── Dopamine trace bar ── | |
| d_lbl = self.font_sm.render("Dopamine (reward trace)", True, (200,150,255)) | |
| s.blit(d_lbl, (6, y)); y += 14 | |
| bar_w = self.W - 12; bar_h = 20 | |
| pygame.draw.rect(s, (30,20,50), (6, y, bar_w, bar_h)) | |
| d_norm = np.clip(agent.dopamine / 5.0, -1, 1) | |
| mid = 6 + bar_w // 2 | |
| if d_norm > 0: | |
| pygame.draw.rect(s, (80,200,80), (mid, y+2, int(d_norm*(bar_w//2)), bar_h-4)) | |
| elif d_norm < 0: | |
| w = int(-d_norm * (bar_w//2)) | |
| pygame.draw.rect(s, (200,60,60), (mid-w, y+2, w, bar_h-4)) | |
| pygame.draw.line(s, (120,120,120), (mid, y), (mid, y+bar_h)) | |
| dv = self.font_sm.render(f"{agent.dopamine:+.3f}", True, (220,220,220)) | |
| s.blit(dv, (6, y+4)); y += bar_h + 6 | |
| # ── Hidden layer spike raster (last 60 frames) ── | |
| pygame.draw.line(s, (60,80,120), (4, y), (self.W-4, y)); y += 4 | |
| hl = self.font_sm.render(f"Hidden spikes ({SNNAgent.N_HID} neurons × 60 frames)", True, (150,200,150)) | |
| s.blit(hl, (6, y)); y += 13 | |
| raster_h = SNNAgent.N_HID * 4 | |
| pygame.draw.rect(s, (15,20,30), (6, y, bar_w, raster_h)) | |
| hist = agent.spike_history_hid # (N_HID, 60) | |
| cell_w = bar_w / 60 | |
| for ni in range(SNNAgent.N_HID): | |
| for ti in range(60): | |
| if hist[ni, ti] > 0: | |
| col = ( | |
| int(60 + 195 * (ni / SNNAgent.N_HID)), | |
| int(200 - 100 * (ni / SNNAgent.N_HID)), | |
| 255 | |
| ) | |
| pygame.draw.rect(s, col, | |
| (6 + int(ti*cell_w), y + ni*4, max(1, int(cell_w)), 3)) | |
| y += raster_h + 6 | |
| # ── Output neurons ── | |
| pygame.draw.line(s, (60,80,120), (4, y), (self.W-4, y)); y += 4 | |
| ol = self.font_sm.render("Output neurons", True, (255,200,100)) | |
| s.blit(ol, (6, y)); y += 14 | |
| labels = ["NO-FLAP", "FLAP"] | |
| colors_n = [(80,150,255), (255,150,80)] | |
| colors_act = [(160,200,255),(255,210,100)] | |
| last_out = agent.spike_history_out[:, -1] | |
| for i in range(SNNAgent.N_OUT): | |
| col = colors_act[i] if last_out[i] else colors_n[i] | |
| acc = agent.last_out_spikes[i] | |
| pygame.draw.rect(s, (30,30,50), (6, y, bar_w, 16)) | |
| fw = int(min(acc / SNNAgent.T_SIM, 1.0) * bar_w) | |
| pygame.draw.rect(s, col, (6, y, fw, 16)) | |
| nl = self.font_sm.render(f"{labels[i]} {acc:.0f}/{SNNAgent.T_SIM}", True, (230,230,230)) | |
| s.blit(nl, (10, y+2)) | |
| y += 19 | |
| y += 4 | |
| # ── Weight norms ── | |
| pygame.draw.line(s, (60,80,120), (4, y), (self.W-4, y)); y += 4 | |
| wl = self.font_sm.render("Weight L2 norms", True, (150,180,220)) | |
| s.blit(wl, (6, y)); y += 14 | |
| for lname, layer in [("L1", agent.layer1), ("L2", agent.layer2)]: | |
| w_norm = float(np.linalg.norm(layer.W)) | |
| e_norm = float(np.linalg.norm(layer.e)) | |
| bar_full = bar_w - 40 | |
| wn = min(w_norm / 5.0, 1.0) | |
| pygame.draw.rect(s, (20,30,50), (6, y, bar_full, 10)) | |
| pygame.draw.rect(s, (80,140,220), (6, y, int(wn*bar_full), 10)) | |
| txt = self.font_sm.render(f"{lname} W:{w_norm:.2f} e:{e_norm:.3f}", True, (180,200,220)) | |
| s.blit(txt, (6, y+12)); y += 26 | |
| # ── Score plot ── | |
| if len(agent.score_history) > 1: | |
| pygame.draw.line(s, (60,80,120), (4, y), (self.W-4, y)); y += 4 | |
| pl = self.font_sm.render("Score per episode", True, (150,200,150)) | |
| s.blit(pl, (6, y)); y += 13 | |
| plot_h = 50; plot_w = bar_w | |
| pygame.draw.rect(s, (10,20,15), (6, y, plot_w, plot_h)) | |
| scores = agent.score_history[-60:] | |
| if scores: | |
| mx = max(max(scores), 1) | |
| pts = [] | |
| for i, sc in enumerate(scores): | |
| px_x = 6 + int(i / max(len(scores)-1, 1) * (plot_w-1)) | |
| px_y = y + plot_h - 1 - int(sc / mx * (plot_h-2)) | |
| pts.append((px_x, px_y)) | |
| if len(pts) > 1: | |
| pygame.draw.lines(s, (80,255,120), False, pts, 1) | |
| # running average | |
| if len(scores) >= 5: | |
| avg_pts = [] | |
| for i in range(len(scores)): | |
| window = scores[max(0,i-4):i+1] | |
| av = sum(window)/len(window) | |
| px_x = 6 + int(i / max(len(scores)-1,1) * (plot_w-1)) | |
| px_y = y + plot_h - 1 - int(av / mx * (plot_h-2)) | |
| avg_pts.append((px_x, px_y)) | |
| pygame.draw.lines(s, (255,200,60), False, avg_pts, 1) | |
| y += plot_h + 4 | |
| # ── Controls hint ── | |
| if y < self.game_h - 50: | |
| y = self.game_h - 50 | |
| pygame.draw.line(s, (60,80,120), (4, y), (self.W-4, y)); y += 4 | |
| for hint in ["[A] toggle AI", "[S] toggle speed", "[R] reset weights"]: | |
| ht = self.font_sm.render(hint, True, (120,130,150)) | |
| s.blit(ht, (6, y)); y += 14 | |
| return s | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # GAME | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| class Game: | |
| def __init__(self, sw, sh): | |
| self.screen_width = sw; self.screen_height = sh | |
| self.bird = Bird(sw, sh) | |
| self.pipes = [] | |
| self.score = 0; self.high_score = 0 | |
| self.game_over = False; self.spawn_timer = 0; self.started = False | |
| self.particles = ParticleSystem() | |
| self.screen_shake = ScreenShake() | |
| self.score_popups = [] | |
| self.flash_screen = 0 | |
| self.sounds = {n: load_sound(n) for n in ["wing","hit","die","point"]} | |
| self.bg_day = load_image("background-day") | |
| self.bg = self.bg_day or self._solid_bg((135,206,235)) | |
| self.base = load_image("base") or self._fallback_base() | |
| self.base_width = self.base.get_width() | |
| self.base_height = self.base.get_height() | |
| self.gameover_img = load_image("gameover") | |
| self.message_img = load_image("message") | |
| self.base_x = 0 | |
| self.cloud_offset = 0 | |
| self.clouds = self._gen_clouds() | |
| self.font = pygame.font.Font(None, 48) | |
| self.small_font = pygame.font.Font(None, 36) | |
| self.big_font = pygame.font.Font(None, 72) | |
| self.ui_font = pygame.font.Font(None, 24) | |
| self._load_numbers() | |
| self.game_over_timer = 0 | |
| self.bird_falling = False | |
| self.game_surface = pygame.Surface((sw, sh)) | |
| # ── SNN stuff ── | |
| self.agent = SNNAgent(sw, sh) | |
| self.ai_mode = True # start in AI mode so it learns immediately | |
| self.fast_mode = False # run multiple updates per frame | |
| self.fast_ticks = 5 | |
| self.visualiser = SNNVisualiser(sh) | |
| self.vis_surface = None | |
| # Survival reward accumulator | |
| self._survive_timer = 0 | |
| # ── background helpers ────────────────────────────────────────────────── | |
| def _solid_bg(self, color): | |
| s = pygame.Surface((self.screen_width, self.screen_height)); s.fill(color); return s | |
| def _fallback_base(self): | |
| s = pygame.Surface((self.screen_width, 112)); s.fill((222,184,135)) | |
| for x in range(0, self.screen_width, 24): | |
| pygame.draw.line(s, (200,160,115), (x,0), (x,112), 2) | |
| return s | |
| def _gen_clouds(self): | |
| return [{"x": random.randint(0, self.screen_width*2), | |
| "y": random.randint(20, 150), | |
| "speed": random.uniform(0.2,0.5), | |
| "size": random.randint(30,60)} for _ in range(5)] | |
| def _draw_clouds(self, surface, sx=0, sy=0): | |
| for c in self.clouds: | |
| x = int(c["x"] - self.cloud_offset + sx) % (self.screen_width*2) | |
| if x < self.screen_width+100: | |
| z = c["size"] | |
| pygame.draw.ellipse(surface, (255,255,255,100), (x-z, c["y"]-z//3, z*2, z//2)) | |
| pygame.draw.ellipse(surface, (255,255,255,150), (x-z//2, c["y"]-z//2, z, z//2)) | |
| def _load_numbers(self): | |
| self.number_imgs = [load_image(str(i)) for i in range(10)] | |
| # ── pipe / score ──────────────────────────────────────────────────────── | |
| def spawn_pipe(self): | |
| self.pipes.append(Pipe(self.screen_width, self.screen_height, self.base_height)) | |
| def draw_score(self, surface, sx=0, sy=0): | |
| s = str(self.score) | |
| total_w = sum(self.number_imgs[int(d)].get_width() if self.number_imgs[int(d)] | |
| else 24 for d in s) | |
| x = (self.screen_width - total_w)//2 + sx; y = 32 + sy | |
| for d in s: | |
| idx = int(d) | |
| if self.number_imgs[idx]: | |
| surface.blit(self.number_imgs[idx], (x, y)); x += self.number_imgs[idx].get_width() | |
| else: | |
| t = self.font.render(d, True, (255,255,255)); surface.blit(t,(x,y)); x+=24 | |
| # ── main update ───────────────────────────────────────────────────────── | |
| def update(self): | |
| if self.game_over: | |
| self.particles.update(); self.screen_shake.update() | |
| if self.flash_screen > 0: self.flash_screen -= 1 | |
| self.game_over_timer += 1 | |
| if self.bird_falling: | |
| self.bird.velocity += self.bird.gravity * 2 | |
| self.bird.y += self.bird.velocity; self.bird.angle += 5 | |
| if self.bird.y > self.screen_height - self.base_height - 24: | |
| self.bird.y = self.screen_height - self.base_height - 24 | |
| self.bird_falling = False | |
| # AI auto-restart | |
| if self.ai_mode and not self.bird_falling and self.game_over_timer > 30: | |
| self.reset() | |
| return | |
| if not self.started: | |
| self.bird.update(game_started=False, base_height=self.base_height) | |
| if random.random() < 0.02: | |
| self.particles.spawn_sparkle(random.randint(0,self.screen_width), | |
| random.randint(0,self.screen_height)) | |
| self.particles.update() | |
| # AI starts immediately | |
| if self.ai_mode: | |
| self.started = True | |
| return | |
| if self.bird_falling: | |
| self.bird.velocity += self.bird.gravity * 2 | |
| self.bird.y += self.bird.velocity; self.bird.angle += 5 | |
| if self.bird.y > self.screen_height - self.base_height - 24: | |
| self.bird.y = self.screen_height - self.base_height - 24 | |
| self.bird_falling = False | |
| self.particles.update(); self.screen_shake.update() | |
| return | |
| # ── AI decision ────────────────────────────────────────────────────── | |
| if self.ai_mode: | |
| obs = self.agent.observe(self.bird.y, self.bird.velocity, | |
| self.pipes, self.screen_width, self.screen_height) | |
| should_flap = self.agent.forward(obs) | |
| if should_flap: | |
| self.bird.jump() | |
| self.play_sound("wing") | |
| # Survival reward every 10 frames | |
| self._survive_timer += 1 | |
| if self._survive_timer >= 10: | |
| self._survive_timer = 0 | |
| self.agent.give_reward(0.05) | |
| self.agent.learn_step() | |
| self.bird.update(game_started=True, base_height=self.base_height) | |
| self.particles.spawn_bird_trail(self.bird.x, self.bird.y+12) | |
| self.cloud_offset += 0.3 | |
| if self.bird.y > self.screen_height - self.base_height - 24: | |
| self.trigger_game_over(); return | |
| if self.bird.y < 0: | |
| self.bird.y = 0; self.bird.velocity = 0 | |
| self.spawn_timer += 1 | |
| if self.spawn_timer >= 100: | |
| self.spawn_timer = 0; self.spawn_pipe() | |
| for pipe in self.pipes: | |
| pipe.update() | |
| if pipe.top_rect.colliderect(self.bird.rect) or \ | |
| pipe.bottom_rect.colliderect(self.bird.rect): | |
| self.trigger_game_over(); break | |
| if not pipe.passed and pipe.x + pipe.pipe_width < self.bird.x: | |
| pipe.passed = True; pipe.flash() | |
| self.score += 1; self.play_sound("point") | |
| self.score_popups.append(ScorePopup(self.screen_width//2, self.screen_height//2-50)) | |
| self.particles.spawn_explosion(self.bird.x+17, self.bird.y+12, (255,255,100), 10) | |
| if self.ai_mode: | |
| self.agent.give_reward(1.0) # big reward for passing pipe | |
| self.score_popups = [p for p in self.score_popups if p.update()] | |
| self.particles.update(); self.screen_shake.update() | |
| if self.flash_screen > 0: self.flash_screen -= 1 | |
| self.pipes = [p for p in self.pipes if not p.off_screen()] | |
| self.base_x -= 2 | |
| if self.base_x <= -self.base_width + self.screen_width: | |
| self.base_x = 0 | |
| def trigger_game_over(self): | |
| if not self.game_over: | |
| self.game_over = True; self.bird_falling = True | |
| self.flash_screen = 5; self.screen_shake.shake(15) | |
| self.particles.spawn_explosion(self.bird.x+17, self.bird.y+12, (255,100,50), 30) | |
| self.play_sound("hit") | |
| pygame.time.delay(50) | |
| self.play_sound("die") | |
| if self.ai_mode: | |
| self.agent.give_reward(-1.0) | |
| self.agent.learn_step() | |
| self.agent.on_death(self.score) | |
| if self.score > self.high_score: | |
| self.high_score = self.score | |
| def draw(self, actual_screen, vis_x_offset=0): | |
| sx, sy = self.screen_shake.get_offset() | |
| self.game_surface.fill((0,0,0)) | |
| self.game_surface.blit(self.bg, (0,0)) | |
| self._draw_clouds(self.game_surface, sx, sy) | |
| for pipe in self.pipes: pipe.draw(self.game_surface) | |
| self.particles.draw(self.game_surface, sx, sy) | |
| if self.base: | |
| self.game_surface.blit(self.base, (self.base_x+sx, self.screen_height-self.base_height+sy)) | |
| self.game_surface.blit(self.base, (self.base_x+self.base_width+sx, self.screen_height-self.base_height+sy)) | |
| self.bird.draw(self.game_surface) | |
| if self.score > 0 or self.started: | |
| self.draw_score(self.game_surface, sx, sy) | |
| for popup in self.score_popups: | |
| popup.draw(self.game_surface, self.big_font) | |
| if self.flash_screen > 0: | |
| fs = pygame.Surface((self.screen_width, self.screen_height)) | |
| fs.fill((255,255,255)); fs.set_alpha(int(200*(self.flash_screen/5))) | |
| self.game_surface.blit(fs,(0,0)) | |
| if not self.started and self.message_img: | |
| mr = self.message_img.get_rect(center=(self.screen_width//2, self.screen_height//2)) | |
| self.game_surface.blit(self.message_img, mr) | |
| if self.game_over: | |
| dark = pygame.Surface((self.screen_width, self.screen_height)) | |
| dark.fill((0,0,0)); dark.set_alpha(150) | |
| self.game_surface.blit(dark, (0,0)) | |
| if self.gameover_img: | |
| bounce = math.sin(self.game_over_timer/10)*3 | |
| go_r = self.gameover_img.get_rect( | |
| center=(self.screen_width//2+sx, self.screen_height//2-40+sy+bounce)) | |
| self.game_surface.blit(self.gameover_img, go_r) | |
| else: | |
| t = self.font.render("GAME OVER", True, (255,0,0)) | |
| self.game_surface.blit(t, t.get_rect(center=(self.screen_width//2,self.screen_height//2-40))) | |
| st = self.font.render(f"Score: {self.score}", True, (255,255,255)) | |
| self.game_surface.blit(st, st.get_rect(center=(self.screen_width//2, self.screen_height//2+10))) | |
| if self.score >= self.high_score and self.score > 0: | |
| ht = self.small_font.render("NEW HIGH SCORE!", True, (255,215,0)) | |
| else: | |
| ht = self.small_font.render(f"Best: {self.high_score}", True, (200,200,200)) | |
| self.game_surface.blit(ht, ht.get_rect(center=(self.screen_width//2,self.screen_height//2+40))) | |
| if not self.ai_mode: | |
| rt = self.small_font.render("Click to restart", True, (255,255,255)) | |
| self.game_surface.blit(rt, rt.get_rect(center=(self.screen_width//2, self.screen_height//2+70))) | |
| # ── AI / speed mode badge ── | |
| mode_col = (80,255,120) if self.ai_mode else (255,150,80) | |
| mode_txt = "AI [RSTDP]" if self.ai_mode else "HUMAN" | |
| badge = self.ui_font.render(mode_txt, True, mode_col) | |
| self.game_surface.blit(badge, (4, 4)) | |
| if self.fast_mode: | |
| ft = self.ui_font.render(f"×{self.fast_ticks}", True, (255,220,60)) | |
| self.game_surface.blit(ft, (4, 20)) | |
| actual_screen.blit(self.game_surface, (vis_x_offset, 0)) | |
| # ── SNN visualiser ── | |
| if self.ai_mode: | |
| vs = self.visualiser.draw(self.agent) | |
| actual_screen.blit(vs, (vis_x_offset + self.screen_width, 0)) | |
| pygame.display.flip() | |
| def reset(self): | |
| self.bird = Bird(self.screen_width, self.screen_height) | |
| self.pipes = []; self.score = 0 | |
| self.game_over = False; self.spawn_timer = 0; self.started = False | |
| self.particles.clear(); self.screen_shake = ScreenShake() | |
| self.score_popups = []; self.game_over_timer = 0 | |
| self.bird_falling = False; self.flash_screen = 0 | |
| self._survive_timer = 0 | |
| def play_sound(self, name): | |
| if name in self.sounds and self.sounds[name]: | |
| try: self.sounds[name].play() | |
| except: pass | |
| def resize(self, nw, nh): | |
| self.screen_width = nw; self.screen_height = nh | |
| self.game_surface = pygame.Surface((nw, nh)) | |
| self.bg = self.bg_day and pygame.transform.scale(self.bg_day, (nw, nh)) or self._solid_bg((135,206,235)) | |
| self.bird.screen_width = nw; self.bird.screen_height = nh | |
| for p in self.pipes: | |
| p.screen_height = nh; p.base_height = self.base_height | |
| p.bottom_rect.height = nh - p.height - p.GAP - self.base_height | |
| self.visualiser = SNNVisualiser(nh) | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # MAIN | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| def main(): | |
| cache_assets() | |
| pygame.mixer.init() | |
| pygame.init() | |
| GAME_W, GAME_H = 288, 512 | |
| VIS_W = SNNVisualiser.W | |
| # Start with vis panel visible (AI mode on) | |
| total_w = GAME_W + VIS_W | |
| screen = pygame.display.set_mode((total_w, GAME_H), pygame.RESIZABLE) | |
| pygame.display.set_caption("Flappy Bird – SNN RSTDP Agent") | |
| clock = pygame.time.Clock() | |
| game = Game(GAME_W, GAME_H) | |
| running = True | |
| vis_visible = True # mirrors ai_mode | |
| while running: | |
| for event in pygame.event.get(): | |
| if event.type == pygame.QUIT: | |
| running = False | |
| if event.type == pygame.VIDEORESIZE: | |
| screen = pygame.display.set_mode((event.w, event.h), pygame.RESIZABLE) | |
| if event.type == pygame.KEYDOWN: | |
| if event.key == pygame.K_ESCAPE: | |
| running = False | |
| elif event.key == pygame.K_a: | |
| # Toggle AI mode | |
| game.ai_mode = not game.ai_mode | |
| vis_visible = game.ai_mode | |
| new_w = GAME_W + (VIS_W if vis_visible else 0) | |
| screen = pygame.display.set_mode((new_w, GAME_H), pygame.RESIZABLE) | |
| game.reset() | |
| elif event.key == pygame.K_s: | |
| # Toggle fast mode | |
| game.fast_mode = not game.fast_mode | |
| elif event.key == pygame.K_r: | |
| # Reset SNN weights | |
| n_in = SNNAgent.N_IN; n_hid = SNNAgent.N_HID; n_out = SNNAgent.N_OUT | |
| game.agent.layer1 = SNNLayer(n_in, n_hid, tau_e=80.0) | |
| game.agent.layer2 = SNNLayer(n_hid, n_out, tau_e=80.0) | |
| game.agent.dopamine = 0.0 | |
| print("SNN weights reset.") | |
| elif event.key == pygame.K_SPACE: | |
| if not game.ai_mode: | |
| if game.game_over and not game.bird_falling: | |
| game.reset() | |
| elif not game.started: | |
| game.started = True; game.bird.jump(); game.play_sound("wing") | |
| elif not game.bird_falling: | |
| game.bird.jump(); game.play_sound("wing") | |
| if event.type == pygame.MOUSEBUTTONDOWN and not game.ai_mode: | |
| if game.game_over and not game.bird_falling: | |
| game.reset() | |
| elif not game.started: | |
| game.started = True; game.bird.jump(); game.play_sound("wing") | |
| elif not game.bird_falling: | |
| game.bird.jump(); game.play_sound("wing") | |
| # Fast mode: run multiple update ticks per frame | |
| ticks = game.fast_ticks if game.fast_mode and game.ai_mode else 1 | |
| for _ in range(ticks): | |
| game.update() | |
| if game.game_over and not game.bird_falling: | |
| break # let it restart | |
| game.draw(screen, vis_x_offset=0) | |
| clock.tick(60) | |
| pygame.quit() | |
| sys.exit() | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment