Skip to content

Instantly share code, notes, and snippets.

@edxmorgan
Created December 6, 2024 01:37
Show Gist options
  • Save edxmorgan/8d08a5403fce6e0198fa7ff134aeb1b9 to your computer and use it in GitHub Desktop.
Save edxmorgan/8d08a5403fce6e0198fa7ff134aeb1b9 to your computer and use it in GitHub Desktop.
Bayesian Parameter updating
import numpy as np
import matplotlib.pyplot as plt
class BayesianParameterEstimator:
def __init__(self, param_min=0.0, param_max=5.0, num_points=500, known_variance=1.0):
"""
Initialize the Bayesian estimator.
:param param_min: Minimum value of the parameter space
:param param_max: Maximum value of the parameter space
:param num_points: Number of discrete points in the parameter grid
:param known_variance: Known variance of the observation model
"""
self.param_min = param_min
self.param_max = param_max
self.num_points = num_points
self.known_variance = known_variance
self.grid = np.linspace(self.param_min, self.param_max, self.num_points)
# Initialize prior as uniform
self.prior = np.ones(self.num_points) / (self.param_max - self.param_min)
def likelihood(self, data_point, theta_values):
"""
Compute the likelihood of observing `data_point` given theta.
Assume a normal likelihood: x ~ N(theta, sigma^2)
"""
sigma = np.sqrt(self.known_variance)
return (1 / (np.sqrt(2 * np.pi) * sigma)) * np.exp(-0.5 * ((data_point - theta_values)**2) / self.known_variance)
def update(self, new_data_point):
"""
Update the posterior distribution given a new data point.
"""
# Compute likelihood for each theta in the grid
L = self.likelihood(new_data_point, self.grid)
# Posterior (unnormalized) = Prior * Likelihood
posterior_unnormalized = self.prior * L
# Normalize
posterior = posterior_unnormalized / np.trapz(posterior_unnormalized, self.grid)
# Update the prior to the new posterior
self.prior = posterior
def get_posterior_stats(self):
"""
Compute posterior mean and credible interval (e.g., 95%) from the posterior.
"""
# Compute cumulative distribution
cdf = np.cumsum(self.prior)
cdf /= cdf[-1]
mean_estimate = np.trapz(self.grid * self.prior, self.grid)
# For a 95% credible interval:
lower_idx = np.argmax(cdf >= 0.025)
upper_idx = np.argmax(cdf >= 0.975)
credible_interval = (self.grid[lower_idx], self.grid[upper_idx])
return mean_estimate, credible_interval
def plot_posterior(self):
"""
Plot the current posterior distribution.
"""
plt.figure(figsize=(8,5))
plt.plot(self.grid, self.prior, label='Posterior')
plt.xlabel('Parameter θ')
plt.ylabel('Density')
plt.title('Posterior Distribution')
plt.grid(True)
plt.legend()
plt.show()
# Example usage:
if __name__ == "__main__":
# Initialize estimator with a uniform prior over [0,5], known variance = 1.0
estimator = BayesianParameterEstimator(param_min=0, param_max=5, num_points=500, known_variance=1.0)
# Suppose the true parameter is unknown to us, and we observe data in real time.
true_theta = 2.5
np.random.seed(42)
# Simulate observing data points one by one and update posterior each time
data_stream = np.random.normal(true_theta, 1.0, size=10) # 10 observations
for i, d in enumerate(data_stream, start=1):
estimator.update(d)
mean_estimate, ci = estimator.get_posterior_stats()
print(f"After {i} data points:")
print(f" Observed data point: {d:.2f}")
print(f" Posterior mean estimate: {mean_estimate:.3f}")
print(f" 95% credible interval: ({ci[0]:.3f}, {ci[1]:.3f})")
# Uncomment below to see plots after each update:
# estimator.plot_posterior()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment