Skip to content

Instantly share code, notes, and snippets.

@ekatrukha
Created May 7, 2025 14:12
Show Gist options
  • Save ekatrukha/aae9713056aed954a45fd1ac5617e1a0 to your computer and use it in GitHub Desktop.
Save ekatrukha/aae9713056aed954a45fd1ac5617e1a0 to your computer and use it in GitHub Desktop.
NeRF ReLU network with fourier features and ND single channel input
# -*- coding: utf-8 -*-
"""
Created on Mon Jan 20 16:29:11 2025
@author: ekatrukha
"""
import os
from dataclasses import dataclass
from pathlib import Path
from collections.abc import Iterable
from typing import List
import imageio
from skimage.transform import resize
from skimage.color import rgba2rgb
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
#!pip install livelossplot --quiet
from livelossplot import PlotLosses
#image = imageio.v2.imread('uu-logo-1-8bit.tif')[...]
image = imageio.v2.imread('circles_1.tif')[...]
#normalize
imagenp=(image-image.min())/(image.max()-image.min())
#plt.imshow(image);
#image=np.swapaxes(image,0,2)
#image=np.swapaxes(image,0,1)
#image = torch.unsqueeze(torch.tensor(image, dtype=torch.float32),-1)
image = torch.tensor(imagenp, dtype=torch.float32)
image_dims = list(image.shape)
nSpaceDim = len(image_dims)
nOutDim = 1
# sample the grid, which will be the input to the model
gridlist = list(image_dims[0:nSpaceDim])
npgrid = np.stack(np.meshgrid(*[np.linspace(0,1,d, endpoint=False) for d in gridlist],indexing='ij'),-1)
grid = torch.tensor(npgrid, dtype=torch.float32)
X, Y = [torch.reshape(grid,(-1, nSpaceDim)), torch.reshape(image,(-1,nOutDim))]
#X, Y = [grid.view(-1, nSpaceDim), image.view(-1, nOutDim)]
test_X, test_y = [X[1::2], Y[1::2]]
train_X, train_y = [X[::2], Y[::2]]
test_X.requires_grad = False
train_X.requires_grad = False
class NeuralField(nn.Module):
def __init__(self, hidden_layers=2, neurons_per_layer=1024, input_dimension=2):
super().__init__()
self.input_layer = nn.Linear(input_dimension, neurons_per_layer)
self.hidden_layers = nn.ModuleList([nn.Linear(neurons_per_layer, neurons_per_layer) for i in range(hidden_layers)])
self.output_layer = nn.Linear(neurons_per_layer, 1)
def forward(self, input):
x = F.relu(self.input_layer(input))
for layer in self.hidden_layers:
x = F.relu(layer(x))
return torch.sigmoid(self.output_layer(x))
def mse(gt, pred):
return 0.5 * torch.mean((gt - pred) ** 2., (-1, -2)).sum(-1).mean()
def psnr(gt, pred):
return -10 * torch.log10(2. * torch.mean((gt - pred) ** 2.))
FOURIER_DIM = 256
##feature scale, parameter that specifies scale size
FOURIER_SCALE = 3.
INPUT_DIMS = 2 * FOURIER_DIM
B = FOURIER_SCALE * torch.randn(size=(nSpaceDim, FOURIER_DIM), requires_grad=False)
def apply_fourier_features(x, B):
projection = (2 * np.pi * x) @ B
transformed = torch.cat([torch.sin(projection), torch.cos(projection)], dim=-1)
return transformed
model = nn.DataParallel(NeuralField(input_dimension=INPUT_DIMS).cuda())
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
ff_images = []
liveloss = PlotLosses()
for i in range(700):
optimizer.zero_grad(set_to_none=True)
prediction = model(apply_fourier_features(train_X, B))
loss = mse(train_y.to('cuda'), prediction)
loss.backward()
optimizer.step()
if i % 25 == 0:
with torch.no_grad():
optimizer.zero_grad(set_to_none=True)
reconstruction = model(apply_fourier_features(X, B)).detach().cpu()
print('loss '+str(loss.item()))
print('iteration' + str(i))
# liveloss.update({'PSNR train': psnr(train_y, prediction.detach().cpu()),
# 'Loss train': mse(train_y, prediction.detach().cpu()),
# 'PSNR test': psnr(test_y, reconstruction[::2]),
# 'Loss test': mse(test_y, reconstruction[::2])},
# current_step=i)
# liveloss.send()
ff_images.append(reconstruction.cpu().detach().numpy().reshape(image.shape))
all_images = np.stack(ff_images)
data8 = (255*np.clip(all_images,0,1)).astype(np.uint8)
f = os.path.join('uulogotrain_no_ff.tif')
imageio.mimwrite(f, data8)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment