Skip to content

Instantly share code, notes, and snippets.

@norabelrose
Last active March 4, 2023 06:05
Show Gist options
  • Select an option

  • Save norabelrose/21186b1c72cc63e6e58f0c7d211979e3 to your computer and use it in GitHub Desktop.

Select an option

Save norabelrose/21186b1c72cc63e6e58f0c7d211979e3 to your computer and use it in GitHub Desktop.
from itertools import product
from scipy.optimize import curve_fit
from typing import NamedTuple, Sequence
import numpy as np
class Break(NamedTuple):
c: float
d: float
f: float
class BNSL(NamedTuple):
a: float
b: float
c0: float
breaks: Sequence[Break]
@classmethod
def fit(cls, x, y, num_breaks: int = 1):
assert np.all(x > 0) and np.all(y > 0)
q = np.linspace(0, 1, 5)[1:-1]
x_quantiles = np.log(np.quantile(x, q))
y_quantiles = np.quantile(y, (0.0, 0.25, 0.5, 0.75, 1.0))
# Test grid of initializations
best_loss = np.inf
best_p = None
exp_grid = np.linspace(0.1, 0.99, 5)
log_grid = np.linspace(1, 10, 10)
break_grid = (exp_grid, x_quantiles, exp_grid) * num_breaks
for params in product(y_quantiles, log_grid, exp_grid, *break_grid):
loss = cls.from_params(params).loss(x, y)
if best_p is None or loss < best_loss:
best_loss = loss
best_p = params
def fn(x, *p):
y_pred = cls.from_params(p)(x)
return np.log(y_pred)
break_lb = [0, np.log(x.min()), 0] * num_breaks
break_ub = [1, np.log(x.max()), np.inf] * num_breaks
p_star, *_ = curve_fit(
fn, x, np.log(y), best_p,
bounds=(
np.array([-np.inf, -np.inf, 0] + break_lb),
np.array([np.inf, np.inf, 1] + break_ub)
),
maxfev=None,
)
return cls.from_params(p_star)
@classmethod
def from_params(cls, params):
a, log_b, c, *break_params = params
breaks = []
for i in range(0, len(break_params), 3):
c_i, log_d_i, f_i = break_params[i:i+3]
breaks.append(
Break(c_i, np.exp(log_d_i), f_i)
)
return cls(a, np.exp(log_b), c, breaks)
def to_params(self):
break_params = []
for break_ in self.breaks:
break_params.extend([
break_.c,
np.log(break_.d),
break_.f,
])
return (
self.a,
np.log(self.b),
self.c0,
*break_params
)
def __call__(self, x):
y = self.b * x ** -self.c0
for c_i, d_i, f_i in self.breaks:
y *= (1.0 + (x / d_i) ** (1.0 / f_i)) ** (-c_i * f_i)
return self.a + y
def loss(self, x, y):
"""Mean squared log error"""
log_diff = np.log(self(x)) - np.log(y)
return np.mean(log_diff ** 2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment