Created
July 9, 2019 08:08
-
-
Save yun-long/712ca38abf618c130a89ea8fbfaf8d38 to your computer and use it in GitHub Desktop.
A simple Kalman Filter example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Just A simple example of the Kalman Filter for State Estimation. | |
Predict: | |
x_i = a * x_{i-1} | |
p_i = a * p_{i-1} * a | |
Update: | |
g_i = p_i / (p_i + r) | |
x_i <-- x_i + g_i * (z_i - x_i) | |
p_i <-- (1 - g_i) * p_i | |
Where: | |
x_i : estimated state | |
z_i : observed state | |
p_i : prediction error | |
g_i : gain | |
r : mean of the observation noise | |
""" | |
import numpy as np | |
class KalmanFilter(object): | |
"""A simple implementation of the Kalman Filter""" | |
def __init__(self, p, r): | |
self.p = p # coefficient of prediction error (initial value) | |
self.r = r | |
@property | |
def gain(self): | |
return self.p / (self.p + self.r) # gain | |
def predict(self, x, a): | |
""" | |
Prediction step | |
params: | |
---------- | |
x : previous estimated state | |
a : coefficient of the system state | |
""" | |
hat_x = a * x | |
self.p = a * self.p * a | |
return hat_x | |
def update(self, hat_x, z, r): | |
""" | |
Update step | |
params: | |
--------- | |
hat_x : estimated state at current time | |
z : observed state at current time | |
""" | |
g = self.p / (self.p + self.r) | |
hat_x_ = hat_x * (1 - g) + g * z | |
self.p = (1- g) * self.p | |
return hat_x_ | |
def main(): | |
# system parameters | |
a = 0.75 # coefficient of the system state (constant value) e.g., x_i = 0.75 * x_{i-1} | |
r = 200. # mean of the observation noise in the environment | |
X = 1000. * np.array([a**i for i in range(10)]) # real states | |
# placeholders | |
X_hat, Z = [], [] | |
P, G = [], [] | |
# | |
p0 = 1. # initial prediction error, 1 means at the begining, the prediction error is very large | |
kf = KalmanFilter(p0, r) | |
for i, x in enumerate(X): | |
# | |
P.append(kf.p) | |
G.append(kf.gain) | |
# | |
noise = np.random.uniform(-kf.r, kf.r, 1) | |
z = x + noise | |
Z.append(z) | |
if i == 0: | |
# estimated state is the observed state | |
# at the first time step | |
X_hat.append(z) | |
continue | |
# _Predict the next state | |
hat_x_ = kf.predict(X_hat[i-1], a) | |
# Compute the gain, then use it to trade-off | |
# between estimation and observation | |
# Compute the prediction error at current time step | |
hat_x = kf.update(hat_x_, z, r) | |
# | |
X_hat.append(hat_x) | |
# | |
# | |
import matplotlib.pyplot as plt | |
fig, axes = plt.subplots(3, 1, figsize=(10, 8), sharex=True) | |
t = np.arange(len(X)) | |
# plot states | |
axes[0].plot(t, X, label='Real states') | |
axes[0].plot(t, Z, label='Observed states') | |
axes[0].plot(t, X_hat, label='Estimated states') | |
axes[0].legend() | |
# prediction error | |
axes[1].bar(t, P, 0.35, label='Predition Error') | |
axes[1].legend() | |
# gain | |
axes[2].bar(t, G, 0.35, label='Gain') | |
axes[2].set_xlabel('Time steps') | |
axes[2].legend() | |
# plt.savefig('kalman.pdf') | |
plt.show() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment