Last active
October 29, 2024 01:55
-
-
Save GaryLee/818dc104eded5ce9d81ad842b8ed4fb3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!python | |
# coding: utf-8 | |
import numpy as np | |
import matplotlib.pyplot as plt | |
dt = 0.1 # The time interval of sampling. | |
# Make input data with some white noise. | |
z = np.arange(1, 100, dt) | |
noise = np.random.randn(np.size(z)) | |
z += noise | |
# The accuracy of model. The smaller number, the higher confidence. | |
model_accuracy = 0.0001 | |
# Define states and related matrix. | |
# We don't use input control. Therefore, the u vector and B matrix are both zero. | |
x = np.matrix([0.0, 0.0]).T # state vector: [position, velocity] | |
u = np.matrix([0.0, 0.0]).T # input vector. | |
B = np.matrix([[0.0, 0.0], [0.0, 0.0]]) # Control matrix. Assume no control input. | |
P = np.matrix([[1.0, 0.0], [0.0, 1.0]]) # State variance matrix. | |
F = np.matrix([[1.0, 1.0], [0.0, 1.0]]) # State transfer matrix. | |
Q = np.matrix([[model_accuracy, 0.0], [0.0, model_accuracy]]) # State variance transfer matrix. | |
H = np.matrix([1.0, 0.0]) # Observation matrix. | |
R = 1 # Observation noise variance. | |
I = np.eye(2) # Identity matrix. | |
# Create data log arrays for plotting. | |
position = np.zeros_like(z) | |
velocity = np.zeros_like(z) | |
time_serial = np.zeros_like(z) | |
# Loop for data processing. | |
for t, data in enumerate(z): | |
# Kalman filter. | |
x_ = F @ x + B @ u | |
P_ = (F @ P @ F.T) + Q | |
S = (H @ P_ @ H.T) + R # Residue observation covariance. | |
K = (P_ @ H.T) / S | |
x = x_ + K @ (data - H @ x_) | |
P = (I - (K @ H)) @ P_ | |
# Log data. | |
position[t] = x[0] | |
velocity[t] = x[1] | |
time_serial[t] = t | |
#print(f"{np.shape(K)=}, {np.shape(S)=}, {np.shape(P)=}") | |
velocity_avg = sum(velocity) / len(velocity) | |
print(f"{velocity_avg=}") | |
plt.figure(figsize=(12, 10)) | |
plt.plot(time_serial, z, label="observation") | |
plt.plot(time_serial, position, label="position") | |
plt.plot(time_serial, velocity, label="velocity") | |
# plt.plot(position, velocity, label="p-v") | |
plt.xlabel('Time (s)') | |
plt.title('Simple kalman filter') | |
plt.legend() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment