Last active
August 29, 2015 13:57
-
-
Save chiral/9774641 to your computer and use it in GitHub Desktop.
Kalman filter for ad impretion and click response prediction
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
ad1 | ad2 | ad3 | |
---|---|---|---|
1000 | 0 | 0 | |
1000 | 0 | 0 | |
1000 | 0 | 0 | |
1000 | 0 | 0 | |
1000 | 0 | 0 | |
0 | 0 | 1000 | |
0 | 0 | 1000 | |
0 | 0 | 1000 | |
0 | 0 | 1000 | |
0 | 1000 | 0 | |
0 | 1000 | 0 | |
0 | 1000 | 0 | |
0 | 1000 | 0 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 単位時間間隔Δt | |
dt <- 1 | |
# 興味ベクトルの仕様 | |
# カテゴリ選好度2つとブランド選好度2つの4次元と | |
# それらの速度成分を加えた合計8次元 | |
# c(cat1,cat2,brandA,brandB,velo1,velo2,veloA,veloB) | |
interest_elems <- list("cat1","cat2","brandA","brandB") | |
dimInterestHalf <- length(interest_elems) | |
dimInterest <- dimInterestHalf * 2 | |
# バナーは3種類とする | |
# 各バナーに興味ベクトルを付与 | |
dimAd <- 3 | |
ad1 <- c(1/2,0,1/2,0) # ad1 = cat1 + brandA | |
ad2 <- c(0,1/2,1/2,0) # ad2 = cat2 + brandA | |
ad3 <- c(0,1/2,0,1/2) # ad3 = cat2 + brandB | |
# 行列H: 興味ベクトルxから反応ベクトルzへの変換行列 | |
velo0 <- rep(0,dimInterestHalf) | |
H <- t(matrix(c(ad1,velo0, | |
ad2,velo0, | |
ad3,velo0), | |
ncol=dimAd, nrow=dimInterest)) | |
# 行列R: 観測ベクトルzの観測ノイズの共分散行列(対称行列) | |
R <- diag(0.5 * c(1,1,1)) | |
# 行列F1: 興味ベクトルxの時間遷移行列 | |
F1_00 <- diag(dimInterestHalf) | |
F1_10 <- matrix(0,dimInterestHalf,dimInterestHalf) | |
F1_01 <- dt * diag(dimInterestHalf) | |
F1_11 <- diag(dimInterestHalf) | |
F1 <- cbind(rbind(F1_00,F1_10),rbind(F1_01,F1_11)) | |
# 行列F2: 広告ベクトルuから興味ベクトルxへの変換行列 | |
alpha <- 0.01 | |
F2 <- alpha * matrix(c(ad1,ad2,ad3),ncol=dimAd, nrow=dimInterestHalf) | |
# 行列Q: 興味ベクトルの単位時間あたりのノイズ源 w_t | |
Q <- diag(0.025 * c(1,1,1,1)) | |
# 行列F3: ノイズ源w_t を興味ベクトルへの寄与に変換する行列 | |
F3_0 <- (1/2*dt*dt) * diag(dimInterestHalf) | |
F3_1 <- dt * diag(dimInterestHalf) | |
F3 <- rbind(F3_0,F3_1) | |
F2 <- F3 %*% F2 # 広告インプレッションを力の作用とみなすため | |
# ノイズの寄与と同じ変換を掛ける | |
# 初期条件 x0,P0 | |
x0 <- rep(0,dimInterest) | |
P0 <- diag(dimInterest) | |
x <- rbind(x0) | |
P <- list(P0) | |
# データの読み込み | |
z <- read.csv("response.csv") | |
u <- read.csv("impression.csv") | |
# 予測ステップ | |
# 注)R言語の添え字の都合により、、 | |
# x,P と z,u は時刻インデックスが一つずれてる | |
kalman_predict <- function(t) { | |
x1 <- F1 %*% x[t,] + F2 %*% t(u[t,]) | |
x <<- rbind(x,t(x1)) | |
P[[t+1]] <<- F1 %*% P[[t]] %*% t(F1) + F3 %*% Q %*% t(F3) | |
} | |
# 更新ステップ | |
kalman_update <- function(t) { | |
e <- z[t,] - x[t+1,] %*% t(H) | |
S <- R + H %*% P[[t+1]] %*% t(H) | |
K <- P[[t+1]] %*% t(H) %*% solve(S) | |
I <- diag(dimInterest) | |
x[t+1,] <<- x[t+1,] + K %*% t(e) | |
P[[t+1]] <<- (I - K %*% H) %*% P[[t+1]] | |
} | |
# シミュレーション | |
upper <- c() | |
lower <- c() | |
for (t in 1:nrow(z)) { | |
print(paste("time =",t)) | |
kalman_predict(t) | |
print(paste(" predict_x =",x[t,])) | |
kalman_update(t) | |
print(paste(" update_x =",x[t,])) | |
sd <- sqrt(diag(P[[t]])) | |
upper <- rbind(upper,x[t,]+sd) | |
lower <- rbind(lower,x[t,]-sd) | |
} | |
# 結果のプロット | |
par(mfrow=c(3,2)) | |
xlim<-c(1,nrow(z)) | |
ylim<-c(min(u),max(u)) | |
plot(u[,1],type='l',col='1',ylab='u',xlim=xlim,ylim=ylim) | |
par(new=T) | |
plot(u[,2],type='l',col='1',ylab='',xlim=xlim,ylim=ylim) | |
par(new=T) | |
plot(u[,3],type='l',col='2',ylab='',xlim=xlim,ylim=ylim) | |
ylim<-c(min(z),max(z)) | |
plot(z[,1],type='l',col='1',ylab='z',xlim=xlim,ylim=ylim) | |
par(new=T) | |
plot(z[,2],type='l',col='1',ylab='',xlim=xlim,ylim=ylim) | |
par(new=T) | |
plot(z[,3],type='l',col='2',ylab='',xlim=xlim,ylim=ylim) | |
for (i in 1:dimInterestHalf) { | |
ylim<-c(min(lower[,i]),max(upper[,i])) | |
plot(x[,i],type='l',ylab=interest_elems[[i]],xlim=xlim,ylim=ylim) | |
par(new=T) | |
plot(upper[,i],type='l',ylab='',col='4',xlim=xlim,ylim=ylim) | |
par(new=T) | |
plot(lower[,i],type='l',ylab='',col='2',xlim=xlim,ylim=ylim) | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
ad1 | ad2 | ad3 | |
---|---|---|---|
5 | 0 | 0 | |
10 | 0 | 0 | |
15 | 0 | 0 | |
20 | 0 | 0 | |
25 | 0 | 0 | |
25 | 0 | 5 | |
20 | 0 | 10 | |
15 | 0 | 15 | |
10 | 0 | 20 | |
10 | 5 | 20 | |
11 | 10 | 15 | |
12 | 15 | 10 | |
13 | 20 | 5 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment