Last active
August 19, 2022 04:56
-
-
Save eyaler/20a9037a3619378b276b2303dadb558d to your computer and use it in GitHub Desktop.
siren_video.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.4" | |
}, | |
"pycharm": { | |
"stem_cell": { | |
"cell_type": "raw", | |
"metadata": { | |
"collapsed": false | |
}, | |
"source": [] | |
} | |
}, | |
"colab": { | |
"name": "siren_video.ipynb", | |
"private_outputs": true, | |
"provenance": [], | |
"collapsed_sections": [], | |
"machine_shape": "hm", | |
"include_colab_link": true | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/eyaler/20a9037a3619378b276b2303dadb558d/notebook.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "3wCS6y_o0zw6", | |
"cellView": "form" | |
}, | |
"source": [ | |
"#@title SIREN Image/Video Fitting\n", | |
"\n", | |
"#@markdown Based on https://vsitzmann.github.io/siren and https://www.matthewtancik.com/learnit\n", | |
"\n", | |
"#@markdown Made just a little bit more accessible by Eyal Gruss [@eyaler](https://twitter.com/eyaler) [eyalgruss.com](https://eyalgruss.com)\n", | |
"\n", | |
"#@markdown Notes:\n", | |
"#@markdown * When using load_initial_model, (hidden_layers, features) are overridden to the loaded dimensions.\n", | |
"#@markdown * For loaded learnit models and optimizer==Adam, you better set **lr=1e-4**.\n", | |
"#@markdown * After fitting a model, you can set **epochs=0** to infer with different parameters.\n", | |
"\n", | |
"url = 'https://eyalgruss.com/media/green1.mp4' #@param ['https://api.time.com/wp-content/uploads/2019/06/final.trump_.cover_.jpg?w=800', 'http://www.fmwconcepts.com/misc_tests/FFT_tests/lena_roundtrip/lena.jpg', 'https://eyalgruss.com/media/green1.mp4', 'https://www.youtube.com/watch?v=IEqccPhsqgA'] {allow-input: true}\n", | |
"load_initial_model = 'none' #@param ['none', 'learnit_official_3_256', 'learnit_eyal_10_256', 'learnit_eyal_16_256', 'last_saved'] {allow-input: true}\n", | |
"init_spacetime_factor = 0#@param {'type':'number'}\n", | |
"temporal_phase_factor = 0#@param {'type':'number'}\n", | |
"color_mode = 'original' #@param ['original','grayscale','colorize']\n", | |
"save_gradient = False #@param {'type':'boolean'}\n", | |
"start_seconds = 0#@param {'type':'number'}\n", | |
"limit_seconds = 60#@param {'type':'number'}\n", | |
"limit_frames = 150#@param {'type':'integer'}\n", | |
"batch_level = 'pixels' #@param ['pixels','frames','sequeנnces']\n", | |
"batch_frac = 0.1#@param {'type':'number'}\n", | |
"spatial_interpolation_factor = 1#@param {'type':'number'}\n", | |
"temporal_interpolation_factor = 1#@param {'type':'number'}\n", | |
"temporal_extrapolation_fraction = 0#@param {'type':'number'}\n", | |
"hidden_layers = 6#@param {'type':'integer'}\n", | |
"features = 128#@param {'type':'integer'}\n", | |
"optimizer = 'Adam' #@param ['Adam','SGD']\n", | |
"lr = 0.0001 #@param {'type':'number'}\n", | |
"reg = 0 #@param {'type':'number'}\n", | |
"epochs = 150#@param {'type':'integer'}\n", | |
"show_every = 10#@param {'type':'integer'}\n", | |
"save_every = 10#@param {'type':'integer'}\n", | |
"\n", | |
"%cd /content\n", | |
"machine = !nvidia-smi -L\n", | |
"print(machine)\n", | |
"\n", | |
"!pip install youtube-dl\n", | |
"!pip install -U imageio\n", | |
"!pip install -U imageio-ffmpeg\n", | |
"import youtube_dl\n", | |
"import torch\n", | |
"from torch import nn\n", | |
"import os\n", | |
"import numpy as np\n", | |
"import imageio\n", | |
"import matplotlib.pyplot as plt\n", | |
"from time import time\n", | |
"import pickle\n", | |
"from IPython.display import HTML\n", | |
"from base64 import b64encode\n", | |
"\n", | |
"first_omega_0 = 30\n", | |
"hidden_omega_0 = 30\n", | |
"init = None\n", | |
"factor = 1\n", | |
"is_learnit = False\n", | |
"if load_initial_model != 'none':\n", | |
" if 'learnit' in load_initial_model:\n", | |
" is_learnit = True\n", | |
" !pip install -q git+https://github.com/deepmind/dm-haiku\n", | |
" if load_initial_model == 'learnit_official_3_256':\n", | |
" learnit_ckpt = 'https://people.eecs.berkeley.edu/~tancik/learnit/checkpoints/maml_ilr_0.01_olr_1e-05_bs_3_150000.pkl'\n", | |
" elif load_initial_model == 'learnit_eyal_10_256':\n", | |
" learnit_ckpt = 'https://eyalgruss.com/models/maml_ilr_0.010000_olr_0.000010_bs_3_53000_31.46_30_2_256_10.pkl'\n", | |
" elif load_initial_model == 'learnit_eyal_16_256':\n", | |
" learnit_ckpt = 'https://eyalgruss.com/models/maml_ilr_0.010000_olr_0.000010_bs_1_res_178_numval_100_step_65400_psnr_30.09_w0_30_inner_2_feat_256_hid_16.pkl'\n", | |
" !wget --no-check-certificate -nc $learnit_ckpt\n", | |
" with open('/content/'+learnit_ckpt.rsplit('/',1)[-1], 'rb') as file:\n", | |
" init = pickle.load(file)\n", | |
" features = len(init['siren__model/siren_layer/linear']['b'])\n", | |
" if load_initial_model == 'learnit_official_3_256':\n", | |
" first_omega_0 = 200\n", | |
" hidden_omega_0 = 200\n", | |
" factor = 2\n", | |
" else:\n", | |
" init = torch.load('/content/output/model.pt').net\n", | |
" features = init[0].linear.out_features\n", | |
" hidden_layers = len(init)-2\n", | |
" print('loaded model with hidden_layers=%d features=%d'%(hidden_layers,features))\n", | |
"\n", | |
"def get_mgrid(shape, frame=None, spatial_interpolation_factor=1, temporal_interpolation_factor=1, temporal_extrapolation_fraction=0):\n", | |
" '''Generates a flattened grid of ([t],y,x) coordinates in a range of -1 to 1.'''\n", | |
" if spatial_interpolation_factor!=1:\n", | |
" shape = (*shape[:-2], int(np.ceil(shape[-2]*spatial_interpolation_factor)), int(np.ceil(shape[-1]*spatial_interpolation_factor)))\n", | |
" if len(shape)==3 and temporal_interpolation_factor!=1:\n", | |
" shape = (int(np.ceil(shape[0]*temporal_interpolation_factor)), *shape[1:])\n", | |
" tensors = [torch.linspace(-1, 1, steps=ax) for ax in shape]\n", | |
" if len(shape)==3:\n", | |
" if temporal_extrapolation_fraction:\n", | |
" extrap = torch.linspace(1, 1+2*abs(temporal_extrapolation_fraction), steps=int(np.round(shape[0]*abs(temporal_extrapolation_fraction))))[1:]\n", | |
" if temporal_extrapolation_fraction>0:\n", | |
" tensors[0] = torch.cat((tensors[0],extrap))\n", | |
" else:\n", | |
" tensors[0] = torch.cat((-extrap.flip(0),tensors[0]))\n", | |
" elif frame is not None:\n", | |
" tensors[0] = torch.tensor(float(frame))\n", | |
" mgrid = torch.stack(torch.meshgrid(*tensors), dim=-1)\n", | |
" mgrid = mgrid.reshape(-1, len(shape))\n", | |
" return mgrid.cuda(), shape\n", | |
"\n", | |
"def laplace(y, x):\n", | |
" grad = gradient(y, x)\n", | |
" return divergence(grad, x)\n", | |
"\n", | |
"def divergence(y, x):\n", | |
" div = 0.\n", | |
" for i in range(y.shape[-1]):\n", | |
" div += torch.autograd.grad(y[..., i], x, torch.ones_like(y[..., i]), create_graph=True)[0][..., i:i+1]\n", | |
" return div\n", | |
"\n", | |
"def gradient(y, x, grad_outputs=None):\n", | |
" if grad_outputs is None:\n", | |
" grad_outputs = torch.ones_like(y)\n", | |
" grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0]\n", | |
" return grad\n", | |
"\n", | |
"class SineLayer(nn.Module):\n", | |
" # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.\n", | |
" \n", | |
" # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the \n", | |
" # nonlinearity. Different signals may require different omega_0 in the first layer - this is a \n", | |
" # hyperparameter.\n", | |
" \n", | |
" # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of \n", | |
" # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)\n", | |
" \n", | |
" def __init__(self, in_features, out_features, bias=True,\n", | |
" is_first=False, omega_0=hidden_omega_0, init=None, is_learnit=False, init_spacetime_factor=0):\n", | |
" super().__init__()\n", | |
" self.omega_0 = omega_0\n", | |
" self.is_first = is_first\n", | |
" self.in_features = in_features\n", | |
" self.linear = nn.Linear(in_features, out_features, bias=bias)\n", | |
" \n", | |
" self.init_weights(bias=bias, init=init, init_spacetime_factor=init_spacetime_factor)\n", | |
" \n", | |
" def init_weights(self, bias=True, init=None, init_spacetime_factor=0):\n", | |
" with torch.no_grad():\n", | |
" if self.is_first:\n", | |
" self.linear.weight.uniform_(-1 / self.in_features, \n", | |
" 1 / self.in_features) \n", | |
" elif init is None:\n", | |
" self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, \n", | |
" np.sqrt(6 / self.in_features) / self.omega_0)\n", | |
" if init is not None:\n", | |
" if is_learnit:\n", | |
" w = torch.from_numpy(init['w']).permute(1,0)\n", | |
" else:\n", | |
" w = init.weight\n", | |
" self.linear.weight[:,self.linear.weight.shape[1]-w.shape[1]:].copy_(w)\n", | |
" if bias:\n", | |
" if is_learnit:\n", | |
" b = torch.from_numpy(init['b'])\n", | |
" else:\n", | |
" b = init.bias\n", | |
" self.linear.bias.copy_(b)\n", | |
" if self.is_first and init_spacetime_factor:\n", | |
" self.linear.weight[:,0].copy_(w[:,-1]*init_spacetime_factor)\n", | |
" \n", | |
" def forward(self, input):\n", | |
" return torch.sin(self.omega_0 * self.linear(input))\n", | |
" \n", | |
" \n", | |
"class Siren(nn.Module):\n", | |
" def __init__(self, in_features, hidden_features, hidden_layers, out_features, bias=True, outermost_linear=False, \n", | |
" first_omega_0=first_omega_0, hidden_omega_0=hidden_omega_0, init=None, is_learnit=False, init_spacetime_factor=0, factor=1):\n", | |
" super().__init__()\n", | |
" \n", | |
" self.net = []\n", | |
" layer_init = None\n", | |
"\n", | |
" if is_learnit:\n", | |
" layer_init = init['siren__model/siren_layer/linear']\n", | |
" elif init is not None:\n", | |
" layer_init = init[0].linear\n", | |
" self.factor = factor\n", | |
" self.net.append(SineLayer(in_features, hidden_features, bias=bias,\n", | |
" is_first=True, omega_0=first_omega_0, init=layer_init, init_spacetime_factor=init_spacetime_factor))\n", | |
"\n", | |
" for i in range(hidden_layers):\n", | |
" if is_learnit:\n", | |
" layer_init = init['siren__model/siren_layer_%d/linear'%(i+1)]\n", | |
" elif init is not None:\n", | |
" layer_init = init[i+1].linear\n", | |
" self.net.append(SineLayer(hidden_features, hidden_features, bias=bias,\n", | |
" is_first=False, omega_0=hidden_omega_0, init=layer_init))\n", | |
"\n", | |
" if is_learnit:\n", | |
" layer_init = init['siren__model/siren_layer_%d/linear'%(i+2)]\n", | |
" elif init is not None:\n", | |
" layer_init = init[i+2]\n", | |
" if hasattr(layer_init, 'linear'):\n", | |
" layer_init = layer_init.linear\n", | |
"\n", | |
" if outermost_linear:\n", | |
" final_linear = nn.Linear(hidden_features, out_features, bias=bias)\n", | |
" \n", | |
" with torch.no_grad():\n", | |
" if layer_init is None:\n", | |
" final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0, \n", | |
" np.sqrt(6 / hidden_features) / hidden_omega_0)\n", | |
" else:\n", | |
" if is_learnit:\n", | |
" w = torch.from_numpy(layer_init['w']).permute(1,0)\n", | |
" else:\n", | |
" w = layer_init.weight\n", | |
" if out_features==1:\n", | |
" w = w.mean(axis=0, keepdims=True)\n", | |
" final_linear.weight.copy_(w)\n", | |
" if bias:\n", | |
" if is_learnit:\n", | |
" b = torch.from_numpy(layer_init['b'])\n", | |
" else:\n", | |
" b = layer_init.bias\n", | |
" if out_features==1:\n", | |
" b = b.mean(axis=0, keepdims=True)\n", | |
" final_linear.bias.copy_(b)\n", | |
" \n", | |
" self.net.append(final_linear)\n", | |
" else:\n", | |
" self.net.append(SineLayer(hidden_features, out_features, bias=bias, \n", | |
" is_first=False, omega_0=hidden_omega_0, init=layer_init))\n", | |
" \n", | |
" self.net = nn.Sequential(*self.net)\n", | |
" \n", | |
" def forward(self, coords):\n", | |
" coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input\n", | |
" output = (self.net(coords)*self.factor+1)/2\n", | |
" return output, coords\n", | |
"\n", | |
"def is_supported(url):\n", | |
" if url.lower().endswith(('.png','.jpg','.jpeg','.bmp','.webp','.mp4','.mov','.avi','.wmv','.mpg','.mpeg','.mkv','.webm')):\n", | |
" return False\n", | |
" extractors = youtube_dl.extractor.gen_extractors()\n", | |
" for e in extractors:\n", | |
" if e.suitable(url) and e.IE_NAME != 'generic':\n", | |
" return True\n", | |
" return False\n", | |
"\n", | |
"os.makedirs('/content/output', exist_ok=True)\n", | |
"if is_supported(url):\n", | |
" !rm -f /content/output/video.mp4\n", | |
" !youtube-dl --no-playlist -f \"bestvideo[ext=mp4][vcodec!*=av01]+bestaudio[ext=m4a]/mp4\" \"$url\" --merge-output-format mp4 -o /content/output/video.mp4\n", | |
" if os.path.exists('/content/output/video.mp4'):\n", | |
" url = '/content/output/video.mp4'\n", | |
"\n", | |
"vid = []\n", | |
"reader = imageio.get_reader(url)\n", | |
"try:\n", | |
" fps = reader.get_meta_data()['fps']\n", | |
"except:\n", | |
" fps = 0\n", | |
"for i,im in enumerate(reader):\n", | |
" if start_seconds and fps and i<start_seconds*fps:\n", | |
" continue\n", | |
" if len(im.shape)==2:\n", | |
" im = im[..., None]\n", | |
" elif color_mode=='grayscale':\n", | |
" im = np.uint8(im[...,:3].mean(axis=-1, keepdims=True))\n", | |
" vid.append(im)\n", | |
" if limit_frames and len(vid)==limit_frames or limit_seconds and fps and len(vid)>=limit_seconds*fps:\n", | |
" break\n", | |
"if len(vid)>1:\n", | |
" imageio.mimwrite('/content/output/original.mp4', vid, fps=fps)\n", | |
"vid = np.array(vid)/255\n", | |
"img1 = vid[0].squeeze()\n", | |
"if len(vid)==1:\n", | |
" vid = vid[0]\n", | |
"shape = vid.shape[:-1]\n", | |
"channels = 3 if color_mode=='colorize' else vid.shape[-1]\n", | |
"if len(vid)>1:\n", | |
" img2 = vid[len(vid)//2].squeeze()\n", | |
" img3 = vid[-1].squeeze()\n", | |
" mgrid1 = get_mgrid(shape, frame=-1)[0]\n", | |
" mgrid2 = get_mgrid(shape, frame=0)[0]\n", | |
" mgrid3 = get_mgrid(shape, frame=1)[0]\n", | |
"vid = torch.from_numpy(vid)\n", | |
"pixels = vid.view(-1, vid.shape[-1]).cuda()\n", | |
"coords = get_mgrid(shape)[0]\n", | |
"flat_len = len(coords)\n", | |
"\n", | |
"frames_in_batch = 0\n", | |
"frames_in_odd_batch = 0\n", | |
"if len(shape)==2 or batch_level=='pixels':\n", | |
" if len(shape)==2:\n", | |
" batch_frac = 1\n", | |
" batch_size = max(int(np.ceil(flat_len*batch_frac)),1)\n", | |
"else:\n", | |
" frames_in_batch = max(int(np.ceil(shape[0]*batch_frac)),1)\n", | |
" frames_in_odd_batch = shape[0]%frames_in_batch\n", | |
" batch_size = frames_in_batch*shape[1]*shape[2]\n", | |
" delta_pixels = torch.arange(0,shape[1]*shape[2]).repeat(shape[0])\n", | |
" \n", | |
"siren = Siren(in_features=len(shape), hidden_features=features,\n", | |
" hidden_layers=hidden_layers, out_features=channels, \n", | |
" bias=True, outermost_linear=True,\n", | |
" first_omega_0=first_omega_0, hidden_omega_0=hidden_omega_0,\n", | |
" init=init, is_learnit=is_learnit,\n", | |
" init_spacetime_factor=init_spacetime_factor, factor=factor)\n", | |
"siren.cuda()\n", | |
"num_params = sum(p.numel() for p in siren.parameters() if p.requires_grad)\n", | |
"if optimizer=='Adam':\n", | |
" optim = torch.optim.Adam(lr=lr, params=siren.parameters())\n", | |
"else:\n", | |
" optim = torch.optim.SGD(lr=lr, params=siren.parameters())\n", | |
"num_batches = int(np.ceil(flat_len/batch_size))\n", | |
"num_full_batches = num_batches - (frames_in_odd_batch>0)\n", | |
"print('params=%d batch_size=%d batch_frames=%d odd_batch_frames=%d shape='%(num_params,batch_size,frames_in_batch,frames_in_odd_batch),shape+(channels,))\n", | |
"\n", | |
"start = time()\n", | |
"reg_loss = 0\n", | |
"max_psnr = 0\n", | |
"max_epoch = 0\n", | |
"save = None\n", | |
"if len(shape)==2 and save_every:\n", | |
" save = np.zeros((epochs//save_every+1,*shape,channels), dtype=np.uint8)\n", | |
" if channels==1:\n", | |
" save = save.squeeze(axis=-1)\n", | |
"if batch_level=='sequences':\n", | |
" delta_frames = torch.arange(0,frames_in_batch).repeat(num_batches)[:shape[0]]\n", | |
"for epoch in range(1,epochs+2):\n", | |
" if epoch==epochs+1:\n", | |
" siren = torch.load('/content/output/model.pt').cuda()\n", | |
" coords, shape = get_mgrid(shape, spatial_interpolation_factor=spatial_interpolation_factor, temporal_interpolation_factor=temporal_interpolation_factor, temporal_extrapolation_fraction=temporal_extrapolation_fraction)\n", | |
" if len(shape)==2:\n", | |
" batch_size = len(coords)\n", | |
" num_batches = int(np.ceil(len(coords)/batch_size))\n", | |
" if len(shape)==3:\n", | |
" save = np.zeros((len(coords),channels), dtype=np.uint8).squeeze()\n", | |
" save_grad = np.zeros((len(coords),3))\n", | |
" elif len(shape)==3 and batch_frac < 1:\n", | |
" if batch_level=='pixels':\n", | |
" idx = torch.randperm(flat_len)\n", | |
" else:\n", | |
" if batch_level=='frames':\n", | |
" idx = torch.randperm(shape[0])\n", | |
" elif batch_level=='sequences':\n", | |
" idx = torch.randperm(num_full_batches)\n", | |
" if frames_in_odd_batch>0:\n", | |
" idx = torch.cat((idx,torch.tensor([len(idx)])))\n", | |
" idx = idx.repeat_interleave(frames_in_batch)[:shape[0]]*frames_in_batch+delta_frames\n", | |
" idx = idx.repeat_interleave(shape[1]*shape[2])*shape[1]*shape[2]+delta_pixels\n", | |
" coords = coords[idx]\n", | |
" pixels = pixels[idx]\n", | |
" epoch_recon_loss = torch.tensor(0.0)\n", | |
" epoch_reg_loss = torch.tensor(0.0)\n", | |
" epoch_total_loss = torch.tensor(0.0)\n", | |
" num_extrap = 0\n", | |
" for i in range(num_batches):\n", | |
" model_input = coords[i*batch_size:(i+1)*batch_size]\n", | |
" if epoch == epochs+1 and len(shape)==3 and temporal_phase_factor:\n", | |
" model_input[:,-1] = model_input[:,-1]+temporal_phase_factor*model_input[:,0]\n", | |
" model_output, model_coords = siren(model_input)\n", | |
" if epoch < epochs+1:\n", | |
" ground_truth = pixels[i*batch_size:(i+1)*batch_size]\n", | |
" if color_mode=='colorize':\n", | |
" if ground_truth.shape[-1]>1:\n", | |
" ground_truth = ground_truth[...,:3].mean(axis=-1, keepdim=True)\n", | |
" recon_loss = ((model_output.mean(axis=-1, keepdim=True) - ground_truth)**2).mean()\n", | |
" else: \n", | |
" recon_loss = ((model_output - ground_truth)**2).mean()\n", | |
" total_loss = recon_loss\n", | |
" if reg and (len(shape)==2 or batch_frac==1 or batch_level!='pixels') or len(shape)==2 and (show_every and not epoch % show_every or save_every and not epoch % save_every):\n", | |
" out = model_output.cpu().view(-1, *shape[1:], channels)\n", | |
" if reg:\n", | |
" if len(shape)==2:\n", | |
" reg_loss = reg*((out[:-1]-out[1:]).abs().mean() + (out[:,:-1]-out[:,1:]).abs().mean())\n", | |
" elif len(shape)==3:\n", | |
" reg_loss = reg*((out[:,:-1]-out[:,1:]).abs().mean() + (out[:,:,:-1]-out[:,:,1:]).abs().mean())\n", | |
" if batch_frac==1 or batch_level=='sequences':\n", | |
" reg_loss = reg_loss + reg*(out[:-1]-out[1:]).abs().mean()\n", | |
" total_loss = total_loss + reg_loss\n", | |
" epoch_reg_loss = epoch_reg_loss + reg_loss.detach()*len(model_input)\n", | |
" optim.zero_grad()\n", | |
" total_loss.backward()\n", | |
" optim.step()\n", | |
" epoch_recon_loss = epoch_recon_loss + recon_loss.detach()*len(model_input)\n", | |
" epoch_total_loss = epoch_total_loss + total_loss.detach()*len(model_input)\n", | |
" else:\n", | |
" if len(shape)==2:\n", | |
" out = model_output.cpu().view(*shape, channels)\n", | |
" else:\n", | |
" save[i*batch_size:(i+1)*batch_size] = np.uint8(model_output.squeeze().cpu().detach().numpy().clip(0,1)*255)\n", | |
" if save_gradient:\n", | |
" model_output = gradient(model_output, model_coords)/channels\n", | |
" if len(shape)==2:\n", | |
" grad = model_output.norm(dim=-1).cpu().view(*shape).detach().numpy()\n", | |
" else:\n", | |
" save_grad[i*batch_size:(i+1)*batch_size] = model_output.cpu().detach().numpy()\n", | |
" if len(shape)==2 and (show_every and not epoch % show_every or save_every and not epoch % save_every or epoch==epochs+1):\n", | |
" out1 = out.squeeze().detach().numpy().clip(0,1)\n", | |
" if save_every and not epoch % save_every and epoch < epochs+1:\n", | |
" save[epoch//save_every] = np.uint8(out1*255)\n", | |
" if epoch < epochs+1:\n", | |
" psnr = -10 * torch.log10(epoch_recon_loss/flat_len)\n", | |
" if psnr>max_psnr:\n", | |
" max_psnr = psnr\n", | |
" max_epoch = epoch\n", | |
" torch.save(siren, '/content/output/model.pt')\n", | |
" if show_every and not epoch % show_every: \n", | |
" fig, axes = plt.subplots(1,2 if len(shape)==2 else 6, figsize=(30,30))\n", | |
" axes[1].imshow(img1, cmap='gray' if len(img1.shape)==2 else None)\n", | |
" if len(shape)==2:\n", | |
" axes[0].imshow(out1, cmap='gray' if channels==1 else None)\n", | |
" else:\n", | |
" out1 = siren(mgrid1)[0].cpu().view(*shape[1:], channels).squeeze().detach().numpy().clip(0,1)\n", | |
" out2 = siren(mgrid2)[0].cpu().view(*shape[1:], channels).squeeze().detach().numpy().clip(0,1)\n", | |
" out3 = siren(mgrid3)[0].cpu().view(*shape[1:], channels).squeeze().detach().numpy().clip(0,1)\n", | |
" axes[0].imshow(out1, cmap='gray' if channels==1 else None) \n", | |
" axes[2].imshow(out2, cmap='gray' if channels==1 else None)\n", | |
" axes[3].imshow(img2, cmap='gray' if len(img1.shape)==2 else None)\n", | |
" axes[4].imshow(out3, cmap='gray' if channels==1 else None)\n", | |
" axes[5].imshow(img3, cmap='gray' if len(img1.shape)==2 else None)\n", | |
" plt.show()\n", | |
" print(\"%d: psnr=%0.2f (max_psnr=%0.2f @ epoch=%d) total_loss=%0.6f recon_loss=%0.6f reg_loss=%0.6f\" % (epoch, psnr, max_psnr, max_epoch, epoch_total_loss/flat_len, epoch_recon_loss/flat_len, epoch_reg_loss/flat_len))\n", | |
" print('took: %d secs (%.2f sec/iter) on %s. CUDA memory: %.1f GB\\n'%(time()-start,(time()-start)/(epoch+1), machine[0], torch.cuda.max_memory_allocated()/1024**3))\n", | |
"\n", | |
"print('params=%d batch_size=%d batch_frames=%d odd_batch_frames=%d shape='%(num_params,batch_size,frames_in_batch,frames_in_odd_batch),shape+(channels,))\n", | |
"\n", | |
"vid_files = []\n", | |
"if len(shape)==2:\n", | |
" imageio.imwrite('/content/output/output.png', np.uint8(out1*255))\n", | |
" if save_gradient:\n", | |
" imageio.imwrite('/content/output/gradient.png', np.uint8((grad/grad.max()).clip(0,1)*255))\n", | |
" if save is not None:\n", | |
" imageio.mimwrite('/content/output/learning.mp4', save)\n", | |
" vid_files = ['/content/output/learning.mp4']\n", | |
"else:\n", | |
" fps *= temporal_interpolation_factor\n", | |
" save = save.reshape(-1, *shape[1:], channels)\n", | |
" imageio.mimwrite('/content/output/output.mp4', save, fps=fps)\n", | |
" #!cp /content/output/model.pt /content/output/model_anim.pt\n", | |
" if save_gradient:\n", | |
" save_grad = save_grad.reshape(-1, *shape[1:], 3)\n", | |
" save_grad = np.square(save_grad)\n", | |
" grad = np.sqrt(np.sum(save_grad, axis=-1))\n", | |
" imageio.mimwrite('/content/output/gradient_xyt.mp4', np.uint8((grad/grad.max()).clip(0,1)*255), fps=fps)\n", | |
" imageio.mimwrite('/content/output/gradient_xyt_frame.mp4', np.uint8((grad/grad.max(axis=(1,2), keepdims=True)).clip(0,1)*255), fps=fps)\n", | |
" imageio.mimwrite('/content/output/gradient_xyt_pixel.mp4', np.uint8((grad/grad.max(axis=0)).clip(0,1)*255), fps=fps)\n", | |
" grad = np.sqrt(np.sum(save_grad[...,1:], axis=-1))\n", | |
" imageio.mimwrite('/content/output/gradient_xy.mp4', np.uint8((grad/grad.max()).clip(0,1)*255), fps=fps)\n", | |
" imageio.mimwrite('/content/output/gradient_xy_frame.mp4', np.uint8((grad/grad.max(axis=(1,2), keepdims=True)).clip(0,1)*255), fps=fps)\n", | |
" imageio.mimwrite('/content/output/gradient_xy_pixel.mp4', np.uint8((grad/grad.max(axis=0)).clip(0,1)*255), fps=fps)\n", | |
" grad = np.sqrt(save_grad[...,0])\n", | |
" imageio.mimwrite('/content/output/gradient_t.mp4', np.uint8((grad/grad.max()).clip(0,1)*255), fps=fps)\n", | |
" imageio.mimwrite('/content/output/gradient_t_frame.mp4', np.uint8((grad/grad.max(axis=(1,2), keepdims=True)).clip(0,1)*255), fps=fps)\n", | |
" imageio.mimwrite('/content/output/gradient_t_pixel.mp4', np.uint8((grad/grad.max(axis=0)).clip(0,1)*255), fps=fps)\n", | |
" vid_files = ['/content/output/output.mp4','/content/output/original.mp4']\n", | |
" if not spatial_interpolation_factor and not temporal_extrapolation_fraction:\n", | |
" !ffmpeg -i /content/output/output.mp4 -i /content/output/original.mp4 -filter_complex hstack -pix_fmt yuv420p -profile:v baseline -movflags +faststart /content/output/combined.mp4 -y\n", | |
"data_urls = []\n", | |
"for file in vid_files:\n", | |
" with open(file, 'rb') as f:\n", | |
" data_urls.append(\"data:video/mp4;base64,\" + b64encode(f.read()).decode())\n", | |
"if data_urls:\n", | |
" display(HTML((''.join(\"\"\"\n", | |
" <video width=600 controls autoplay loop>\n", | |
" <source src=\"%s\" type=\"video/mp4\">\n", | |
" </video>\"\"\"%data_url for data_url in data_urls))))\n" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"cellView": "form", | |
"id": "ipz3NRfh9DS_" | |
}, | |
"source": [ | |
"#@title Download\n", | |
"from google.colab import files\n", | |
"!rm -f /content/output.zip\n", | |
"!zip -r /content/output.zip /content/output\n", | |
"files.download('/content/output.zip')" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment