Skip to content

Instantly share code, notes, and snippets.

@thomasahle
Created July 7, 2024 17:30
Show Gist options
  • Save thomasahle/08c2d23c9b82ad6e2ba18fa2bd5742bb to your computer and use it in GitHub Desktop.
Save thomasahle/08c2d23c9b82ad6e2ba18fa2bd5742bb to your computer and use it in GitHub Desktop.
Randomly Initialized MLPs with Different Activation Functions
import torch
import math
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as mcolors
import numpy as np
import seaborn as sns
device = "cpu"
width = 1000
depth = 100
iterations = 3
title = "Randomly Initialized MLPs with Different Activation Functions"
def gelu2(x):
Phi = 0.5 * (1 + torch.erf(x / math.sqrt(2)))
phi = torch.exp(-0.5 * x**2) / math.sqrt(2 * torch.pi)
return x * Phi - phi
activation_functions = {
"Identity": lambda x: x,
"ReLU": lambda x: F.relu(x) * nn.init.calculate_gain("relu"),
"Tanh": lambda x: torch.tanh(x) * nn.init.calculate_gain("tanh"),
"GeLU": lambda x: F.gelu(x) * 1.70093,
"SiLU/Swish": lambda x: F.silu(x) * 1.78719,
"Mean 0 GeLU: x Φ(x) − ϕ(x)": lambda x: gelu2(x) * 1.533530,
# "Sine": lambda x: torch.sin(x) * 1.19267,
"SELU": lambda x: F.selu(x) * nn.init.calculate_gain("selu"),
"ELU": lambda x: F.elu(x) * 1.269457,
"Leaky ReLU": lambda x: F.leaky_relu(x) * nn.init.calculate_gain("leaky_relu"),
}
class MLP(nn.Module):
def __init__(self, width, depth, act_func):
super().__init__()
self.layer0 = nn.Linear(5, width, bias=False)
self.layers = nn.ModuleList(
[nn.Linear(width, width, bias=False) for _ in range(depth)]
)
self.act_func = act_func
for p in self.layers.parameters():
p.data[:] = torch.randn_like(p.data) / (width**0.5)
def forward(self, x):
norms = []
angles = []
x = self.layer0(x)
norms.append(x.norm(dim=1).mean().item())
x_normalized = x / x.norm(dim=1, keepdim=True)
angles.append((x_normalized[0] @ x_normalized[1]).item())
for layer in self.layers:
x = self.act_func(x)
x = layer(x)
norms.append(x.norm(dim=1).mean().item())
x_normalized = x / x.norm(dim=1, keepdim=True)
angles.append((x_normalized[0] @ x_normalized[1]).item())
return x, norms, angles
sns.set(style="whitegrid")
fig, axes = plt.subplots(3, 3, figsize=(16, 10))
fig.suptitle(title)
axes = axes.flatten()
def add_variation_to_color(color, variation=0.1):
"""Adds a little randomness to a given base color."""
color = np.array(mcolors.to_rgb(color))
noise = np.random.uniform(-variation, variation, color.shape)
new_color = np.clip(color + noise, 0, 1)
return new_color
oranges = [add_variation_to_color("orange") for _ in range(iterations)]
blues = [add_variation_to_color("tab:blue") for _ in range(iterations)]
data = torch.randn(iterations, 2, 5, device=device)
data /= data.norm(dim=2, keepdim=True)
net = MLP(width=width, depth=depth, act_func=None).to(device)
for i, (act_name, act_func) in enumerate(activation_functions.items()):
print(i, act_name)
net.act_func = act_func
all_norms = []
all_angles = []
for j in range(iterations):
with torch.no_grad():
outputs, norms, angles = net(data[j])
all_norms.append(norms)
all_angles.append(angles)
ax = axes[i]
# Plotting the angles
ax.set_title(f"{act_name}")
ax.set_xlabel("Layer")
ax.set_ylabel("Angle", color="orange")
for j, angles in enumerate(all_angles):
ax.plot(
angles,
color=oranges[j],
label=f"Angle Iteration {j + 1}",
)
# Plotting the norms
ax2 = ax.twinx()
ax2.set_ylabel("Norm", color="tab:blue")
for j, norms in enumerate(all_norms):
ax2.plot(norms, color=blues[j], label=f"Norm Iteration {j + 1}")
# Setting grid and legends
ax.grid(True, axis="y")
ax.xaxis.grid(False)
ax2.grid(False)
for ax in [ax, ax2]:
ax.spines["right"].set_visible(False)
ax.spines["left"].set_visible(False)
ax.spines["top"].set_visible(False)
plt.tight_layout()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment