-
-
Save Judithcodes/aa56349643670a0b1e845993381fadbb to your computer and use it in GitHub Desktop.
Linear Kalman filter and animation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# multivariate normal Kalman filter | |
require(dplyr) | |
require(tidyr) | |
require(ggplot2) | |
require(animation) | |
# ARIMA(1,1) + 線形トレンド の乱数生成 | |
N <- 50 | |
phi1 <- .5 | |
theta1 <- .2 | |
sigma <- 1 | |
delta <- .5 | |
set.seed(42) | |
y <- arima.sim(model=list(ar=phi1, ma=theta1), | |
n=N, | |
innov=rnorm(N) * sigma) | |
y <- y + delta * 1:N | |
# グラフで確認 | |
ts.plot(y) | |
# 毎期のフィルタリングを行う関数 | |
Kf.filter.linear <- function(Z, G, H, Q, R, y, xhat, P){ | |
# 引数: | |
# x[t+1] = Zx[t] + Gv[t] | |
# y[t] = Hx[t] + w[t] | |
# v[t] ~ N(0, Q) | |
# w[t] ~ N(0, R) | |
# u = u[t-1] | |
# y = observation value at [t] | |
# xhat = prior state estimates at [t-1] | |
# P = posteriror state variance at [t-1] | |
# 返り値: | |
# xpri = xpri[t+1], xpost = xpost[t], | |
# Ppri = Ppri[t+1], Ppost = Ppost[t], | |
# K = Kalman gain at t | |
y <- matrix(y, ncol=1) | |
xhat <- matrix(xhat, ncol=1) | |
# 観測値に欠損がない場合に更新 | |
# innovataion term | |
v <- y - H %*% xhat | |
Vv <- H %*% P %*% t(H) + R | |
# Kalman gain | |
K <- P %*% t(H) %*% solve(Vv) | |
# filtering | |
if(!any(is.na(v))){ | |
xpost <- xhat + K %*% v | |
Ppost <- P - K %*% Vv %*% t(K) # variance | |
} else{ | |
# 欠損のある場合 | |
xpost <- xhat | |
Ppost <- P | |
} | |
# one-step-ahead | |
xpred <- Z %*% xpost | |
Ppred <- Z %*% Ppost %*% t(Z) + t(G) %*% Q %*% t(G) | |
return(list(xpost=xpost, xpred=xpred, | |
Ppost=Ppost, Ppred=Ppred, | |
K=K)) | |
} | |
# matrix 型にする | |
y <- matrix(y, nrow=N, ncol=1) | |
# 欠測バージョン | |
# y[20:29] <- NA | |
#### モデルパラメータ #### | |
# system: x[t+1] = Ax[t] + Cv[t] | |
# obs: y[t] = Bx[t] + e[t] | |
# v[t] ~ NID(0,Q), e[t] ~ NID(0, R) | |
# 本来ならここは推定する必要がある | |
A <- matrix(c(.5, 1, 1, 0, | |
0, 0, 0, 0, | |
0, 0, 1, 1, | |
0, 0, 0, 1), nrow=4, ncol=4, byrow = T) | |
B <- matrix(c(1, 0, 0, 0), nrow=1, ncol=4, byrow = T) | |
C <- matrix(c(1, 0, 0, 0, | |
.5, 0, 0, 0, | |
0, 0, 0, 0, | |
0, 0, 0, 0), nrow=4, ncol=4, byrow = T) | |
Q <- diag(1, 4) | |
R <- diag(0, 1) | |
# for(i in 1:N){ | |
# temp <- Kf.filter.linear(Z = A, G = C, H = B, | |
# Q = Q, R = R, | |
# y = y[i, ], | |
# xhat = xpri[i, ], | |
# P = P[[i]]) | |
# xpost[i,] <- temp$xpost | |
# if(i < N) | |
# xpri[i+1,] <- temp$xpred | |
# P[[i+1]] <- temp$Ppred | |
# K[[i]] <- temp$K | |
# } | |
kalman.l <- function(A, C, B, Q, R, y, xini, Pini, N=NULL){ | |
# 規定の期間までフィルタリング or 予測 | |
if(is.null(N)) | |
N <- nrow(y) | |
xpri <- matrix(0, nrow=N, ncol=nrow(A)) | |
xpost <- matrix(0, nrow=N, ncol=nrow(A)) | |
P <- list(matrix(0, nrow=nrow(A), ncol=nrow(A))) | |
K <- list(matrix(0, nrow=nrow(A), ncol=nrow(A))) | |
# 状態推定の初期値 | |
xpri[1, ] <- xini | |
P[[1]] <- Pini | |
for(i in 1:N){ | |
temp <- Kf.filter.linear(Z = A, G = C, H = B, | |
Q = Q, R = R, | |
y = y[i, ], | |
xhat = xpri[i, ], | |
P = P[[i]]) | |
xpost[i,] <- temp$xpost | |
if(i < N) | |
xpri[i+1,] <- temp$xpred | |
P[[i+1]] <- temp$Ppred | |
K[[i]] <- temp$K | |
} | |
return(list(xpost=xpost, xpri=xpri, P=P, K=K)) | |
} | |
# 実行 | |
result <- kalman.l(A, C, B, Q, R, y, c(0, 0, delta, delta), diag(1.5, 4)) | |
# グラフ用のデータ作成 | |
df <- data.frame(y) %>% mutate(t=1:N) | |
colnames(df)[1] <- "raw" | |
df$xpri <- result$xpri[,1] | |
df$xpost <- result$xpost[,1] | |
### 描画用関数 ### | |
drawKalman <- function(df, ci=.95){ | |
test <- list() | |
t_max <- max(df$t) | |
# True = one-step-ahead predict, | |
# False = filtering | |
kalmanstep <- T | |
g.raw <- ggplot(df) + | |
geom_point(aes(x=t, y=raw), color="red4", size=1) | |
for(now in 1:t_max){ | |
y.pred <- df$xpost[now] | |
# forecasting | |
if(now < t_max){ | |
y.temp <- y | |
y.temp[(now+1):t_max,] <- NA | |
temp <- kalman.l(A, C, B, Q, R, y.temp, | |
c(0, 0, delta, delta), diag(1.5, 4)) | |
df.for <- data.frame( | |
t=1:nrow(y.temp), | |
xpost = temp$xpost[, 1], | |
xlow = qnorm(p=(1-ci)/2, | |
mean=temp$xpost[, 1], | |
sd=unlist(lapply(temp$P[-1], function(x) return(sqrt(x[1,1])))) | |
), | |
xup = qnorm(p=(1-ci)/2, | |
mean=temp$xpost[, 1], | |
sd=unlist(lapply(temp$P[-1], function(x) return(sqrt(x[1,1])))), | |
lower.tail=F | |
) | |
) %>% filter(t > now) | |
test[[now]] <- df.for | |
} | |
for( kalmanstep in c(T, F)){ | |
g <- g.raw + geom_step(data=dplyr::select(df[1:(now - 1),], -raw) %>% | |
gather(key=type, value=state, -t) %>% | |
arrange(t), | |
aes(x=t, y=state), color="blue", size=.6) | |
# predict and filtering at round t | |
g <- g + geom_point(data=df[1:(now-1), ], | |
aes(x=t, y=xpri), size=2, color="green") | |
g <- g + geom_point(data=df[1:(now-1), ], | |
aes(x=t, y=xpost), size=2, color="blue") | |
if(kalmanstep ==T){ | |
g <- g + geom_point(data=df[now, ], | |
aes(x=t, y=xpri), size=2, color="green") + | |
annotate("text", x=now, y=y.pred * .8, hjust=0, | |
label="prior estimation") | |
} else { | |
g <- g + geom_point(data=df[now, ], | |
aes(x=t, y=xpost), size=2, color="blue") + | |
annotate("text", x=now, y=y.pred * .8, hjust=0, | |
label="posterior estimation") | |
} | |
if(now < t_max){ | |
# forcasting | |
g <- g + | |
geom_line(data=df.for, aes(x=t, y=xpost), | |
color="blue", linetype="dashed", alpha=.5, size=1) + | |
# confidence intervals | |
geom_ribbon(data=df.for, aes(x=t, ymin=xlow, ymax=xup), alpha=.2, fill="blue") + | |
geom_line(data=df.for, aes(x=t, y=xlow), color="black", linetype="dashed") + | |
geom_line(data=df.for, aes(x=t, y=xup), color="black", linetype="dashed") | |
# highlighting forecasting term | |
g <- g + geom_rect(aes(xmin=now + 1, xmax=t_max, | |
ymin=-Inf, ymax=Inf), | |
fill="grey", alpha=.02) | |
} | |
# legend | |
g <- g + theme_bw() + guides(col=guide_legend()) + | |
labs(y="y", title="Kalman filter") + | |
coord_cartesian(ylim=c(min(y, na.rm = T), max(y, na.rm = T))) + | |
theme(plot.title=element_text(hjust=.5)) + | |
annotate("text", x=t_max, y=min(y, na.rm = T), hjust=1, vjust=1, | |
label="http://ill-identified.hatenablog.com/") | |
print(g) | |
} | |
} | |
return(test) | |
} | |
# gif 作成 | |
animation::saveGIF( | |
expr = drawKalman(df), | |
movie.name = "kalman.gif", | |
interval = .2, | |
ani.width=400, ani.height=250) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment