Skip to content

Instantly share code, notes, and snippets.

@yun-long
Created July 9, 2019 08:08
Show Gist options
  • Save yun-long/712ca38abf618c130a89ea8fbfaf8d38 to your computer and use it in GitHub Desktop.
Save yun-long/712ca38abf618c130a89ea8fbfaf8d38 to your computer and use it in GitHub Desktop.
A simple Kalman Filter example
"""
Just A simple example of the Kalman Filter for State Estimation.
Predict:
x_i = a * x_{i-1}
p_i = a * p_{i-1} * a
Update:
g_i = p_i / (p_i + r)
x_i <-- x_i + g_i * (z_i - x_i)
p_i <-- (1 - g_i) * p_i
Where:
x_i : estimated state
z_i : observed state
p_i : prediction error
g_i : gain
r : mean of the observation noise
"""
import numpy as np
class KalmanFilter(object):
"""A simple implementation of the Kalman Filter"""
def __init__(self, p, r):
self.p = p # coefficient of prediction error (initial value)
self.r = r
@property
def gain(self):
return self.p / (self.p + self.r) # gain
def predict(self, x, a):
"""
Prediction step
params:
----------
x : previous estimated state
a : coefficient of the system state
"""
hat_x = a * x
self.p = a * self.p * a
return hat_x
def update(self, hat_x, z, r):
"""
Update step
params:
---------
hat_x : estimated state at current time
z : observed state at current time
"""
g = self.p / (self.p + self.r)
hat_x_ = hat_x * (1 - g) + g * z
self.p = (1- g) * self.p
return hat_x_
def main():
# system parameters
a = 0.75 # coefficient of the system state (constant value) e.g., x_i = 0.75 * x_{i-1}
r = 200. # mean of the observation noise in the environment
X = 1000. * np.array([a**i for i in range(10)]) # real states
# placeholders
X_hat, Z = [], []
P, G = [], []
#
p0 = 1. # initial prediction error, 1 means at the begining, the prediction error is very large
kf = KalmanFilter(p0, r)
for i, x in enumerate(X):
#
P.append(kf.p)
G.append(kf.gain)
#
noise = np.random.uniform(-kf.r, kf.r, 1)
z = x + noise
Z.append(z)
if i == 0:
# estimated state is the observed state
# at the first time step
X_hat.append(z)
continue
# _Predict the next state
hat_x_ = kf.predict(X_hat[i-1], a)
# Compute the gain, then use it to trade-off
# between estimation and observation
# Compute the prediction error at current time step
hat_x = kf.update(hat_x_, z, r)
#
X_hat.append(hat_x)
#
#
import matplotlib.pyplot as plt
fig, axes = plt.subplots(3, 1, figsize=(10, 8), sharex=True)
t = np.arange(len(X))
# plot states
axes[0].plot(t, X, label='Real states')
axes[0].plot(t, Z, label='Observed states')
axes[0].plot(t, X_hat, label='Estimated states')
axes[0].legend()
# prediction error
axes[1].bar(t, P, 0.35, label='Predition Error')
axes[1].legend()
# gain
axes[2].bar(t, G, 0.35, label='Gain')
axes[2].set_xlabel('Time steps')
axes[2].legend()
# plt.savefig('kalman.pdf')
plt.show()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment