Skip to content

Instantly share code, notes, and snippets.

@going-digital
Created August 30, 2021 17:56
Show Gist options
  • Save going-digital/67e7db8e86319e19246ebe00248ac971 to your computer and use it in GitHub Desktop.
Save going-digital/67e7db8e86319e19246ebe00248ac971 to your computer and use it in GitHub Desktop.
RBF Bunny.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "RBF Bunny.ipynb",
"provenance": [],
"authorship_tag": "ABX9TyO0n+qM2MLLG73JsiG9Zf7B",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/going-digital/67e7db8e86319e19246ebe00248ac971/rbf-bunny.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XEsT9va9BPbU"
},
"source": [
"Mesh to Shadertoy using RBFs\n",
"\n",
"See this working: https://www.shadertoy.com/view/7dtGR2\n",
"\n",
"Original concept by Blackle Mori https://www.shadertoy.com/view/wtVyWK\n",
"\n",
"Blackle used a SIne REpresentation Network.\n",
"I've since found that for small network sizes, Radial Basis Functions seem to\n",
"work better, as they are just a linear combination of simple primitives. In\n",
"this case I've chosen the simplest of all - the SDF of a sphere.\n",
"\n",
"This seems to reproduce fine detail (bunny ears!) better with smaller model sizes."
]
},
{
"cell_type": "code",
"metadata": {
"id": "eV7h2-cFBceA"
},
"source": [
"# -*- coding: utf-8 -*-\n",
"!wget https://github.com/going-digital/ml_sdf/raw/main/bunny2.obj\n",
"!pip install mesh_to_sdf\n",
"import numpy as np\n",
"from mesh_to_sdf import get_surface_point_cloud\n",
"from mesh_to_sdf.utils import sample_uniform_points_in_unit_sphere\n",
"import trimesh\n",
"import re\n",
"import tensorflow as tf\n",
"import tensorflow.keras.backend as K"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "NINdQgjEBgHf"
},
"source": [
"model_samples = 256*256*4\n",
"number_of_rbfs = 78 # Hand tuned to hit 2048 bytes. Quality/size tradeoff."
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "OUQf4VJsAgFF"
},
"source": [
"seed = 1234\n",
"np.random.seed(seed)\n",
"\n",
"def SDFFitting(filename, samples):\n",
" surface_samples = int(np.floor(0.5 * samples))\n",
" volume_samples = samples - surface_samples\n",
" mesh = trimesh.load(filename)\n",
" surface_point_cloud = get_surface_point_cloud(mesh, surface_point_method='sample')\n",
" coords1, samples1 = surface_point_cloud.sample_sdf_near_surface(surface_samples, use_scans=False, sign_method='normal')\n",
" coords2 = sample_uniform_points_in_unit_sphere(volume_samples)\n",
" samples2 = surface_point_cloud.get_sdf_in_batches(coords2, use_depth_buffer=False)\n",
" coords = np.concatenate([coords1, coords2])\n",
" samples = np.concatenate([samples1, samples2])\n",
" return (coords, samples)\n",
"#%%\n",
"# Coordinates of the model are preserved in the SDF.\n",
"# To improve RBF modelling, ensure your model is simple, fully solid and watertight.\n",
"# Size and location are preserved, so ensure your source model is of unit size and centred at the origin.\n",
"\n",
"# This model is a Stanford Bunny, modified by Blackle.\n",
"coord, target_sdf = SDFFitting(\"bunny2.obj\", model_samples)\n",
"#%%\n",
"class RBFLayer(tf.keras.layers.Layer):\n",
" def __init__(self, units, **kwargs):\n",
" self.units = units\n",
" self.initializer = tf.keras.initializers.RandomUniform\n",
" super(RBFLayer, self).__init__(**kwargs)\n",
"\n",
" def build(self, input_shape):\n",
" self.centres = self.add_weight(\n",
" name='centres',\n",
" shape=(self.units,input_shape[1]),\n",
" initializer=self.initializer,\n",
" trainable=True\n",
" )\n",
" super().build(input_shape)\n",
"\n",
" def call(self, inputs):\n",
" # SDF of a sphere, written in Tensorflow \n",
" C = tf.expand_dims(self.centres, -1)\n",
" H = tf.transpose(C - tf.transpose(inputs))\n",
" return tf.sqrt(tf.math.reduce_sum(H**2, axis=1))\n",
" \n",
" def compute_output_shape(self, input_shape):\n",
" return (input_shape[0], self.units)\n",
"\n",
"#%%\n",
"tf.random.set_seed(123)\n",
"model = tf.keras.Sequential([\n",
" tf.keras.layers.InputLayer(input_shape=(3,)),\n",
" RBFLayer(number_of_rbfs),\n",
" tf.keras.layers.Dense(units=1, activation='linear'),\n",
"])\n",
"\n",
"model.compile(\n",
" # Learning rate below is tweakable. Increase to speed up learning.\n",
" # Decrease to improve learning stability.\n",
" optimizer=tf.keras.optimizers.Adam(learning_rate=3e-4),\n",
" loss=tf.keras.losses.MeanSquaredLogarithmicError(reduction=tf.keras.losses.Reduction.NONE),\n",
" metrics=[tf.keras.metrics.MeanSquaredLogarithmicError()],\n",
")\n",
"dataset = tf.data.Dataset.from_tensor_slices((coord, target_sdf)).batch(256).cache()\n",
"#%%\n",
"# Convert network to a compact shadertoy representation\n",
"rbf_glsl = \"S({0:.3f},{1:.0f},{2:.0f},{3:.0f})\"\n",
"def model_to_shadertoy(model):\n",
" radials = []\n",
" for i in range(model.weights[0].numpy().shape[0]):\n",
" radials.append(\n",
" rbf_glsl.format(\n",
" model.weights[1][i].numpy()[0],\n",
" 100*model.weights[0][i][2].numpy(),\n",
" 100*model.weights[0][i][0].numpy(),\n",
" 100*model.weights[0][i][1].numpy(),\n",
" )\n",
" )\n",
" output = [\n",
" \"#define S(a,b,c,d) a*length(p-.01*vec3(b,c,d))\\n\",\n",
" \"float scene(vec3 p){\\n\",\n",
" \" return {:0.3f}\\n + \".format(model.weights[2].numpy()[0]),\n",
" \"\\n + \".join(radials) + \";\\n}\\n\",\n",
" ]\n",
" output = \"\".join(output)\n",
"\n",
" # Some easily automated code size tweaks.\n",
" output = re.sub(r\"(\\d+\\.\\d*)0+\\b\", r\"\\1\", output) # Remove trailing zeros eg. 1.0 => 1.\n",
" output = re.sub(r\"\\b(\\.\\d+)0+\\b\", r\"\\1\", output) # Remove trailing zeros eg. .60 => .6\n",
" output = re.sub(r\"\\b0(\\.\\d+)\\b\", r\"\\1\", output) # Remove leading zeros eg. 0.5 => .5\n",
" output = re.sub(r\"-\\.0+\\b\", r\".0\", output) # Make all zeros positive eg. -.0 => .0\n",
" output = re.sub(r\"\\+-\", r\"-\", output) # Change +-1. into -1.\n",
" output = re.sub(r\"\\+ S\\(-\", r\"- S(\", output) # Express sign of weight more compactly.\n",
" return output\n",
"#%%\n",
"model.fit(\n",
" dataset.shuffle(model_samples),\n",
" epochs=1000,\n",
" shuffle='batch',\n",
" callbacks=[\n",
" #tf.keras.callbacks.ModelCheckpoint(filepath='cp', monitor='loss',save_weights_only=True,verbose=1,save_best_only=True,),\n",
" ],\n",
")\n",
"print(model_to_shadertoy(model))"
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment