Skip to content

Instantly share code, notes, and snippets.

@ivanstepanovftw
Last active March 24, 2024 15:09
Show Gist options
  • Save ivanstepanovftw/bb883c63f26d3665fcfb0a9d6d52a310 to your computer and use it in GitHub Desktop.
Save ivanstepanovftw/bb883c63f26d3665fcfb0a9d6d52a310 to your computer and use it in GitHub Desktop.
Conway's Game of Life using only convolutions with hand-made weights. Implementation in PyTorch of paper: "It's Hard for Neural Networks To Learn the Game of Life" https://arxiv.org/abs/2009.01398
import logging
import plotly.graph_objects as go
import torch
import torch.nn as nn
import torch.nn.functional as F
from plotly.subplots import make_subplots
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(message)s", datefmt="%H:%M:%S")
logger.setLevel(logging.DEBUG)
class GameOfLifeNN(nn.Module):
def __init__(self, s=20):
super(GameOfLifeNN, self).__init__()
# First layer: two 3x3 convolutional filters
self.conv1 = nn.Conv2d(1, 2, kernel_size=3, padding=1, bias=True)
# Manually setting weights and biases for the first layer
self.conv1.weight = nn.Parameter(torch.tensor([[[[1., 1., 1.], [1., 0.1, 1.], [1., 1., 1.]]],
[[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]]]))
self.conv1.bias = nn.Parameter(torch.tensor([-3., -2.]))
# Second layer: a single 1x1 convolutional filter
self.conv2 = nn.Conv2d(2, 1, kernel_size=1, bias=False)
self.conv2.weight = nn.Parameter(torch.tensor([[[[-10.]], [[1.]]]]))
# Third layer: weights and bias for final adjustment before sigmoid
self.final_weight = torch.tensor([[2*s]], dtype=torch.float32)
self.final_bias = torch.tensor(-s, dtype=torch.float32)
def forward(self, x):
# First layer with ReLU activation
x = F.relu(self.conv1(x))
# Second layer with ReLU activation
x = F.relu(self.conv2(x))
# Third layer with sigmoid activation
x = torch.sigmoid(x * self.final_weight + self.final_bias)
return x
# Create the model
model = GameOfLifeNN()
# Function to add a pattern to the state at a specified location
def add_pattern(state, pattern, top_left=(0, 0)):
x, y = top_left
shape = pattern.shape
state[0, 0, x:x + shape[0], y:y + shape[1]] = torch.tensor(pattern, dtype=torch.float32)
# Function to prepare the initial state with multiple patterns
def initialize_complex_state(size=100):
state = torch.zeros((1, 1, size, size))
# Glider
glider = torch.tensor([[0., 1., 0.], [0., 0., 1.], [1., 1., 1.]])
# Lightweight spaceship (LWSS)
lwss = torch.tensor([[0., 1., 0., 1., 0.], [1., 0., 0., 0., 1.], [1., 0., 0., 0., 0.], [1., 1., 1., 1., 1.]])
# Add patterns to the state
add_pattern(state, glider, (1, 1))
add_pattern(state, glider, (10, 15))
add_pattern(state, glider, (20, 25))
add_pattern(state, glider, (30, 35))
add_pattern(state, lwss, (40, 10))
add_pattern(state, glider, (50, 55))
add_pattern(state, glider, (60, 65))
add_pattern(state, lwss, (70, 20))
add_pattern(state, lwss, (15, 45))
add_pattern(state, glider, (75, 75))
return state
# Function to update the state using the Game of Life NN model
def update_state(state):
with torch.no_grad():
return model(state)
# Animation function
def animate_game_of_life(initial_state, num_steps=20):
frames = [initial_state.squeeze().numpy()]
state = initial_state
for _ in range(num_steps):
state = update_state(state)
frames.append(state.squeeze().numpy())
fig = make_subplots(rows=1, cols=1)
fig.add_trace(go.Heatmap(z=frames[0], colorscale='Greys', showscale=False))
# Update frames
fig.frames = [go.Frame(data=[go.Heatmap(z=frame, colorscale='Greys')]) for frame in frames]
# Animation configuration
fig.update_layout(updatemenus=[dict(type="buttons", showactive=False,
y=1, x=0.5, xanchor="center", yanchor="top",
buttons=[dict(label="Play",
method="animate",
args=[None, {"frame": {"duration": 200, "redraw": True},
"fromcurrent": True}]),
dict(label="Pause",
method="animate",
args=[[None], {"frame": {"duration": 0, "redraw": True},
"mode": "immediate"}])])])
fig.update_layout(width=600, height=600)
fig.show()
# Prepare the initial state
initial_state = initialize_complex_state(size=100)
# Animate the Game of Life
animate_game_of_life(initial_state, num_steps=200)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment