Last active
March 24, 2024 15:09
-
-
Save ivanstepanovftw/bb883c63f26d3665fcfb0a9d6d52a310 to your computer and use it in GitHub Desktop.
Conway's Game of Life using only convolutions with hand-made weights. Implementation in PyTorch of paper: "It's Hard for Neural Networks To Learn the Game of Life" https://arxiv.org/abs/2009.01398
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import plotly.graph_objects as go | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from plotly.subplots import make_subplots | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(message)s", datefmt="%H:%M:%S") | |
logger.setLevel(logging.DEBUG) | |
class GameOfLifeNN(nn.Module): | |
def __init__(self, s=20): | |
super(GameOfLifeNN, self).__init__() | |
# First layer: two 3x3 convolutional filters | |
self.conv1 = nn.Conv2d(1, 2, kernel_size=3, padding=1, bias=True) | |
# Manually setting weights and biases for the first layer | |
self.conv1.weight = nn.Parameter(torch.tensor([[[[1., 1., 1.], [1., 0.1, 1.], [1., 1., 1.]]], | |
[[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]]])) | |
self.conv1.bias = nn.Parameter(torch.tensor([-3., -2.])) | |
# Second layer: a single 1x1 convolutional filter | |
self.conv2 = nn.Conv2d(2, 1, kernel_size=1, bias=False) | |
self.conv2.weight = nn.Parameter(torch.tensor([[[[-10.]], [[1.]]]])) | |
# Third layer: weights and bias for final adjustment before sigmoid | |
self.final_weight = torch.tensor([[2*s]], dtype=torch.float32) | |
self.final_bias = torch.tensor(-s, dtype=torch.float32) | |
def forward(self, x): | |
# First layer with ReLU activation | |
x = F.relu(self.conv1(x)) | |
# Second layer with ReLU activation | |
x = F.relu(self.conv2(x)) | |
# Third layer with sigmoid activation | |
x = torch.sigmoid(x * self.final_weight + self.final_bias) | |
return x | |
# Create the model | |
model = GameOfLifeNN() | |
# Function to add a pattern to the state at a specified location | |
def add_pattern(state, pattern, top_left=(0, 0)): | |
x, y = top_left | |
shape = pattern.shape | |
state[0, 0, x:x + shape[0], y:y + shape[1]] = torch.tensor(pattern, dtype=torch.float32) | |
# Function to prepare the initial state with multiple patterns | |
def initialize_complex_state(size=100): | |
state = torch.zeros((1, 1, size, size)) | |
# Glider | |
glider = torch.tensor([[0., 1., 0.], [0., 0., 1.], [1., 1., 1.]]) | |
# Lightweight spaceship (LWSS) | |
lwss = torch.tensor([[0., 1., 0., 1., 0.], [1., 0., 0., 0., 1.], [1., 0., 0., 0., 0.], [1., 1., 1., 1., 1.]]) | |
# Add patterns to the state | |
add_pattern(state, glider, (1, 1)) | |
add_pattern(state, glider, (10, 15)) | |
add_pattern(state, glider, (20, 25)) | |
add_pattern(state, glider, (30, 35)) | |
add_pattern(state, lwss, (40, 10)) | |
add_pattern(state, glider, (50, 55)) | |
add_pattern(state, glider, (60, 65)) | |
add_pattern(state, lwss, (70, 20)) | |
add_pattern(state, lwss, (15, 45)) | |
add_pattern(state, glider, (75, 75)) | |
return state | |
# Function to update the state using the Game of Life NN model | |
def update_state(state): | |
with torch.no_grad(): | |
return model(state) | |
# Animation function | |
def animate_game_of_life(initial_state, num_steps=20): | |
frames = [initial_state.squeeze().numpy()] | |
state = initial_state | |
for _ in range(num_steps): | |
state = update_state(state) | |
frames.append(state.squeeze().numpy()) | |
fig = make_subplots(rows=1, cols=1) | |
fig.add_trace(go.Heatmap(z=frames[0], colorscale='Greys', showscale=False)) | |
# Update frames | |
fig.frames = [go.Frame(data=[go.Heatmap(z=frame, colorscale='Greys')]) for frame in frames] | |
# Animation configuration | |
fig.update_layout(updatemenus=[dict(type="buttons", showactive=False, | |
y=1, x=0.5, xanchor="center", yanchor="top", | |
buttons=[dict(label="Play", | |
method="animate", | |
args=[None, {"frame": {"duration": 200, "redraw": True}, | |
"fromcurrent": True}]), | |
dict(label="Pause", | |
method="animate", | |
args=[[None], {"frame": {"duration": 0, "redraw": True}, | |
"mode": "immediate"}])])]) | |
fig.update_layout(width=600, height=600) | |
fig.show() | |
# Prepare the initial state | |
initial_state = initialize_complex_state(size=100) | |
# Animate the Game of Life | |
animate_game_of_life(initial_state, num_steps=200) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment