Created
May 7, 2025 14:12
-
-
Save ekatrukha/aae9713056aed954a45fd1ac5617e1a0 to your computer and use it in GitHub Desktop.
NeRF ReLU network with fourier features and ND single channel input
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
Created on Mon Jan 20 16:29:11 2025 | |
@author: ekatrukha | |
""" | |
import os | |
from dataclasses import dataclass | |
from pathlib import Path | |
from collections.abc import Iterable | |
from typing import List | |
import imageio | |
from skimage.transform import resize | |
from skimage.color import rgba2rgb | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import torchvision | |
import torchvision.transforms as transforms | |
#!pip install livelossplot --quiet | |
from livelossplot import PlotLosses | |
#image = imageio.v2.imread('uu-logo-1-8bit.tif')[...] | |
image = imageio.v2.imread('circles_1.tif')[...] | |
#normalize | |
imagenp=(image-image.min())/(image.max()-image.min()) | |
#plt.imshow(image); | |
#image=np.swapaxes(image,0,2) | |
#image=np.swapaxes(image,0,1) | |
#image = torch.unsqueeze(torch.tensor(image, dtype=torch.float32),-1) | |
image = torch.tensor(imagenp, dtype=torch.float32) | |
image_dims = list(image.shape) | |
nSpaceDim = len(image_dims) | |
nOutDim = 1 | |
# sample the grid, which will be the input to the model | |
gridlist = list(image_dims[0:nSpaceDim]) | |
npgrid = np.stack(np.meshgrid(*[np.linspace(0,1,d, endpoint=False) for d in gridlist],indexing='ij'),-1) | |
grid = torch.tensor(npgrid, dtype=torch.float32) | |
X, Y = [torch.reshape(grid,(-1, nSpaceDim)), torch.reshape(image,(-1,nOutDim))] | |
#X, Y = [grid.view(-1, nSpaceDim), image.view(-1, nOutDim)] | |
test_X, test_y = [X[1::2], Y[1::2]] | |
train_X, train_y = [X[::2], Y[::2]] | |
test_X.requires_grad = False | |
train_X.requires_grad = False | |
class NeuralField(nn.Module): | |
def __init__(self, hidden_layers=2, neurons_per_layer=1024, input_dimension=2): | |
super().__init__() | |
self.input_layer = nn.Linear(input_dimension, neurons_per_layer) | |
self.hidden_layers = nn.ModuleList([nn.Linear(neurons_per_layer, neurons_per_layer) for i in range(hidden_layers)]) | |
self.output_layer = nn.Linear(neurons_per_layer, 1) | |
def forward(self, input): | |
x = F.relu(self.input_layer(input)) | |
for layer in self.hidden_layers: | |
x = F.relu(layer(x)) | |
return torch.sigmoid(self.output_layer(x)) | |
def mse(gt, pred): | |
return 0.5 * torch.mean((gt - pred) ** 2., (-1, -2)).sum(-1).mean() | |
def psnr(gt, pred): | |
return -10 * torch.log10(2. * torch.mean((gt - pred) ** 2.)) | |
FOURIER_DIM = 256 | |
##feature scale, parameter that specifies scale size | |
FOURIER_SCALE = 3. | |
INPUT_DIMS = 2 * FOURIER_DIM | |
B = FOURIER_SCALE * torch.randn(size=(nSpaceDim, FOURIER_DIM), requires_grad=False) | |
def apply_fourier_features(x, B): | |
projection = (2 * np.pi * x) @ B | |
transformed = torch.cat([torch.sin(projection), torch.cos(projection)], dim=-1) | |
return transformed | |
model = nn.DataParallel(NeuralField(input_dimension=INPUT_DIMS).cuda()) | |
model.train() | |
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) | |
ff_images = [] | |
liveloss = PlotLosses() | |
for i in range(700): | |
optimizer.zero_grad(set_to_none=True) | |
prediction = model(apply_fourier_features(train_X, B)) | |
loss = mse(train_y.to('cuda'), prediction) | |
loss.backward() | |
optimizer.step() | |
if i % 25 == 0: | |
with torch.no_grad(): | |
optimizer.zero_grad(set_to_none=True) | |
reconstruction = model(apply_fourier_features(X, B)).detach().cpu() | |
print('loss '+str(loss.item())) | |
print('iteration' + str(i)) | |
# liveloss.update({'PSNR train': psnr(train_y, prediction.detach().cpu()), | |
# 'Loss train': mse(train_y, prediction.detach().cpu()), | |
# 'PSNR test': psnr(test_y, reconstruction[::2]), | |
# 'Loss test': mse(test_y, reconstruction[::2])}, | |
# current_step=i) | |
# liveloss.send() | |
ff_images.append(reconstruction.cpu().detach().numpy().reshape(image.shape)) | |
all_images = np.stack(ff_images) | |
data8 = (255*np.clip(all_images,0,1)).astype(np.uint8) | |
f = os.path.join('uulogotrain_no_ff.tif') | |
imageio.mimwrite(f, data8) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment