Skip to content

Instantly share code, notes, and snippets.

@metric-space
Created June 26, 2025 09:18
Show Gist options
  • Save metric-space/42f98747e91e0cec0f9a90e0d39467c0 to your computer and use it in GitHub Desktop.
Save metric-space/42f98747e91e0cec0f9a90e0d39467c0 to your computer and use it in GitHub Desktop.
#import jax
#import jax.numpy as jnp
import matplotlib.pyplot as plt
#import jax.numpy.linalg as linalg
import numpy as jnp
eta = 3
rbf_ = lambda r: jnp.exp(-(eta*r)**2.0)
function = lambda x: jnp.exp(x*jnp.cos(x*3*jnp.pi))
range_ = jnp.linspace(0.0,1.0,14)
function_values = function(range_)
rbf = rbf_(jnp.abs(range_[:,None] - range_[None,:]))
weights = jnp.linalg.solve(rbf, function_values[:,None])
def interpolate(x_scalar):
diffs = x_scalar - range_ # shape (14,)
return (rbf_(jnp.abs(diffs)) * weights.squeeze(-1)).sum()
range2_ = jnp.linspace(0,1,100)
actual = function(range2_)
interpolated = jnp.vectorize(interpolate)(range2_)
fig, ax = plt.subplots()
plt.plot(range2_, actual)
plt.plot(range2_, interpolated)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment