Last active
September 13, 2023 13:36
-
-
Save Logrus/e5cd1b6f70f3898f56ecdf54fcdfcfa2 to your computer and use it in GitHub Desktop.
Memorize an image with NN, made to play around with sin activation, positional encodings and etc.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Recently Implicit Neural Representations gain popularity. | |
One of the issues there though is that the learned representations | |
are low-frequency biased, resulting in over-smoothed representations. | |
There has been a few approaches suggested to alleviate the issue, | |
for example by using positional encodings. | |
An alternative could be using Sin/Cos activation functions, | |
which in essence present a learnable basis functions. | |
A commonly used example to get a feel for a problem is | |
an image-memorization problem, where MLP has to map from (u, v) | |
pixel coordinates to (r,g,b) color. | |
Using ReLU in this example results in overly smoothed representation, | |
however using Sin/Cos activation is helping to represent higher frequencies better. | |
Another parameter that can be changed is normalizing (u, v) coordinates | |
between 0 and 1, which helps with ReLU activations, however unnormalized coordinates | |
work better with Sin/Cos activations (since periodic function wraps it back, | |
that doesn't destroy training), then the network converges almost instantly. | |
In addition, this code contains a demonstration of what MLP produces when asked about (u, v) | |
beyond image bounds. | |
Video available on Youtube: | |
https://youtu.be/AYFoXcl6zyU | |
""" | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import torch | |
plt.ion() | |
fig = plt.figure() | |
ax = fig.add_subplot(111) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# ===================================================================== | |
# A sin activation | |
class Sin(torch.nn.Module): | |
__constants__ = ["inplace"] | |
inplace: bool | |
def __init__(self, inplace: bool = False): | |
super(Sin, self).__init__() | |
self.inplace = inplace | |
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
return torch.sin(input) | |
def extra_repr(self) -> str: | |
inplace_str = "inplace=True" if self.inplace else "" | |
return inplace_str | |
class Cos(torch.nn.Module): | |
__constants__ = ["inplace"] | |
inplace: bool | |
def __init__(self, inplace: bool = False): | |
super(Cos, self).__init__() | |
self.inplace = inplace | |
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
return torch.cos(input) | |
def extra_repr(self) -> str: | |
inplace_str = "inplace=True" if self.inplace else "" | |
return inplace_str | |
# ===================================================================== | |
# ===================================================================== | |
# Network | |
class StupidNet(torch.nn.Module): | |
def __init__(self) -> None: | |
super(StupidNet, self).__init__() | |
# activation = Sin() | |
# activation = torch.nn.LogSigmoid() | |
activation = torch.nn.ReLU() | |
# activation = Cos() | |
# A relatively high amount of neurons is needed | |
# for learning a proper representation | |
self.layers = torch.nn.Sequential( | |
torch.nn.Linear(2, 512), | |
Sin(), | |
torch.nn.Linear(512, 512), | |
activation, | |
torch.nn.Linear(512, 512), | |
activation, | |
torch.nn.Linear(512, 32), | |
activation, | |
torch.nn.Linear(32, 32), | |
activation, | |
torch.nn.Linear(32, 3), | |
) | |
def forward(self, coord): | |
color = self.layers(coord) | |
return color | |
# ===================================================================== | |
# Load and normalize an image | |
image = Image.open("cat_small.jpeg") | |
image_array = np.array(image, dtype=np.float32) | |
# Image is normalized in [-0.5, 0.5] | |
image_normalized = (image_array / 255.0) - 0.5 | |
H, W, _ = image_normalized.shape | |
print(f"Image size, height: {H}, width {W}") | |
# ===================================================================== | |
# Create training data | |
X = np.zeros((H * W, 2), dtype=np.float32) | |
Y = np.zeros((H * W, 3), dtype=np.float32) | |
for i in range(H): | |
for j in range(W): | |
# Normalized coordinates, work better with ReLU | |
X[i * W + j] = np.array([i / H, j / W]) | |
# Unnormalized coordinates, work better with Sin | |
# X[i * W + j] = np.array([i, j]) | |
Y[i * W + j] = image_normalized[i, j] | |
# ===================================================================== | |
# Query MLP beyond learned data with some padding around | |
padding = 100 | |
X_with_padding = np.zeros(((H + padding * 2) * (W + padding * 2), 2), dtype=np.float32) | |
for i in range(-padding, H + padding): | |
for j in range(-padding, W + padding): | |
# Unnormalized coordinates | |
# X_with_padding[(i+padding)*(W+padding*2) + (j+padding)] = np.array([i,j]) | |
# Normalized coordinates | |
X_with_padding[(i + padding) * (W + padding * 2) + (j + padding)] = np.array( | |
[i / H, j / W] | |
) | |
X_tensor_padding = torch.tensor(X_with_padding).to(device) | |
# The dataset is small so no batching is needed | |
# everything can be loaded in GPU memory | |
X_tensor = torch.tensor(X).to(device) | |
Y_tensor = torch.tensor(Y).to(device) | |
# ===================================================================== | |
# Show original image | |
ax.set_title("Original image") | |
ax.imshow(Y.reshape((H, W, 3)) + 0.5) | |
plt.pause(1.0) | |
def train_loop(X, y, model, loss_fn, optimizer): | |
# Compute prediction and loss | |
pred = model(X) | |
loss = loss_fn(pred, y) | |
# Backpropagation | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
loss = loss.item() | |
print(f"loss: {loss:>7f}") | |
def extract_image(model, X_in): | |
pred = model.forward(X_in) | |
np_arr = pred.cpu().detach().numpy() | |
min, max = np.min(np_arr), np.max(np_arr) | |
print("Image min max ", min, max) | |
# Predicted image overflows the allowed range [0, 1] | |
# so re-normalization is possible, although not required, | |
# the shown image is still ok | |
# image = (np_arr.reshape((H, W, 3)) - min) / (max - min) | |
image = np_arr.reshape((H + padding * 2, W + padding * 2, 3)) + 0.5 | |
return image | |
# ===================================================================== | |
# Initialize model and optimizer | |
model = StupidNet().to(device) | |
loss_fn = torch.nn.MSELoss() | |
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) | |
epochs = 50000 | |
for e in range(epochs): | |
train_loop(X_tensor, Y_tensor, model, loss_fn, optimizer) | |
if (e % 100) == 0: | |
image_learned = extract_image(model, X_tensor_padding) | |
ax.imshow(image_learned) | |
ax.set_title(f"Epoch {e}") | |
plt.pause(0.001) | |
# plt.savefig("training_beoyond_edges/image_{:06}".format(e)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Test image: