Skip to content

Instantly share code, notes, and snippets.

@CypherpunkSamurai
Last active April 2, 2026 05:43
Show Gist options
  • Select an option

  • Save CypherpunkSamurai/30b5ebb8eb7329421f0bdf90aca166e9 to your computer and use it in GitHub Desktop.

Select an option

Save CypherpunkSamurai/30b5ebb8eb7329421f0bdf90aca166e9 to your computer and use it in GitHub Desktop.
import pygame
import random
import sys
import os
import urllib.request
import math
import numpy as np
# ─────────────────────────────────────────────────────────────────────────────
# SPIKING NEURAL NETWORK ─ LIF + RSTDP with eligibility traces
# ─────────────────────────────────────────────────────────────────────────────
class LIFNeuron:
"""Leaky Integrate-and-Fire neuron."""
def __init__(self, n, tau_m=20.0, v_thresh=1.0, v_reset=0.0, v_rest=0.0):
self.n = n
self.tau_m = tau_m # membrane time constant (ms)
self.v_thresh = v_thresh # firing threshold
self.v_reset = v_reset # reset potential after spike
self.v_rest = v_rest # resting potential
self.v = np.full(n, v_rest, dtype=np.float32)
self.spikes = np.zeros(n, dtype=np.float32)
# Spike traces for STDP (low-pass filtered spike train)
self.trace = np.zeros(n, dtype=np.float32)
self.tau_trace = 20.0 # trace decay time constant
def step(self, I_in, dt=1.0):
"""One timestep update. I_in: input current vector."""
# Membrane update (Euler): τ dV/dt = -(V-V_rest) + I
decay = np.exp(-dt / self.tau_m)
self.v = self.v * decay + self.v_rest * (1 - decay) + I_in * (dt / self.tau_m)
# Spike detection
self.spikes = (self.v >= self.v_thresh).astype(np.float32)
# Reset fired neurons
self.v = np.where(self.spikes > 0, self.v_reset, self.v)
# Update spike trace: x += -x/tau + spike
trace_decay = np.exp(-dt / self.tau_trace)
self.trace = self.trace * trace_decay + self.spikes
return self.spikes
def reset_state(self):
self.v[:] = self.v_rest
self.spikes[:] = 0
self.trace[:] = 0
class SNNLayer:
"""Feedforward SNN layer with RSTDP eligibility traces."""
def __init__(self, n_pre, n_post, tau_e=100.0, w_min=-2.0, w_max=2.0):
self.n_pre = n_pre
self.n_post = n_post
self.tau_e = tau_e # eligibility trace decay (ms)
self.w_min = w_min
self.w_max = w_max
# Xavier-like init
scale = np.sqrt(2.0 / (n_pre + n_post))
self.W = np.random.randn(n_post, n_pre).astype(np.float32) * scale
# Eligibility traces: shape (n_post, n_pre)
self.e = np.zeros((n_post, n_pre), dtype=np.float32)
def forward(self, pre_spikes):
"""Compute input current for post layer."""
return self.W @ pre_spikes # (n_post,)
def update_eligibility(self, pre_trace, pre_spikes, post_trace, post_spikes, dt=1.0):
"""
STDP-based eligibility trace update.
e_ij decays and is incremented by:
+A_plus * pre_trace_i * post_spike_j (LTP: pre before post)
-A_minus * post_trace_j * pre_spike_i (LTD: post before pre)
"""
A_plus = 0.02
A_minus = 0.02
e_decay = np.exp(-dt / self.tau_e)
# LTP term: outer product of post_spikes and pre_trace
ltp = np.outer(post_spikes, pre_trace) * A_plus
# LTD term: outer product of post_trace and pre_spikes
ltd = np.outer(post_trace, pre_spikes) * A_minus
self.e = self.e * e_decay + ltp - ltd
def apply_dopamine(self, dopamine, eta=0.01):
"""Apply reward signal to update weights via eligibility traces."""
dW = eta * dopamine * self.e
self.W = np.clip(self.W + dW, self.w_min, self.w_max)
def reset_state(self):
self.e[:] = 0.0
class SNNAgent:
"""
Online RL SNN agent using LIF neurons + RSTDP with eligibility traces.
Architecture:
Input (5 rate-coded neurons) → Hidden (16 LIF) → Output (2 LIF: [no-flap, flap])
Learning:
Three-factor RSTDP:
ΔW = η × d(t) × e_ij(t)
where d(t) is the dopamine/reward trace and e_ij is the synaptic
eligibility trace capturing recent pre/post spike coincidences.
Inputs (rate-coded, normalised 0→1):
0: bird_y / screen_height
1: (bird_vy + 10) / 20 normalised velocity
2: dist_to_next_pipe / screen_w
3: pipe_center_y / screen_h
4: gap_top_y / screen_h (top of the gap)
"""
N_IN = 5
N_HID = 16
N_OUT = 2
T_SIM = 8 # simulation timesteps per game-frame
FIRE_RATE_MAX = 0.8
def __init__(self, screen_w, screen_h):
self.screen_w = screen_w
self.screen_h = screen_h
self.layer1 = SNNLayer(self.N_IN, self.N_HID, tau_e=80.0)
self.layer2 = SNNLayer(self.N_HID, self.N_OUT, tau_e=80.0)
self.hid = LIFNeuron(self.N_HID, tau_m=15.0, v_thresh=1.0)
self.out = LIFNeuron(self.N_OUT, tau_m=15.0, v_thresh=1.0)
# Dopamine trace
self.dopamine = 0.0
self.tau_dopa = 60.0 # ms – dopamine decay
# Accumulate spikes for this frame
self.last_out_spikes = np.zeros(self.N_OUT)
self.last_hid_spikes = np.zeros(self.N_HID)
# History for visualisation
self.spike_history_hid = np.zeros((self.N_HID, 60), dtype=np.float32)
self.spike_history_out = np.zeros((self.N_OUT, 60), dtype=np.float32)
self.dopa_history = np.zeros(60, dtype=np.float32)
self.score_history = []
self.total_reward = 0.0
self.frames_alive = 0
self.episode_count = 0
self.episode_rewards = []
self.eta = 0.005 # learning rate
self.explore_eps = 0.05 # random exploration rate
# Keep encoder for reuse
self._rng = np.random.default_rng()
# ── encoding ──────────────────────────────────────────────────────────────
def _encode(self, obs):
"""Rate-code observation into spike probabilities."""
return np.clip(obs, 0.0, 1.0).astype(np.float32) * self.FIRE_RATE_MAX
# ── forward pass over T_SIM steps ─────────────────────────────────────────
def forward(self, obs):
rates = self._encode(obs)
out_acc = np.zeros(self.N_OUT)
for _ in range(self.T_SIM):
# Poisson input spikes
in_spikes = (self._rng.random(self.N_IN) < rates).astype(np.float32)
# Layer 1 forward
I_hid = self.layer1.forward(in_spikes)
hid_spikes = self.hid.step(I_hid)
# Layer 2 forward
I_out = self.layer2.forward(hid_spikes)
out_spikes = self.out.step(I_out)
# Update eligibility traces
self.layer1.update_eligibility(
in_spikes, in_spikes, # pre: spikes≈trace for input layer
self.hid.trace, hid_spikes
)
self.layer2.update_eligibility(
self.hid.trace, hid_spikes,
self.out.trace, out_spikes
)
out_acc += out_spikes
self.last_out_spikes = out_acc
self.last_hid_spikes = self.hid.spikes.copy()
# Roll spike history
self.spike_history_hid = np.roll(self.spike_history_hid, -1, axis=1)
self.spike_history_hid[:, -1] = self.last_hid_spikes
self.spike_history_out = np.roll(self.spike_history_out, -1, axis=1)
self.spike_history_out[:, -1] = self.last_out_spikes > 0
# Decision: flap if output neuron 1 fires more than neuron 0
# With exploration
if self._rng.random() < self.explore_eps:
return self._rng.integers(0, 2) == 1
return bool(out_acc[1] > out_acc[0])
# ── reward / learning ─────────────────────────────────────────────────────
def give_reward(self, r):
"""Inject reward as dopamine pulse."""
self.dopamine = np.clip(self.dopamine + r, -5.0, 5.0)
self.total_reward += r
def learn_step(self, dt=1.0):
"""
Apply dopamine-gated weight updates via eligibility traces.
Called every game frame.
"""
if abs(self.dopamine) < 1e-4:
return
self.layer1.apply_dopamine(self.dopamine, eta=self.eta)
self.layer2.apply_dopamine(self.dopamine, eta=self.eta)
# Dopamine decay
dopa_decay = np.exp(-dt / self.tau_dopa)
self.dopamine *= dopa_decay
# Log
self.dopa_history = np.roll(self.dopa_history, -1)
self.dopa_history[-1] = self.dopamine
self.frames_alive += 1
def on_death(self, score):
self.episode_count += 1
self.episode_rewards.append(self.total_reward)
self.score_history.append(score)
if len(self.episode_rewards) > 200:
self.episode_rewards.pop(0)
self.score_history.pop(0)
self.total_reward = 0.0
self.frames_alive = 0
# Partial reset of neuron states (keep weights / traces)
self.hid.reset_state()
self.out.reset_state()
# Decay eligibility traces on death
self.layer1.e *= 0.1
self.layer2.e *= 0.1
# ── observation builder (called from Game) ────────────────────────────────
def observe(self, bird_y, bird_vy, pipes, screen_w, screen_h):
# Find nearest pipe ahead of bird
ahead = [(p.x, p.height, p.height + p.GAP)
for p in pipes if p.x + p.pipe_width > 50]
if ahead:
ahead.sort()
px, ptop, pbot = ahead[0]
dist = (px - 50) / screen_w
pcen = (ptop + pbot) / 2 / screen_h
ptop_n = ptop / screen_h
else:
dist, pcen, ptop_n = 1.0, 0.5, 0.4
obs = np.array([
bird_y / screen_h,
(bird_vy + 10) / 20,
dist,
pcen,
ptop_n,
], dtype=np.float32)
return obs
# ─────────────────────────────────────────────────────────────────────────────
# ORIGINAL GAME ASSETS / HELPERS (unchanged)
# ─────────────────────────────────────────────────────────────────────────────
def cache_assets():
sprites = [
"0.png","1.png","2.png","3.png","4.png","5.png","6.png","7.png","8.png","9.png",
"background-day.png","background-night.png","base.png",
"bluebird-downflap.png","bluebird-midflap.png","bluebird-upflap.png",
"gameover.png","message.png","pipe-green.png","pipe-red.png",
"redbird-downflap.png","redbird-midflap.png","redbird-upflap.png",
"yellowbird-downflap.png","yellowbird-midflap.png","yellowbird-upflap.png",
]
audio_files = ["die.ogg","hit.ogg","point.ogg","swoosh.ogg","wing.ogg"]
sprite_base = "https://raw.githubusercontent.com/samuelcust/flappy-bird-assets/master/sprites/"
audio_base = "https://raw.githubusercontent.com/samuelcust/flappy-bird-assets/master/audio/"
os.makedirs("sprites", exist_ok=True)
os.makedirs("audio", exist_ok=True)
downloaded = 0
for f in sprites:
path = f"sprites/{f}"
if not os.path.exists(path):
try: urllib.request.urlretrieve(sprite_base + f, path); downloaded += 1
except: pass
for f in audio_files:
path = f"audio/{f}"
if not os.path.exists(path):
try: urllib.request.urlretrieve(audio_base + f, path); downloaded += 1
except: pass
if downloaded: print(f"Downloaded {downloaded} assets")
def load_sound(name):
try:
path = f"audio/{name}.ogg"
if os.path.exists(path):
return pygame.mixer.Sound(path)
except: pass
return None
def load_image(name):
try:
path = f"sprites/{name}.png"
if os.path.exists(path):
return pygame.image.load(path).convert_alpha()
except: pass
return None
# ─────────────────────────────────────────────────────────────────────────────
# EFFECTS
# ─────────────────────────────────────────────────────────────────────────────
class Particle:
def __init__(self, x, y, color, velocity, lifetime, size=2):
self.x, self.y = x, y
self.color = color
self.vx, self.vy = velocity
self.lifetime = lifetime
self.max_lifetime = lifetime
self.size = size
self.gravity = 0.2
def update(self):
self.vy += self.gravity
self.x += self.vx; self.y += self.vy
self.lifetime -= 1
return self.lifetime > 0
def draw(self, surface, shake_x=0, shake_y=0):
alpha = int(255 * (self.lifetime / self.max_lifetime))
surf = pygame.Surface((self.size, self.size), pygame.SRCALPHA)
surf.fill((*self.color[:3], alpha))
surface.blit(surf, (int(self.x + shake_x), int(self.y + shake_y)))
class ParticleSystem:
def __init__(self): self.particles = []
def spawn_explosion(self, x, y, color=(255,200,0), count=20):
for _ in range(count):
a = random.uniform(0, 2*math.pi); s = random.uniform(2,8)
self.particles.append(Particle(x, y, color,
(math.cos(a)*s, math.sin(a)*s), random.randint(20,40), random.choice([2,3,4])))
def spawn_sparkle(self, x, y, color=(255,255,255)):
self.particles.append(Particle(x, y, color,
(random.uniform(-1,1), random.uniform(-2,-0.5)), 30, 2))
def spawn_bird_trail(self, x, y):
if random.random() < 0.3:
self.particles.append(Particle(x, y, (255,255,100,150),
(random.uniform(-1,-3), random.uniform(-0.5,0.5)), 15, 2))
def update(self):
self.particles = [p for p in self.particles if p.update()]
def draw(self, surface, sx=0, sy=0):
for p in self.particles: p.draw(surface, sx, sy)
def clear(self): self.particles = []
class ScreenShake:
def __init__(self): self.shake_amount = 0; self.shake_decay = 0.9
def shake(self, amount=10): self.shake_amount = amount
def update(self):
self.shake_amount *= self.shake_decay
if self.shake_amount < 0.5: self.shake_amount = 0
def get_offset(self):
if self.shake_amount == 0: return (0,0)
return (random.uniform(-self.shake_amount, self.shake_amount),
random.uniform(-self.shake_amount, self.shake_amount))
class ScorePopup:
def __init__(self, x, y, score=1):
self.x, self.y = x, y
self.lifetime = 30; self.max_lifetime = 30
self.scale = 1.0; self.score = score
def update(self):
self.lifetime -= 1
p = 1 - (self.lifetime / self.max_lifetime)
self.y -= 1
self.scale = 1.0 + p*1.5 if p < 0.3 else 1.45 - (p-0.3)*0.7
return self.lifetime > 0
def draw(self, surface, font):
alpha = int(255 * (self.lifetime / self.max_lifetime))
text = font.render(f"+{self.score}", True, (255,255,100))
text.set_alpha(alpha)
scaled = pygame.transform.scale(text, (int(text.get_width()*self.scale),
int(text.get_height()*self.scale)))
surface.blit(scaled, scaled.get_rect(center=(self.x, self.y)))
# ─────────────────────────────────────────────────────────────────────────────
# BIRD & PIPE (same as original)
# ─────────────────────────────────────────────────────────────────────────────
class Bird:
def __init__(self, sw, sh):
self.screen_width = sw; self.screen_height = sh
self.x = 50; self.y = sh // 2
self.velocity = 0; self.gravity = 0.25; self.jump_strength = -4.5
self.angle = 0; self.target_angle = 0
self.frames = [load_image("yellowbird-downflap"),
load_image("yellowbird-midflap"),
load_image("yellowbird-upflap")]
if self.frames[0] is None:
self.frames = [self._fallback() for _ in range(3)]
self.current_frame = 0; self.animation_timer = 0; self.flap_timer = 0
self.rect = pygame.Rect(self.x, self.y, 34, 24)
self.bounce_offset = 0; self.bounce_speed = 0.05; self.bounce_time = 0
def _fallback(self):
s = pygame.Surface((34,24), pygame.SRCALPHA)
pygame.draw.ellipse(s, (255,255,0), (0,0,34,24))
return s
def jump(self):
self.velocity = self.jump_strength; self.target_angle = -25; self.flap_timer = 10
def update(self, game_started=False, base_height=112):
if game_started:
self.velocity += self.gravity; self.y += self.velocity
self.target_angle = max(-25, min(90, self.velocity*3))
else:
self.bounce_time += self.bounce_speed
self.bounce_offset = math.sin(self.bounce_time) * 5
self.target_angle = math.sin(self.bounce_time*2) * 5
self.angle += (self.target_angle - self.angle) * 0.1
self.animation_timer += 1
if self.flap_timer > 0:
if self.animation_timer >= 3:
self.animation_timer = 0; self.current_frame = (self.current_frame+1)%3
self.flap_timer -= 1
else:
if self.animation_timer >= 5:
self.animation_timer = 0; self.current_frame = (self.current_frame+1)%3
self.rect.y = int(self.y + self.bounce_offset)
def draw(self, surface):
glow = pygame.Surface((50,50), pygame.SRCALPHA)
pygame.draw.ellipse(glow, (255,255,100,30), (8,8,34,34))
surface.blit(glow, (self.x-8, self.y+self.bounce_offset-5))
if self.frames[self.current_frame]:
rotated = pygame.transform.rotate(self.frames[self.current_frame], -self.angle)
rect = rotated.get_rect(center=(self.x+17, self.y+12+self.bounce_offset))
surface.blit(rotated, rect)
class Pipe:
GAP = 100; VELOCITY = 2
def __init__(self, x, sh, base_height=112):
self.x = x; self.screen_height = sh; self.base_height = base_height
self.height = random.randint(50, sh - self.GAP - base_height - 50)
self.passed = False; self.flash_timer = 0
self.pipe_img = load_image("pipe-green")
if self.pipe_img:
self.pipe_width = self.pipe_img.get_width()
self.pipe_height = self.pipe_img.get_height()
else:
self.pipe_width = 52; self.pipe_height = 320
self.top_rect = pygame.Rect(x, 0, self.pipe_width, self.height)
self.bottom_rect = pygame.Rect(x, self.height+self.GAP,
self.pipe_width, sh-self.height-self.GAP-base_height)
def update(self):
self.x -= self.VELOCITY
self.top_rect.x = self.x; self.bottom_rect.x = self.x
if self.flash_timer > 0: self.flash_timer -= 1
def draw(self, surface, flash=False):
if self.pipe_img:
top = pygame.transform.flip(self.pipe_img, False, True)
if flash or self.flash_timer > 0:
for img, y in [(top, self.height-self.pipe_height),
(self.pipe_img, self.height+self.GAP)]:
over = img.copy()
over.fill((255,255,255,100), special_flags=pygame.BLEND_RGBA_ADD)
surface.blit(over, (self.x, y))
else:
surface.blit(top, (self.x, self.height-self.pipe_height))
surface.blit(self.pipe_img, (self.x, self.height+self.GAP))
else:
pygame.draw.rect(surface, (0,200,0), self.top_rect)
pygame.draw.rect(surface, (0,200,0), self.bottom_rect)
def flash(self): self.flash_timer = 5
def off_screen(self): return self.x < -self.pipe_width
# ─────────────────────────────────────────────────────────────────────────────
# SNN VISUALISER
# ─────────────────────────────────────────────────────────────────────────────
class SNNVisualiser:
"""
Overlay panel showing real-time SNN activity.
Docked to the right of the game window.
"""
W = 260
def __init__(self, game_h):
self.game_h = game_h
self.font_sm = pygame.font.Font(None, 18)
self.font_md = pygame.font.Font(None, 22)
self.surf = pygame.Surface((self.W, game_h), pygame.SRCALPHA)
def draw(self, agent: SNNAgent):
s = self.surf
s.fill((10, 10, 25, 220))
y = 8
# ── Title ──
t = self.font_md.render("SNN Agent [RSTDP + e-trace]", True, (150,220,255))
s.blit(t, (6, y)); y += 22
# ── Episode stats ──
ep_lbl = self.font_sm.render(
f"Episode {agent.episode_count} η={agent.eta:.4f} ε={agent.explore_eps:.2f}",
True, (180,180,180))
s.blit(ep_lbl, (6, y)); y += 16
if agent.score_history:
best = max(agent.score_history)
avg = sum(agent.score_history[-10:]) / len(agent.score_history[-10:])
stat = self.font_sm.render(f"Best: {best} Avg10: {avg:.1f}", True, (255,200,80))
s.blit(stat, (6, y))
y += 18
# ── Separator ──
pygame.draw.line(s, (60,80,120), (4, y), (self.W-4, y)); y += 6
# ── Dopamine trace bar ──
d_lbl = self.font_sm.render("Dopamine (reward trace)", True, (200,150,255))
s.blit(d_lbl, (6, y)); y += 14
bar_w = self.W - 12; bar_h = 20
pygame.draw.rect(s, (30,20,50), (6, y, bar_w, bar_h))
d_norm = np.clip(agent.dopamine / 5.0, -1, 1)
mid = 6 + bar_w // 2
if d_norm > 0:
pygame.draw.rect(s, (80,200,80), (mid, y+2, int(d_norm*(bar_w//2)), bar_h-4))
elif d_norm < 0:
w = int(-d_norm * (bar_w//2))
pygame.draw.rect(s, (200,60,60), (mid-w, y+2, w, bar_h-4))
pygame.draw.line(s, (120,120,120), (mid, y), (mid, y+bar_h))
dv = self.font_sm.render(f"{agent.dopamine:+.3f}", True, (220,220,220))
s.blit(dv, (6, y+4)); y += bar_h + 6
# ── Hidden layer spike raster (last 60 frames) ──
pygame.draw.line(s, (60,80,120), (4, y), (self.W-4, y)); y += 4
hl = self.font_sm.render(f"Hidden spikes ({SNNAgent.N_HID} neurons × 60 frames)", True, (150,200,150))
s.blit(hl, (6, y)); y += 13
raster_h = SNNAgent.N_HID * 4
pygame.draw.rect(s, (15,20,30), (6, y, bar_w, raster_h))
hist = agent.spike_history_hid # (N_HID, 60)
cell_w = bar_w / 60
for ni in range(SNNAgent.N_HID):
for ti in range(60):
if hist[ni, ti] > 0:
col = (
int(60 + 195 * (ni / SNNAgent.N_HID)),
int(200 - 100 * (ni / SNNAgent.N_HID)),
255
)
pygame.draw.rect(s, col,
(6 + int(ti*cell_w), y + ni*4, max(1, int(cell_w)), 3))
y += raster_h + 6
# ── Output neurons ──
pygame.draw.line(s, (60,80,120), (4, y), (self.W-4, y)); y += 4
ol = self.font_sm.render("Output neurons", True, (255,200,100))
s.blit(ol, (6, y)); y += 14
labels = ["NO-FLAP", "FLAP"]
colors_n = [(80,150,255), (255,150,80)]
colors_act = [(160,200,255),(255,210,100)]
last_out = agent.spike_history_out[:, -1]
for i in range(SNNAgent.N_OUT):
col = colors_act[i] if last_out[i] else colors_n[i]
acc = agent.last_out_spikes[i]
pygame.draw.rect(s, (30,30,50), (6, y, bar_w, 16))
fw = int(min(acc / SNNAgent.T_SIM, 1.0) * bar_w)
pygame.draw.rect(s, col, (6, y, fw, 16))
nl = self.font_sm.render(f"{labels[i]} {acc:.0f}/{SNNAgent.T_SIM}", True, (230,230,230))
s.blit(nl, (10, y+2))
y += 19
y += 4
# ── Weight norms ──
pygame.draw.line(s, (60,80,120), (4, y), (self.W-4, y)); y += 4
wl = self.font_sm.render("Weight L2 norms", True, (150,180,220))
s.blit(wl, (6, y)); y += 14
for lname, layer in [("L1", agent.layer1), ("L2", agent.layer2)]:
w_norm = float(np.linalg.norm(layer.W))
e_norm = float(np.linalg.norm(layer.e))
bar_full = bar_w - 40
wn = min(w_norm / 5.0, 1.0)
pygame.draw.rect(s, (20,30,50), (6, y, bar_full, 10))
pygame.draw.rect(s, (80,140,220), (6, y, int(wn*bar_full), 10))
txt = self.font_sm.render(f"{lname} W:{w_norm:.2f} e:{e_norm:.3f}", True, (180,200,220))
s.blit(txt, (6, y+12)); y += 26
# ── Score plot ──
if len(agent.score_history) > 1:
pygame.draw.line(s, (60,80,120), (4, y), (self.W-4, y)); y += 4
pl = self.font_sm.render("Score per episode", True, (150,200,150))
s.blit(pl, (6, y)); y += 13
plot_h = 50; plot_w = bar_w
pygame.draw.rect(s, (10,20,15), (6, y, plot_w, plot_h))
scores = agent.score_history[-60:]
if scores:
mx = max(max(scores), 1)
pts = []
for i, sc in enumerate(scores):
px_x = 6 + int(i / max(len(scores)-1, 1) * (plot_w-1))
px_y = y + plot_h - 1 - int(sc / mx * (plot_h-2))
pts.append((px_x, px_y))
if len(pts) > 1:
pygame.draw.lines(s, (80,255,120), False, pts, 1)
# running average
if len(scores) >= 5:
avg_pts = []
for i in range(len(scores)):
window = scores[max(0,i-4):i+1]
av = sum(window)/len(window)
px_x = 6 + int(i / max(len(scores)-1,1) * (plot_w-1))
px_y = y + plot_h - 1 - int(av / mx * (plot_h-2))
avg_pts.append((px_x, px_y))
pygame.draw.lines(s, (255,200,60), False, avg_pts, 1)
y += plot_h + 4
# ── Controls hint ──
if y < self.game_h - 50:
y = self.game_h - 50
pygame.draw.line(s, (60,80,120), (4, y), (self.W-4, y)); y += 4
for hint in ["[A] toggle AI", "[S] toggle speed", "[R] reset weights"]:
ht = self.font_sm.render(hint, True, (120,130,150))
s.blit(ht, (6, y)); y += 14
return s
# ─────────────────────────────────────────────────────────────────────────────
# GAME
# ─────────────────────────────────────────────────────────────────────────────
class Game:
def __init__(self, sw, sh):
self.screen_width = sw; self.screen_height = sh
self.bird = Bird(sw, sh)
self.pipes = []
self.score = 0; self.high_score = 0
self.game_over = False; self.spawn_timer = 0; self.started = False
self.particles = ParticleSystem()
self.screen_shake = ScreenShake()
self.score_popups = []
self.flash_screen = 0
self.sounds = {n: load_sound(n) for n in ["wing","hit","die","point"]}
self.bg_day = load_image("background-day")
self.bg = self.bg_day or self._solid_bg((135,206,235))
self.base = load_image("base") or self._fallback_base()
self.base_width = self.base.get_width()
self.base_height = self.base.get_height()
self.gameover_img = load_image("gameover")
self.message_img = load_image("message")
self.base_x = 0
self.cloud_offset = 0
self.clouds = self._gen_clouds()
self.font = pygame.font.Font(None, 48)
self.small_font = pygame.font.Font(None, 36)
self.big_font = pygame.font.Font(None, 72)
self.ui_font = pygame.font.Font(None, 24)
self._load_numbers()
self.game_over_timer = 0
self.bird_falling = False
self.game_surface = pygame.Surface((sw, sh))
# ── SNN stuff ──
self.agent = SNNAgent(sw, sh)
self.ai_mode = True # start in AI mode so it learns immediately
self.fast_mode = False # run multiple updates per frame
self.fast_ticks = 5
self.visualiser = SNNVisualiser(sh)
self.vis_surface = None
# Survival reward accumulator
self._survive_timer = 0
# ── background helpers ──────────────────────────────────────────────────
def _solid_bg(self, color):
s = pygame.Surface((self.screen_width, self.screen_height)); s.fill(color); return s
def _fallback_base(self):
s = pygame.Surface((self.screen_width, 112)); s.fill((222,184,135))
for x in range(0, self.screen_width, 24):
pygame.draw.line(s, (200,160,115), (x,0), (x,112), 2)
return s
def _gen_clouds(self):
return [{"x": random.randint(0, self.screen_width*2),
"y": random.randint(20, 150),
"speed": random.uniform(0.2,0.5),
"size": random.randint(30,60)} for _ in range(5)]
def _draw_clouds(self, surface, sx=0, sy=0):
for c in self.clouds:
x = int(c["x"] - self.cloud_offset + sx) % (self.screen_width*2)
if x < self.screen_width+100:
z = c["size"]
pygame.draw.ellipse(surface, (255,255,255,100), (x-z, c["y"]-z//3, z*2, z//2))
pygame.draw.ellipse(surface, (255,255,255,150), (x-z//2, c["y"]-z//2, z, z//2))
def _load_numbers(self):
self.number_imgs = [load_image(str(i)) for i in range(10)]
# ── pipe / score ────────────────────────────────────────────────────────
def spawn_pipe(self):
self.pipes.append(Pipe(self.screen_width, self.screen_height, self.base_height))
def draw_score(self, surface, sx=0, sy=0):
s = str(self.score)
total_w = sum(self.number_imgs[int(d)].get_width() if self.number_imgs[int(d)]
else 24 for d in s)
x = (self.screen_width - total_w)//2 + sx; y = 32 + sy
for d in s:
idx = int(d)
if self.number_imgs[idx]:
surface.blit(self.number_imgs[idx], (x, y)); x += self.number_imgs[idx].get_width()
else:
t = self.font.render(d, True, (255,255,255)); surface.blit(t,(x,y)); x+=24
# ── main update ─────────────────────────────────────────────────────────
def update(self):
if self.game_over:
self.particles.update(); self.screen_shake.update()
if self.flash_screen > 0: self.flash_screen -= 1
self.game_over_timer += 1
if self.bird_falling:
self.bird.velocity += self.bird.gravity * 2
self.bird.y += self.bird.velocity; self.bird.angle += 5
if self.bird.y > self.screen_height - self.base_height - 24:
self.bird.y = self.screen_height - self.base_height - 24
self.bird_falling = False
# AI auto-restart
if self.ai_mode and not self.bird_falling and self.game_over_timer > 30:
self.reset()
return
if not self.started:
self.bird.update(game_started=False, base_height=self.base_height)
if random.random() < 0.02:
self.particles.spawn_sparkle(random.randint(0,self.screen_width),
random.randint(0,self.screen_height))
self.particles.update()
# AI starts immediately
if self.ai_mode:
self.started = True
return
if self.bird_falling:
self.bird.velocity += self.bird.gravity * 2
self.bird.y += self.bird.velocity; self.bird.angle += 5
if self.bird.y > self.screen_height - self.base_height - 24:
self.bird.y = self.screen_height - self.base_height - 24
self.bird_falling = False
self.particles.update(); self.screen_shake.update()
return
# ── AI decision ──────────────────────────────────────────────────────
if self.ai_mode:
obs = self.agent.observe(self.bird.y, self.bird.velocity,
self.pipes, self.screen_width, self.screen_height)
should_flap = self.agent.forward(obs)
if should_flap:
self.bird.jump()
self.play_sound("wing")
# Survival reward every 10 frames
self._survive_timer += 1
if self._survive_timer >= 10:
self._survive_timer = 0
self.agent.give_reward(0.05)
self.agent.learn_step()
self.bird.update(game_started=True, base_height=self.base_height)
self.particles.spawn_bird_trail(self.bird.x, self.bird.y+12)
self.cloud_offset += 0.3
if self.bird.y > self.screen_height - self.base_height - 24:
self.trigger_game_over(); return
if self.bird.y < 0:
self.bird.y = 0; self.bird.velocity = 0
self.spawn_timer += 1
if self.spawn_timer >= 100:
self.spawn_timer = 0; self.spawn_pipe()
for pipe in self.pipes:
pipe.update()
if pipe.top_rect.colliderect(self.bird.rect) or \
pipe.bottom_rect.colliderect(self.bird.rect):
self.trigger_game_over(); break
if not pipe.passed and pipe.x + pipe.pipe_width < self.bird.x:
pipe.passed = True; pipe.flash()
self.score += 1; self.play_sound("point")
self.score_popups.append(ScorePopup(self.screen_width//2, self.screen_height//2-50))
self.particles.spawn_explosion(self.bird.x+17, self.bird.y+12, (255,255,100), 10)
if self.ai_mode:
self.agent.give_reward(1.0) # big reward for passing pipe
self.score_popups = [p for p in self.score_popups if p.update()]
self.particles.update(); self.screen_shake.update()
if self.flash_screen > 0: self.flash_screen -= 1
self.pipes = [p for p in self.pipes if not p.off_screen()]
self.base_x -= 2
if self.base_x <= -self.base_width + self.screen_width:
self.base_x = 0
def trigger_game_over(self):
if not self.game_over:
self.game_over = True; self.bird_falling = True
self.flash_screen = 5; self.screen_shake.shake(15)
self.particles.spawn_explosion(self.bird.x+17, self.bird.y+12, (255,100,50), 30)
self.play_sound("hit")
pygame.time.delay(50)
self.play_sound("die")
if self.ai_mode:
self.agent.give_reward(-1.0)
self.agent.learn_step()
self.agent.on_death(self.score)
if self.score > self.high_score:
self.high_score = self.score
def draw(self, actual_screen, vis_x_offset=0):
sx, sy = self.screen_shake.get_offset()
self.game_surface.fill((0,0,0))
self.game_surface.blit(self.bg, (0,0))
self._draw_clouds(self.game_surface, sx, sy)
for pipe in self.pipes: pipe.draw(self.game_surface)
self.particles.draw(self.game_surface, sx, sy)
if self.base:
self.game_surface.blit(self.base, (self.base_x+sx, self.screen_height-self.base_height+sy))
self.game_surface.blit(self.base, (self.base_x+self.base_width+sx, self.screen_height-self.base_height+sy))
self.bird.draw(self.game_surface)
if self.score > 0 or self.started:
self.draw_score(self.game_surface, sx, sy)
for popup in self.score_popups:
popup.draw(self.game_surface, self.big_font)
if self.flash_screen > 0:
fs = pygame.Surface((self.screen_width, self.screen_height))
fs.fill((255,255,255)); fs.set_alpha(int(200*(self.flash_screen/5)))
self.game_surface.blit(fs,(0,0))
if not self.started and self.message_img:
mr = self.message_img.get_rect(center=(self.screen_width//2, self.screen_height//2))
self.game_surface.blit(self.message_img, mr)
if self.game_over:
dark = pygame.Surface((self.screen_width, self.screen_height))
dark.fill((0,0,0)); dark.set_alpha(150)
self.game_surface.blit(dark, (0,0))
if self.gameover_img:
bounce = math.sin(self.game_over_timer/10)*3
go_r = self.gameover_img.get_rect(
center=(self.screen_width//2+sx, self.screen_height//2-40+sy+bounce))
self.game_surface.blit(self.gameover_img, go_r)
else:
t = self.font.render("GAME OVER", True, (255,0,0))
self.game_surface.blit(t, t.get_rect(center=(self.screen_width//2,self.screen_height//2-40)))
st = self.font.render(f"Score: {self.score}", True, (255,255,255))
self.game_surface.blit(st, st.get_rect(center=(self.screen_width//2, self.screen_height//2+10)))
if self.score >= self.high_score and self.score > 0:
ht = self.small_font.render("NEW HIGH SCORE!", True, (255,215,0))
else:
ht = self.small_font.render(f"Best: {self.high_score}", True, (200,200,200))
self.game_surface.blit(ht, ht.get_rect(center=(self.screen_width//2,self.screen_height//2+40)))
if not self.ai_mode:
rt = self.small_font.render("Click to restart", True, (255,255,255))
self.game_surface.blit(rt, rt.get_rect(center=(self.screen_width//2, self.screen_height//2+70)))
# ── AI / speed mode badge ──
mode_col = (80,255,120) if self.ai_mode else (255,150,80)
mode_txt = "AI [RSTDP]" if self.ai_mode else "HUMAN"
badge = self.ui_font.render(mode_txt, True, mode_col)
self.game_surface.blit(badge, (4, 4))
if self.fast_mode:
ft = self.ui_font.render(f"×{self.fast_ticks}", True, (255,220,60))
self.game_surface.blit(ft, (4, 20))
actual_screen.blit(self.game_surface, (vis_x_offset, 0))
# ── SNN visualiser ──
if self.ai_mode:
vs = self.visualiser.draw(self.agent)
actual_screen.blit(vs, (vis_x_offset + self.screen_width, 0))
pygame.display.flip()
def reset(self):
self.bird = Bird(self.screen_width, self.screen_height)
self.pipes = []; self.score = 0
self.game_over = False; self.spawn_timer = 0; self.started = False
self.particles.clear(); self.screen_shake = ScreenShake()
self.score_popups = []; self.game_over_timer = 0
self.bird_falling = False; self.flash_screen = 0
self._survive_timer = 0
def play_sound(self, name):
if name in self.sounds and self.sounds[name]:
try: self.sounds[name].play()
except: pass
def resize(self, nw, nh):
self.screen_width = nw; self.screen_height = nh
self.game_surface = pygame.Surface((nw, nh))
self.bg = self.bg_day and pygame.transform.scale(self.bg_day, (nw, nh)) or self._solid_bg((135,206,235))
self.bird.screen_width = nw; self.bird.screen_height = nh
for p in self.pipes:
p.screen_height = nh; p.base_height = self.base_height
p.bottom_rect.height = nh - p.height - p.GAP - self.base_height
self.visualiser = SNNVisualiser(nh)
# ─────────────────────────────────────────────────────────────────────────────
# MAIN
# ─────────────────────────────────────────────────────────────────────────────
def main():
cache_assets()
pygame.mixer.init()
pygame.init()
GAME_W, GAME_H = 288, 512
VIS_W = SNNVisualiser.W
# Start with vis panel visible (AI mode on)
total_w = GAME_W + VIS_W
screen = pygame.display.set_mode((total_w, GAME_H), pygame.RESIZABLE)
pygame.display.set_caption("Flappy Bird – SNN RSTDP Agent")
clock = pygame.time.Clock()
game = Game(GAME_W, GAME_H)
running = True
vis_visible = True # mirrors ai_mode
while running:
for event in pygame.event.get():
if event.type == pygame.QUIT:
running = False
if event.type == pygame.VIDEORESIZE:
screen = pygame.display.set_mode((event.w, event.h), pygame.RESIZABLE)
if event.type == pygame.KEYDOWN:
if event.key == pygame.K_ESCAPE:
running = False
elif event.key == pygame.K_a:
# Toggle AI mode
game.ai_mode = not game.ai_mode
vis_visible = game.ai_mode
new_w = GAME_W + (VIS_W if vis_visible else 0)
screen = pygame.display.set_mode((new_w, GAME_H), pygame.RESIZABLE)
game.reset()
elif event.key == pygame.K_s:
# Toggle fast mode
game.fast_mode = not game.fast_mode
elif event.key == pygame.K_r:
# Reset SNN weights
n_in = SNNAgent.N_IN; n_hid = SNNAgent.N_HID; n_out = SNNAgent.N_OUT
game.agent.layer1 = SNNLayer(n_in, n_hid, tau_e=80.0)
game.agent.layer2 = SNNLayer(n_hid, n_out, tau_e=80.0)
game.agent.dopamine = 0.0
print("SNN weights reset.")
elif event.key == pygame.K_SPACE:
if not game.ai_mode:
if game.game_over and not game.bird_falling:
game.reset()
elif not game.started:
game.started = True; game.bird.jump(); game.play_sound("wing")
elif not game.bird_falling:
game.bird.jump(); game.play_sound("wing")
if event.type == pygame.MOUSEBUTTONDOWN and not game.ai_mode:
if game.game_over and not game.bird_falling:
game.reset()
elif not game.started:
game.started = True; game.bird.jump(); game.play_sound("wing")
elif not game.bird_falling:
game.bird.jump(); game.play_sound("wing")
# Fast mode: run multiple update ticks per frame
ticks = game.fast_ticks if game.fast_mode and game.ai_mode else 1
for _ in range(ticks):
game.update()
if game.game_over and not game.bird_falling:
break # let it restart
game.draw(screen, vis_x_offset=0)
clock.tick(60)
pygame.quit()
sys.exit()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment